Detailed changes
@@ -39,6 +39,7 @@ tokio-rustls = { version = "0.26", optional = true }
[dev-dependencies]
env_logger = { version = "0.11", default-features = false, features = ["auto-color", "humantime"] }
# this is needed for echo-component example
+tokio = { version = "1", features = ["test-util"] }
tokio-xmpp = { path = ".", features = ["insecure-tcp"]}
[features]
@@ -31,9 +31,7 @@ async fn main() {
// If you don't need a custom server but default localhost:5347, you can use
// Component::new() directly
- let mut component = Component::new_plaintext(jid, password, server)
- .await
- .unwrap();
+ let mut component = Component::new(jid, password).await.unwrap();
// Make the two interfaces for sending and receiving independent
// of each other so we can move one into a closure.
@@ -2,7 +2,7 @@ use futures::{SinkExt, StreamExt};
use tokio::{self, io, net::TcpSocket};
use tokio_xmpp::parsers::stream_features::StreamFeatures;
-use tokio_xmpp::xmlstream::{accept_stream, StreamHeader};
+use tokio_xmpp::xmlstream::{accept_stream, StreamHeader, Timeouts};
#[tokio::main]
async fn main() -> Result<(), io::Error> {
@@ -19,6 +19,7 @@ async fn main() -> Result<(), io::Error> {
let stream = accept_stream(
tokio::io::BufStream::new(stream),
tokio_xmpp::parsers::ns::DEFAULT_NS,
+ Timeouts::default(),
)
.await?;
let stream = stream.send_header(StreamHeader::default()).await?;
@@ -19,7 +19,9 @@ use crate::{
client::bind::bind,
connect::ServerConnector,
error::{AuthError, Error, ProtocolError},
- xmlstream::{xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, XmppStream},
+ xmlstream::{
+ xmpp::XmppStreamElement, InitiatingStream, ReadError, StreamHeader, Timeouts, XmppStream,
+ },
};
pub async fn auth<S: AsyncBufRead + AsyncWrite + Unpin>(
@@ -107,11 +109,12 @@ pub async fn client_login<C: ServerConnector>(
server: C,
jid: Jid,
password: String,
+ timeouts: Timeouts,
) -> Result<(Option<FullJid>, StreamFeatures, XmppStream<C::Stream>), Error> {
let username = jid.node().unwrap().as_str();
let password = password;
- let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT).await?;
+ let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
let channel_binding = C::channel_binding(xmpp_stream.get_stream())?;
@@ -5,6 +5,7 @@ use crate::{
client::{login::client_login, stream::ClientState},
connect::ServerConnector,
error::Error,
+ xmlstream::Timeouts,
Stanza,
};
@@ -30,6 +31,7 @@ pub struct Client<C: ServerConnector> {
password: String,
connector: C,
state: ClientState<C::Stream>,
+ timeouts: Timeouts,
reconnect: bool,
// TODO: tls_required=true
}
@@ -95,6 +97,7 @@ impl Client<StartTlsServerConnector> {
jid.clone(),
password,
DnsConfig::srv(&jid.domain().to_string(), "_xmpp-client._tcp", 5222),
+ Timeouts::default(),
);
client.set_reconnect(true);
client
@@ -105,8 +108,14 @@ impl Client<StartTlsServerConnector> {
jid: J,
password: P,
dns_config: DnsConfig,
+ timeouts: Timeouts,
) -> Self {
- Self::new_with_connector(jid, password, StartTlsServerConnector::from(dns_config))
+ Self::new_with_connector(
+ jid,
+ password,
+ StartTlsServerConnector::from(dns_config),
+ timeouts,
+ )
}
}
@@ -117,8 +126,14 @@ impl Client<TcpServerConnector> {
jid: J,
password: P,
dns_config: DnsConfig,
+ timeouts: Timeouts,
) -> Self {
- Self::new_with_connector(jid, password, TcpServerConnector::from(dns_config))
+ Self::new_with_connector(
+ jid,
+ password,
+ TcpServerConnector::from(dns_config),
+ timeouts,
+ )
}
}
@@ -128,6 +143,7 @@ impl<C: ServerConnector> Client<C> {
jid: J,
password: P,
connector: C,
+ timeouts: Timeouts,
) -> Self {
let jid = jid.into();
let password = password.into();
@@ -136,6 +152,7 @@ impl<C: ServerConnector> Client<C> {
connector.clone(),
jid.clone(),
password.clone(),
+ timeouts,
));
let client = Client {
jid,
@@ -143,6 +160,7 @@ impl<C: ServerConnector> Client<C> {
connector,
state: ClientState::Connecting(connect),
reconnect: false,
+ timeouts,
};
client
}
@@ -56,6 +56,7 @@ impl<C: ServerConnector> Stream for Client<C> {
self.connector.clone(),
self.jid.clone(),
self.password.clone(),
+ self.timeouts,
));
self.state = ClientState::Connecting(connect);
self.poll_next(cx)
@@ -6,16 +6,17 @@ use xmpp_parsers::{component::Handshake, jid::Jid, ns};
use crate::component::ServerConnector;
use crate::error::{AuthError, Error};
-use crate::xmlstream::{ReadError, XmppStream, XmppStreamElement};
+use crate::xmlstream::{ReadError, Timeouts, XmppStream, XmppStreamElement};
/// Log into an XMPP server as a client with a jid+pass
pub async fn component_login<C: ServerConnector>(
connector: C,
jid: Jid,
password: String,
+ timeouts: Timeouts,
) -> Result<XmppStream<C::Stream>, Error> {
let password = password;
- let mut stream = connector.connect(&jid, ns::COMPONENT).await?;
+ let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
let header = stream.take_header();
let mut stream = stream.skip_features();
let stream_id = match header.id {
@@ -6,8 +6,10 @@ use std::str::FromStr;
use xmpp_parsers::jid::Jid;
use crate::{
- component::login::component_login, connect::ServerConnector, xmlstream::XmppStream, Error,
- Stanza,
+ component::login::component_login,
+ connect::ServerConnector,
+ xmlstream::{Timeouts, XmppStream},
+ Error, Stanza,
};
#[cfg(any(feature = "starttls", feature = "insecure-tcp"))]
@@ -46,7 +48,13 @@ impl Component<TcpServerConnector> {
/// Start a new XMPP component over plaintext TCP to localhost:5347
#[cfg(feature = "insecure-tcp")]
pub async fn new(jid: &str, password: &str) -> Result<Self, Error> {
- Self::new_plaintext(jid, password, DnsConfig::addr("127.0.0.1:5347")).await
+ Self::new_plaintext(
+ jid,
+ password,
+ DnsConfig::addr("127.0.0.1:5347"),
+ Timeouts::tight(),
+ )
+ .await
}
/// Start a new XMPP component over plaintext TCP
@@ -55,8 +63,15 @@ impl Component<TcpServerConnector> {
jid: &str,
password: &str,
dns_config: DnsConfig,
+ timeouts: Timeouts,
) -> Result<Self, Error> {
- Component::new_with_connector(jid, password, TcpServerConnector::from(dns_config)).await
+ Component::new_with_connector(
+ jid,
+ password,
+ TcpServerConnector::from(dns_config),
+ timeouts,
+ )
+ .await
}
}
@@ -69,10 +84,11 @@ impl<C: ServerConnector> Component<C> {
jid: &str,
password: &str,
connector: C,
+ timeouts: Timeouts,
) -> Result<Self, Error> {
let jid = Jid::from_str(jid)?;
let password = password.to_owned();
- let stream = component_login(connector, jid.clone(), password).await?;
+ let stream = component_login(connector, jid.clone(), password, timeouts).await?;
Ok(Component { jid, stream })
}
}
@@ -4,7 +4,7 @@ use sasl::common::ChannelBinding;
use tokio::io::{AsyncBufRead, AsyncWrite};
use xmpp_parsers::jid::Jid;
-use crate::xmlstream::PendingFeaturesRecv;
+use crate::xmlstream::{PendingFeaturesRecv, Timeouts};
use crate::Error;
#[cfg(feature = "starttls")]
@@ -36,6 +36,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
&self,
jid: &Jid,
ns: &'static str,
+ timeouts: Timeouts,
) -> impl std::future::Future<Output = Result<PendingFeaturesRecv<Self::Stream>, Error>> + Send;
/// Return channel binding data if available
@@ -44,7 +44,7 @@ use crate::{
connect::{DnsConfig, ServerConnector, ServerConnectorError},
error::{Error, ProtocolError},
xmlstream::{
- initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, XmppStream,
+ initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
XmppStreamElement,
},
Client,
@@ -70,6 +70,7 @@ impl ServerConnector for StartTlsServerConnector {
&self,
jid: &Jid,
ns: &'static str,
+ timeouts: Timeouts,
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
@@ -82,6 +83,7 @@ impl ServerConnector for StartTlsServerConnector {
from: None,
id: None,
},
+ timeouts,
)
.await?;
let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
@@ -98,6 +100,7 @@ impl ServerConnector for StartTlsServerConnector {
from: None,
id: None,
},
+ timeouts,
)
.await?)
} else {
@@ -6,7 +6,7 @@ use tokio::{io::BufStream, net::TcpStream};
use crate::{
connect::{DnsConfig, ServerConnector},
- xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader},
+ xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
Client, Component, Error,
};
@@ -35,6 +35,7 @@ impl ServerConnector for TcpServerConnector {
&self,
jid: &xmpp_parsers::jid::Jid,
ns: &'static str,
+ timeouts: Timeouts,
) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
let stream = BufStream::new(self.0.resolve().await?);
Ok(initiate_stream(
@@ -45,6 +46,7 @@ impl ServerConnector for TcpServerConnector {
from: None,
id: None,
},
+ timeouts,
)
.await?)
}
@@ -7,6 +7,7 @@
use core::future::Future;
use core::pin::Pin;
use core::task::{Context, Poll};
+use core::time::Duration;
use std::borrow::Cow;
use std::io;
@@ -14,7 +15,10 @@ use futures::{ready, Sink, SinkExt, Stream, StreamExt};
use bytes::{Buf, BytesMut};
-use tokio::io::{AsyncBufRead, AsyncWrite};
+use tokio::{
+ io::{AsyncBufRead, AsyncWrite},
+ time::Instant,
+};
use xso::{
exports::rxml::{self, writer::TrackNamespace, xml_ncname, Event, Namespace},
@@ -25,6 +29,129 @@ use super::capture::{log_enabled, log_recv, log_send, CaptureBufRead};
use xmpp_parsers::ns::STREAM as XML_STREAM_NS;
+/// Configuration for timeouts on an XML stream.
+///
+/// The defaults are tuned toward common desktop/laptop use and may not hold
+/// up to extreme conditions (arctic sattelite link, mobile internet on a
+/// train in Brandenburg, Germany, and similar) and may be inefficient in
+/// other conditions (stable server link, localhost communication).
+#[derive(Debug, Clone, Copy)]
+pub struct Timeouts {
+ /// Maximum silence time before a
+ /// [`ReadError::SoftTimeout`][`super::ReadError::SoftTimeout`] is
+ /// returned.
+ ///
+ /// Soft timeouts are not fatal, but they must be handled by user code so
+ /// that more data is read after at most [`Self::response_timeout`],
+ /// starting from the moment the soft timeout is returned.
+ pub read_timeout: Duration,
+
+ /// Maximum silence after a soft timeout.
+ ///
+ /// If the stream is silent for longer than this time after a soft timeout
+ /// has been emitted, a hard [`TimedOut`][`std::io::ErrorKind::TimedOut`]
+ /// I/O error is returned and the stream is to be considered dead.
+ pub response_timeout: Duration,
+}
+
+impl Default for Timeouts {
+ fn default() -> Self {
+ Self {
+ read_timeout: Duration::new(300, 0),
+ response_timeout: Duration::new(300, 0),
+ }
+ }
+}
+
+impl Timeouts {
+ /// Tight timeouts suitable for communicating on a fast LAN or localhost.
+ pub fn tight() -> Self {
+ Self {
+ read_timeout: Duration::new(60, 0),
+ response_timeout: Duration::new(15, 0),
+ }
+ }
+
+ fn data_to_soft(&self) -> Duration {
+ self.read_timeout
+ }
+
+ fn soft_to_warn(&self) -> Duration {
+ self.response_timeout / 2
+ }
+
+ fn warn_to_hard(&self) -> Duration {
+ self.response_timeout / 2
+ }
+}
+
+#[derive(Clone, Copy)]
+enum TimeoutLevel {
+ Soft,
+ Warn,
+ Hard,
+}
+
+#[derive(Debug)]
+pub(super) enum RawError {
+ Io(io::Error),
+ SoftTimeout,
+}
+
+impl From<io::Error> for RawError {
+ fn from(other: io::Error) -> Self {
+ Self::Io(other)
+ }
+}
+
+struct TimeoutState {
+ /// Configuration for the timeouts.
+ timeouts: Timeouts,
+
+ /// Level of the next timeout which will trip.
+ level: TimeoutLevel,
+
+ /// Sleep timer used for read timeouts.
+ // NOTE: even though we pretend we could deal with an !Unpin
+ // RawXmlStream, we really can't: box_stream for example needs it,
+ // but also all the typestate around the initial stream setup needs
+ // to be able to move the stream around.
+ deadline: Pin<Box<tokio::time::Sleep>>,
+}
+
+impl TimeoutState {
+ fn new(timeouts: Timeouts) -> Self {
+ Self {
+ deadline: Box::pin(tokio::time::sleep(timeouts.data_to_soft())),
+ level: TimeoutLevel::Soft,
+ timeouts,
+ }
+ }
+
+ fn poll(&mut self, cx: &mut Context) -> Poll<TimeoutLevel> {
+ ready!(self.deadline.as_mut().poll(cx));
+ // Deadline elapsed!
+ let to_return = self.level;
+ let (next_level, next_duration) = match self.level {
+ TimeoutLevel::Soft => (TimeoutLevel::Warn, self.timeouts.soft_to_warn()),
+ TimeoutLevel::Warn => (TimeoutLevel::Hard, self.timeouts.warn_to_hard()),
+ // Something short-ish so that we fire this over and over until
+ // someone finally kills the stream for good.
+ TimeoutLevel::Hard => (TimeoutLevel::Hard, Duration::new(1, 0)),
+ };
+ self.level = next_level;
+ self.deadline.as_mut().reset(Instant::now() + next_duration);
+ Poll::Ready(to_return)
+ }
+
+ fn reset(&mut self) {
+ self.level = TimeoutLevel::Soft;
+ self.deadline
+ .as_mut()
+ .reset((Instant::now() + self.timeouts.data_to_soft()).into());
+ }
+}
+
pin_project_lite::pin_project! {
// NOTE: due to limitations of pin_project_lite, the field comments are
// no doc comments. Luckily, this struct is only `pub(super)` anyway.
@@ -37,6 +164,8 @@ pin_project_lite::pin_project! {
// The writer used for serialising data.
writer: rxml::writer::Encoder<rxml::writer::SimpleNamespaces>,
+ timeouts: TimeoutState,
+
// The default namespace to declare on the stream header.
stream_ns: &'static str,
@@ -112,7 +241,7 @@ impl<Io: AsyncBufRead + AsyncWrite> RawXmlStream<Io> {
writer
}
- pub(super) fn new(io: Io, stream_ns: &'static str) -> Self {
+ pub(super) fn new(io: Io, stream_ns: &'static str, timeouts: Timeouts) -> Self {
let parser = rxml::Parser::default();
let mut io = CaptureBufRead::wrap(io);
if log_enabled() {
@@ -121,6 +250,7 @@ impl<Io: AsyncBufRead + AsyncWrite> RawXmlStream<Io> {
Self {
parser: rxml::AsyncReader::wrap(io, parser),
writer: Self::new_writer(stream_ns),
+ timeouts: TimeoutState::new(timeouts),
tx_buffer_logged: 0,
stream_ns,
tx_buffer: BytesMut::new(),
@@ -189,18 +319,34 @@ impl<Io> RawXmlStream<Io> {
}
impl<Io: AsyncBufRead> Stream for RawXmlStream<Io> {
- type Item = Result<rxml::Event, io::Error>;
+ type Item = Result<rxml::Event, RawError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
- return Poll::Ready(
- match ready!(this.parser.as_mut().poll_read(cx)).transpose() {
- // Skip the XML declaration, nobody wants to hear about that.
- Some(Ok(rxml::Event::XmlDeclaration(_, _))) => continue,
- other => other,
- },
- );
+ match this.parser.as_mut().poll_read(cx) {
+ Poll::Pending => (),
+ Poll::Ready(v) => {
+ this.timeouts.reset();
+ match v.transpose() {
+ // Skip the XML declaration, nobody wants to hear about that.
+ Some(Ok(rxml::Event::XmlDeclaration(_, _))) => continue,
+ other => return Poll::Ready(other.map(|x| x.map_err(RawError::Io))),
+ }
+ }
+ };
+
+ // poll_read returned pending... what do the timeouts have to say?
+ match ready!(this.timeouts.poll(cx)) {
+ TimeoutLevel::Soft => return Poll::Ready(Some(Err(RawError::SoftTimeout))),
+ TimeoutLevel::Warn => (),
+ TimeoutLevel::Hard => {
+ return Poll::Ready(Some(Err(RawError::Io(io::Error::new(
+ io::ErrorKind::TimedOut,
+ "read and response timeouts elapsed",
+ )))))
+ }
+ }
}
}
}
@@ -312,6 +458,20 @@ pub(super) enum ReadXsoError {
/// not well-formed document.
Hard(io::Error),
+ /// The underlying stream signalled a soft read timeout before a child
+ /// element could be read.
+ ///
+ /// Note that soft timeouts which are triggered in the middle of receiving
+ /// an element are converted to hard timeouts (i.e. I/O errors).
+ ///
+ /// This masking is intentional, because:
+ /// - Returning a [`Self::SoftTimeout`] from the middle of parsing is not
+ /// possible without complicating the API.
+ /// - There is no reason why the remote side should interrupt sending data
+ /// in the middle of an element except if it or the transport has failed
+ /// fatally.
+ SoftTimeout,
+
/// A parse error occurred.
///
/// The XML structure was well-formed, but the data contained did not
@@ -324,19 +484,6 @@ pub(super) enum ReadXsoError {
Parse(xso::error::Error),
}
-impl From<ReadXsoError> for io::Error {
- fn from(other: ReadXsoError) -> Self {
- match other {
- ReadXsoError::Hard(v) => v,
- ReadXsoError::Parse(e) => io::Error::new(io::ErrorKind::InvalidData, e),
- ReadXsoError::Footer => io::Error::new(
- io::ErrorKind::UnexpectedEof,
- "element footer while waiting for XSO element start",
- ),
- }
- }
-}
-
impl From<io::Error> for ReadXsoError {
fn from(other: io::Error) -> Self {
Self::Hard(other)
@@ -425,13 +572,13 @@ impl<T: FromXml> ReadXsoState<T> {
.parser_pinned()
.set_text_buffering(text_buffering);
- let ev = ready!(source.as_mut().poll_next(cx)).transpose()?;
+ let ev = ready!(source.as_mut().poll_next(cx)).transpose();
match self {
ReadXsoState::PreData => {
log::trace!("ReadXsoState::PreData ev = {:?}", ev);
match ev {
- Some(rxml::Event::XmlDeclaration(_, _)) => (),
- Some(rxml::Event::Text(_, data)) => {
+ Ok(Some(rxml::Event::XmlDeclaration(_, _))) => (),
+ Ok(Some(rxml::Event::Text(_, data))) => {
if xso::is_xml_whitespace(data.as_bytes()) {
log::trace!("Received {} bytes of whitespace", data.len());
source.as_mut().stream_pinned().discard_capture();
@@ -445,18 +592,18 @@ impl<T: FromXml> ReadXsoState<T> {
.into()));
}
}
- Some(rxml::Event::StartElement(_, name, attrs)) => {
+ Ok(Some(rxml::Event::StartElement(_, name, attrs))) => {
*self = ReadXsoState::Parsing(
<Result<T, xso::error::Error> as FromXml>::from_events(name, attrs)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
);
}
// Amounts to EOF, as we expect to start on the stream level.
- Some(rxml::Event::EndElement(_)) => {
+ Ok(Some(rxml::Event::EndElement(_))) => {
*self = ReadXsoState::Done;
return Poll::Ready(Err(ReadXsoError::Footer));
}
- None => {
+ Ok(None) => {
*self = ReadXsoState::Done;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -464,17 +611,42 @@ impl<T: FromXml> ReadXsoState<T> {
)
.into()));
}
+ Err(RawError::SoftTimeout) => {
+ *self = ReadXsoState::Done;
+ return Poll::Ready(Err(ReadXsoError::SoftTimeout));
+ }
+ Err(RawError::Io(e)) => {
+ *self = ReadXsoState::Done;
+ return Poll::Ready(Err(ReadXsoError::Hard(e)));
+ }
}
}
ReadXsoState::Parsing(builder) => {
log::trace!("ReadXsoState::Parsing ev = {:?}", ev);
- let Some(ev) = ev else {
- *self = ReadXsoState::Done;
- return Poll::Ready(Err(io::Error::new(
- io::ErrorKind::UnexpectedEof,
- "eof during XSO parsing",
- )
- .into()));
+ let ev = match ev {
+ Ok(Some(ev)) => ev,
+ Ok(None) => {
+ *self = ReadXsoState::Done;
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ "eof during XSO parsing",
+ )
+ .into()));
+ }
+ Err(RawError::Io(e)) => {
+ *self = ReadXsoState::Done;
+ return Poll::Ready(Err(e.into()));
+ }
+ Err(RawError::SoftTimeout) => {
+ // See also [`ReadXsoError::SoftTimeout`] for why
+ // we mask the SoftTimeout condition here.
+ *self = ReadXsoState::Done;
+ return Poll::Ready(Err(io::Error::new(
+ io::ErrorKind::TimedOut,
+ "read timeout during XSO parsing",
+ )
+ .into()));
+ }
};
match builder.feed(ev) {
@@ -622,8 +794,10 @@ impl StreamHeader<'static> {
mut stream: Pin<&mut RawXmlStream<Io>>,
) -> io::Result<Self> {
loop {
- match stream.as_mut().next().await.transpose()? {
- Some(Event::StartElement(_, (ns, name), mut attrs)) => {
+ match stream.as_mut().next().await {
+ Some(Err(RawError::Io(e))) => return Err(e),
+ Some(Err(RawError::SoftTimeout)) => (),
+ Some(Ok(Event::StartElement(_, (ns, name), mut attrs))) => {
if ns != XML_STREAM_NS || name != "stream" {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
@@ -666,7 +840,7 @@ impl StreamHeader<'static> {
id: id.map(Cow::Owned),
});
}
- Some(Event::Text(_, _)) | Some(Event::EndElement(_)) => {
+ Some(Ok(Event::Text(_, _))) | Some(Ok(Event::EndElement(_))) => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected content before stream header",
@@ -674,7 +848,7 @@ impl StreamHeader<'static> {
}
// We cannot loop infinitely here because the XML parser will
// prevent more than one XML declaration from being parsed.
- Some(Event::XmlDeclaration(_, _)) => (),
+ Some(Ok(Event::XmlDeclaration(_, _))) => (),
None => {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
@@ -17,7 +17,7 @@ use xmpp_parsers::stream_features::StreamFeatures;
use xso::{AsXml, FromXml};
use super::{
- common::{RawXmlStream, ReadXso, StreamHeader},
+ common::{RawXmlStream, ReadXso, ReadXsoError, StreamHeader},
XmlStream,
};
@@ -80,7 +80,22 @@ impl<Io: AsyncBufRead + AsyncWrite + Unpin> PendingFeaturesRecv<Io> {
mut stream,
header: _,
} = self;
- let features = ReadXso::read_from(Pin::new(&mut stream)).await?;
+ let features = loop {
+ match ReadXso::read_from(Pin::new(&mut stream)).await {
+ Ok(v) => break v,
+ Err(ReadXsoError::SoftTimeout) => (),
+ Err(ReadXsoError::Hard(e)) => return Err(e),
+ Err(ReadXsoError::Parse(e)) => {
+ return Err(io::Error::new(io::ErrorKind::InvalidData, e))
+ }
+ Err(ReadXsoError::Footer) => {
+ return Err(io::Error::new(
+ io::ErrorKind::UnexpectedEof,
+ "unexpected stream footer",
+ ))
+ }
+ }
+ };
Ok((features, XmlStream::wrap(stream)))
}
@@ -77,8 +77,8 @@ mod responder;
mod tests;
pub(crate) mod xmpp;
-pub use self::common::StreamHeader;
-use self::common::{RawXmlStream, ReadXsoError, ReadXsoState};
+use self::common::{RawError, RawXmlStream, ReadXsoError, ReadXsoState};
+pub use self::common::{StreamHeader, Timeouts};
pub use self::initiator::{InitiatingStream, PendingFeaturesRecv};
pub use self::responder::{AcceptedStream, PendingFeaturesSend};
pub use self::xmpp::XmppStreamElement;
@@ -129,8 +129,9 @@ pub async fn initiate_stream<Io: AsyncBufRead + AsyncWrite + Unpin>(
io: Io,
stream_ns: &'static str,
stream_header: StreamHeader<'_>,
+ timeouts: Timeouts,
) -> Result<PendingFeaturesRecv<Io>, io::Error> {
- let stream = InitiatingStream(RawXmlStream::new(io, stream_ns));
+ let stream = InitiatingStream(RawXmlStream::new(io, stream_ns, timeouts));
stream.send_header(stream_header).await
}
@@ -144,8 +145,9 @@ pub async fn initiate_stream<Io: AsyncBufRead + AsyncWrite + Unpin>(
pub async fn accept_stream<Io: AsyncBufRead + AsyncWrite + Unpin>(
io: Io,
stream_ns: &'static str,
+ timeouts: Timeouts,
) -> Result<AcceptedStream<Io>, io::Error> {
- let mut stream = RawXmlStream::new(io, stream_ns);
+ let mut stream = RawXmlStream::new(io, stream_ns, timeouts);
let header = StreamHeader::recv(Pin::new(&mut stream)).await?;
Ok(AcceptedStream { stream, header })
}
@@ -319,14 +321,21 @@ impl<Io: AsyncBufRead, T: FromXml + AsXml + fmt::Debug> Stream for XmlStream<Io,
type Item = Result<T, ReadError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- let this = self.project();
+ let mut this = self.project();
let result = match this.read_state.as_mut() {
None => {
// awaiting eof.
- return match ready!(this.inner.poll_next(cx)) {
- None => Poll::Ready(None),
- Some(Ok(_)) => unreachable!("xml parser allowed data after stream footer"),
- Some(Err(e)) => Poll::Ready(Some(Err(ReadError::HardError(e)))),
+ return loop {
+ match ready!(this.inner.as_mut().poll_next(cx)) {
+ None => break Poll::Ready(None),
+ Some(Ok(_)) => unreachable!("xml parser allowed data after stream footer"),
+ Some(Err(RawError::Io(e))) => {
+ break Poll::Ready(Some(Err(ReadError::HardError(e))))
+ }
+ // Swallow soft timeout, we don't want the user to trigger
+ // anything here.
+ Some(Err(RawError::SoftTimeout)) => continue,
+ }
};
}
Some(read_state) => ready!(read_state.poll_advance(this.inner, cx)),
@@ -341,6 +350,7 @@ impl<Io: AsyncBufRead, T: FromXml + AsXml + fmt::Debug> Stream for XmlStream<Io,
// another read state.
return Poll::Ready(Some(Err(ReadError::StreamFooterReceived)));
}
+ Err(ReadXsoError::SoftTimeout) => Poll::Ready(Some(Err(ReadError::SoftTimeout))),
};
*this.read_state = Some(ReadXsoState::default());
result
@@ -4,6 +4,8 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+use std::time::Duration;
+
use futures::{SinkExt, StreamExt};
use xmpp_parsers::stream_features::StreamFeatures;
@@ -29,12 +31,18 @@ async fn test_initiate_accept_stream() {
to: Some("server".into()),
id: Some("client-id".into()),
},
+ Timeouts::tight(),
)
.await?;
Ok::<_, io::Error>(stream.take_header())
});
let responder = tokio::spawn(async move {
- let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?;
+ let stream = accept_stream(
+ tokio::io::BufStream::new(rhs),
+ "jabber:client",
+ Timeouts::tight(),
+ )
+ .await?;
assert_eq!(stream.header().from.unwrap(), "client");
assert_eq!(stream.header().to.unwrap(), "server");
assert_eq!(stream.header().id.unwrap(), "client-id");
@@ -61,13 +69,19 @@ async fn test_exchange_stream_features() {
tokio::io::BufStream::new(lhs),
"jabber:client",
StreamHeader::default(),
+ Timeouts::tight(),
)
.await?;
let (features, _) = stream.recv_features::<Data>().await?;
Ok::<_, io::Error>(features)
});
let responder = tokio::spawn(async move {
- let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?;
+ let stream = accept_stream(
+ tokio::io::BufStream::new(rhs),
+ "jabber:client",
+ Timeouts::tight(),
+ )
+ .await?;
let stream = stream.send_header(StreamHeader::default()).await?;
stream
.send_features::<Data>(&StreamFeatures::default())
@@ -88,6 +102,7 @@ async fn test_exchange_data() {
tokio::io::BufStream::new(lhs),
"jabber:client",
StreamHeader::default(),
+ Timeouts::tight(),
)
.await?;
let (_, mut stream) = stream.recv_features::<Data>().await?;
@@ -104,7 +119,12 @@ async fn test_exchange_data() {
});
let responder = tokio::spawn(async move {
- let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?;
+ let stream = accept_stream(
+ tokio::io::BufStream::new(rhs),
+ "jabber:client",
+ Timeouts::tight(),
+ )
+ .await?;
let stream = stream.send_header(StreamHeader::default()).await?;
let mut stream = stream
.send_features::<Data>(&StreamFeatures::default())
@@ -134,6 +154,7 @@ async fn test_clean_shutdown() {
tokio::io::BufStream::new(lhs),
"jabber:client",
StreamHeader::default(),
+ Timeouts::tight(),
)
.await?;
let (_, mut stream) = stream.recv_features::<Data>().await?;
@@ -146,7 +167,12 @@ async fn test_clean_shutdown() {
});
let responder = tokio::spawn(async move {
- let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?;
+ let stream = accept_stream(
+ tokio::io::BufStream::new(rhs),
+ "jabber:client",
+ Timeouts::tight(),
+ )
+ .await?;
let stream = stream.send_header(StreamHeader::default()).await?;
let mut stream = stream
.send_features::<Data>(&StreamFeatures::default())
@@ -172,6 +198,7 @@ async fn test_exchange_data_stream_reset_and_shutdown() {
tokio::io::BufStream::new(lhs),
"jabber:client",
StreamHeader::default(),
+ Timeouts::tight(),
)
.await?;
let (_, mut stream) = stream.recv_features::<Data>().await?;
@@ -215,7 +242,12 @@ async fn test_exchange_data_stream_reset_and_shutdown() {
});
let responder = tokio::spawn(async move {
- let stream = accept_stream(tokio::io::BufStream::new(rhs), "jabber:client").await?;
+ let stream = accept_stream(
+ tokio::io::BufStream::new(rhs),
+ "jabber:client",
+ Timeouts::tight(),
+ )
+ .await?;
let stream = stream.send_header(StreamHeader::default()).await?;
let mut stream = stream
.send_features::<Data>(&StreamFeatures::default())
@@ -262,3 +294,104 @@ async fn test_exchange_data_stream_reset_and_shutdown() {
responder.await.unwrap().expect("responder failed");
initiator.await.unwrap().expect("initiator failed");
}
+
+#[tokio::test(start_paused = true)]
+async fn test_emits_soft_timeout_after_silence() {
+ let (lhs, rhs) = tokio::io::duplex(65536);
+
+ let client_timeouts = Timeouts {
+ read_timeout: Duration::new(300, 0),
+ response_timeout: Duration::new(15, 0),
+ };
+
+ // We do want to trigger only one set of timeouts, so we set the server
+ // timeouts much longer than the client timeouts
+ let server_timeouts = Timeouts {
+ read_timeout: Duration::new(900, 0),
+ response_timeout: Duration::new(15, 0),
+ };
+
+ let initiator = tokio::spawn(async move {
+ let stream = initiate_stream(
+ tokio::io::BufStream::new(lhs),
+ "jabber:client",
+ StreamHeader::default(),
+ client_timeouts,
+ )
+ .await?;
+ let (_, mut stream) = stream.recv_features::<Data>().await?;
+ stream
+ .send(&Data {
+ contents: "hello".to_owned(),
+ })
+ .await?;
+ match stream.next().await {
+ Some(Ok(Data { contents })) => assert_eq!(contents, "world!"),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ // Here we prove that the stream doesn't see any data and also does
+ // not see the SoftTimeout too early.
+ // (Well, not exactly a proof: We only check until half of the read
+ // timeout, because that was easy to write and I deem it good enough.)
+ match tokio::time::timeout(client_timeouts.read_timeout / 2, stream.next()).await {
+ Err(_) => (),
+ Ok(ev) => panic!("early stream message (before soft timeout): {:?}", ev),
+ };
+ // Now the next thing that happens is the soft timeout ...
+ match stream.next().await {
+ Some(Err(ReadError::SoftTimeout)) => (),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ // Another check that the there is some time between soft and hard
+ // timeout.
+ match tokio::time::timeout(client_timeouts.response_timeout / 3, stream.next()).await {
+ Err(_) => (),
+ Ok(ev) => {
+ panic!("early stream message (before hard timeout): {:?}", ev);
+ }
+ };
+ // ... and thereafter the hard timeout in form of an I/O error.
+ match stream.next().await {
+ Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::TimedOut => (),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ Ok::<_, io::Error>(())
+ });
+
+ let responder = tokio::spawn(async move {
+ let stream = accept_stream(
+ tokio::io::BufStream::new(rhs),
+ "jabber:client",
+ server_timeouts,
+ )
+ .await?;
+ let stream = stream.send_header(StreamHeader::default()).await?;
+ let mut stream = stream
+ .send_features::<Data>(&StreamFeatures::default())
+ .await?;
+ stream
+ .send(&Data {
+ contents: "world!".to_owned(),
+ })
+ .await?;
+ match stream.next().await {
+ Some(Ok(Data { contents })) => assert_eq!(contents, "hello"),
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ match stream.next().await {
+ Some(Err(ReadError::HardError(e))) if e.kind() == io::ErrorKind::InvalidData => {
+ match e.downcast::<rxml::Error>() {
+ // the initiator closes the stream by dropping it once the
+ // timeout trips, so we get a hard eof here.
+ Ok(rxml::Error::InvalidEof(_)) => (),
+ other => panic!("unexpected error: {:?}", other),
+ }
+ }
+ other => panic!("unexpected stream message: {:?}", other),
+ }
+ Ok::<_, io::Error>(())
+ });
+
+ responder.await.unwrap().expect("responder failed");
+ initiator.await.unwrap().expect("initiator failed");
+}
@@ -15,6 +15,7 @@ use tokio_xmpp::{
disco::{DiscoInfoResult, Feature, Identity},
ns,
},
+ xmlstream::Timeouts,
Client as TokioXmppClient,
};
@@ -51,6 +52,7 @@ pub struct ClientBuilder<'a, C: ServerConnector> {
disco: (ClientType, String),
features: Vec<ClientFeature>,
resource: Option<String>,
+ timeouts: Timeouts,
}
#[cfg(any(feature = "starttls-rust", feature = "starttls-native"))]
@@ -80,6 +82,7 @@ impl<C: ServerConnector> ClientBuilder<'_, C> {
disco: (ClientType::default(), String::from("tokio-xmpp")),
features: vec![],
resource: None,
+ timeouts: Timeouts::default(),
}
}
@@ -109,6 +112,15 @@ impl<C: ServerConnector> ClientBuilder<'_, C> {
self
}
+ /// Configure the timeouts used.
+ ///
+ /// See [`Timeouts`] for more information on the semantics and the
+ /// defaults (which are used unless you call this method).
+ pub fn set_timeouts(mut self, timeouts: Timeouts) -> Self {
+ self.timeouts = timeouts;
+ self
+ }
+
pub fn enable_feature(mut self, feature: ClientFeature) -> Self {
self.features.push(feature);
self
@@ -146,8 +158,12 @@ impl<C: ServerConnector> ClientBuilder<'_, C> {
self.jid.clone().into()
};
- let mut client =
- TokioXmppClient::new_with_connector(jid, self.password, self.server_connector.clone());
+ let mut client = TokioXmppClient::new_with_connector(
+ jid,
+ self.password,
+ self.server_connector.clone(),
+ self.timeouts,
+ );
client.set_reconnect(true);
self.build_impl(client)
}