diff --git a/tokio-xmpp/src/xmlstream/capture.rs b/tokio-xmpp/src/xmlstream/capture.rs new file mode 100644 index 0000000000000000000000000000000000000000..4043877ab3b80536710e955a91cddafdbd943725 --- /dev/null +++ b/tokio-xmpp/src/xmlstream/capture.rs @@ -0,0 +1,205 @@ +// Copyright (c) 2024 Jonas Schäfer +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Small helper struct to capture data read from an AsyncBufRead. + +use core::pin::Pin; +use core::task::{Context, Poll}; +use std::io::{self, IoSlice}; + +use futures::ready; + +use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf}; + +use super::LogXsoBuf; + +pin_project_lite::pin_project! { + /// Wrapper around [`AsyncBufRead`] which stores bytes which have been + /// read in an internal vector for later inspection. + /// + /// This struct implements [`AsyncRead`] and [`AsyncBufRead`] and passes + /// read requests down to the wrapped [`AsyncBufRead`]. + /// + /// After capturing has been enabled using [`Self::enable_capture`], any + /// data which is read via the struct will be stored in an internal buffer + /// and can be extracted with [`Self::take_capture`] or discarded using + /// [`Self::discard_capture`]. + /// + /// This can be used to log data which is being read from a source. + /// + /// In addition, this struct implements [`AsyncWrite`] if and only if `T` + /// implements [`AsyncWrite`]. Writing is unaffected by capturing and is + /// implemented solely for convenience purposes (to allow duplex usage + /// of a wrapped I/O object). + pub(super) struct CaptureBufRead { + #[pin] + inner: T, + buf: Option<(Vec, usize)>, + } +} + +impl CaptureBufRead { + /// Wrap a given [`AsyncBufRead`]. + /// + /// Note that capturing of data which is being read is disabled by default + /// and needs to be enabled using [`Self::enable_capture`]. + pub fn wrap(inner: T) -> Self { + Self { inner, buf: None } + } + + /// Extract the inner [`AsyncBufRead`] and discard the capture buffer. + pub fn into_inner(self) -> T { + self.inner + } + + /// Obtain a reference to the inner [`AsyncBufRead`]. + pub fn inner(&self) -> &T { + &self.inner + } + + /// Enable capturing of read data into the inner buffer. + /// + /// Any data which is read from now on will be copied into the internal + /// buffer. That buffer will grow indefinitely until calls to + /// [`Self::take_capture`] or [`Self::discard_capture`]. + pub fn enable_capture(&mut self) { + self.buf = Some((Vec::new(), 0)); + } + + /// Discard the current buffer data, if any. + /// + /// Further data which is read will be captured again. + pub(super) fn discard_capture(self: Pin<&mut Self>) { + let this = self.project(); + if let Some((buf, consumed_up_to)) = this.buf.as_mut() { + buf.drain(..*consumed_up_to); + *consumed_up_to = 0; + } + } + + /// Take the currently captured data out of the inner buffer. + /// + /// Returns `None` unless capturing has been enabled using + /// [`Self::enable_capture`]. + pub(super) fn take_capture(self: Pin<&mut Self>) -> Option> { + let this = self.project(); + let (buf, consumed_up_to) = this.buf.as_mut()?; + let result = buf.drain(..*consumed_up_to).collect(); + buf.drain(..*consumed_up_to); + *consumed_up_to = 0; + Some(result) + } +} + +impl AsyncRead for CaptureBufRead { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + read_buf: &mut ReadBuf, + ) -> Poll> { + let this = self.project(); + let prev_len = read_buf.filled().len(); + let result = ready!(this.inner.poll_read(cx, read_buf)); + if let Some((buf, consumed_up_to)) = this.buf.as_mut() { + buf.truncate(*consumed_up_to); + buf.extend(&read_buf.filled()[prev_len..]); + *consumed_up_to = buf.len(); + } + Poll::Ready(result) + } +} + +impl AsyncBufRead for CaptureBufRead { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); + let result = ready!(this.inner.poll_fill_buf(cx))?; + if let Some((buf, consumed_up_to)) = this.buf.as_mut() { + buf.truncate(*consumed_up_to); + buf.extend(result); + } + Poll::Ready(Ok(result)) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let this = self.project(); + this.inner.consume(amt); + if let Some((_, consumed_up_to)) = this.buf.as_mut() { + // Increase the amount of data to preserve. + *consumed_up_to = *consumed_up_to + amt; + } + } +} + +impl AsyncWrite for CaptureBufRead { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.project().inner.poll_write(cx, buf) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_shutdown(cx) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.project().inner.poll_flush(cx) + } + + fn is_write_vectored(&self) -> bool { + self.inner.is_write_vectored() + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context, + bufs: &[IoSlice], + ) -> Poll> { + self.project().inner.poll_write_vectored(cx, bufs) + } +} + +/// Return true if logging via [`log_recv`] or [`log_send`] might be visible +/// to the user. +pub(super) fn log_enabled() -> bool { + log::log_enabled!(log::Level::Trace) +} + +/// Log received data. +/// +/// `err` is an error which may be logged alongside the received data. +/// `capture` is the data which has been received and which should be logged. +/// If built with the `syntax-highlighting` feature, `capture` data will be +/// logged with XML syntax highlighting. +/// +/// If both `err` and `capture` are None, nothing will be logged. +pub(super) fn log_recv(err: Option<&xmpp_parsers::Error>, capture: Option>) { + match err { + Some(err) => match capture { + Some(capture) => { + log::trace!("RECV (error: {}) {}", err, LogXsoBuf(&capture)); + } + None => { + log::trace!("RECV (error: {}) [data capture disabled]", err); + } + }, + None => match capture { + Some(capture) => { + log::trace!("RECV (ok) {}", LogXsoBuf(&capture)); + } + None => (), + }, + } +} + +/// Log sent data. +/// +/// If built with the `syntax-highlighting` feature, `data` data will be +/// logged with XML syntax highlighting. +pub(super) fn log_send(data: &[u8]) { + log::trace!("SEND {}", LogXsoBuf(data)); +} diff --git a/tokio-xmpp/src/xmlstream/common.rs b/tokio-xmpp/src/xmlstream/common.rs index 9d0d6e8e9d1a445e9fe7132a7ca39e3bbfc9fd49..c0969564c0d47a896a4c12456ec701c61b84beee 100644 --- a/tokio-xmpp/src/xmlstream/common.rs +++ b/tokio-xmpp/src/xmlstream/common.rs @@ -18,9 +18,11 @@ use tokio::io::{AsyncBufRead, AsyncWrite}; use xso::{ exports::rxml::{self, writer::TrackNamespace, xml_ncname, Event, Namespace}, - FromEventsBuilder, FromXml, Item, + AsXml, FromEventsBuilder, FromXml, Item, }; +use super::capture::{log_enabled, log_recv, log_send, CaptureBufRead}; + use xmpp_parsers::ns::STREAM as XML_STREAM_NS; pin_project_lite::pin_project! { @@ -30,7 +32,7 @@ pin_project_lite::pin_project! { pub(super) struct RawXmlStream { // The parser used for deserialising data. #[pin] - parser: rxml::AsyncReader, + parser: rxml::AsyncReader>, // The writer used for serialising data. writer: rxml::writer::Encoder, @@ -44,6 +46,10 @@ pin_project_lite::pin_project! { // happens in `start_send`. tx_buffer: BytesMut, + // Position inside tx_buffer up to which to-be-sent data has already + // been logged. + tx_buffer_logged: usize, + // This signifies the limit at the point of which the Sink will // refuse to accept more data: if the `tx_buffer`'s size grows beyond // that high water mark, poll_ready will return Poll::Pending until @@ -108,9 +114,14 @@ impl RawXmlStream { pub(super) fn new(io: Io, stream_ns: &'static str) -> Self { let parser = rxml::Parser::default(); + let mut io = CaptureBufRead::wrap(io); + if log_enabled() { + io.enable_capture(); + } Self { parser: rxml::AsyncReader::wrap(io, parser), writer: Self::new_writer(stream_ns), + tx_buffer_logged: 0, stream_ns, tx_buffer: BytesMut::new(), @@ -129,7 +140,37 @@ impl RawXmlStream { } pub(super) fn into_inner(self) -> Io { - self.parser.into_inner().0 + self.parser.into_inner().0.into_inner() + } +} + +impl RawXmlStream { + /// Start sending an entire XSO. + /// + /// Unlike the `Sink` implementation, this provides nice syntax + /// highlighting for the serialised data in log outputs (if enabled) *and* + /// is error safe: if the XSO fails to serialise completely, it will be as + /// if it hadn't been attempted to serialise it at all. + /// + /// Note that, like with `start_send`, the caller is responsible for + /// ensuring that the stream is ready by polling + /// [`::poll_ready`] as needed. + pub(super) fn start_send_xso(self: Pin<&mut Self>, xso: &T) -> io::Result<()> { + let mut this = self.project(); + let prev_len = this.tx_buffer.len(); + match this.try_send_xso(xso) { + Ok(()) => Ok(()), + Err(e) => { + let curr_len = this.tx_buffer.len(); + this.tx_buffer.truncate(prev_len); + log::trace!( + "SEND failed: {}. Rewinding buffer by {} bytes.", + e, + curr_len - prev_len + ); + Err(e) + } + } } } @@ -138,8 +179,12 @@ impl RawXmlStream { self.project().parser.parser_pinned() } + fn stream_pinned(self: Pin<&mut Self>) -> Pin<&mut CaptureBufRead> { + self.project().parser.inner_pinned() + } + pub(super) fn get_stream(&self) -> &Io { - self.parser.inner() + self.parser.inner().inner() } } @@ -161,7 +206,38 @@ impl Stream for RawXmlStream { } impl<'x, Io: AsyncWrite> RawXmlStreamProj<'x, Io> { + fn flush_tx_log(&mut self) { + let range = &self.tx_buffer[*self.tx_buffer_logged..]; + if range.len() == 0 { + return; + } + log_send(range); + *self.tx_buffer_logged = self.tx_buffer.len(); + } + + fn start_send(&mut self, item: &xso::Item<'_>) -> io::Result<()> { + self.writer + .encode_into_bytes(item.as_rxml_item(), self.tx_buffer) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) + } + + fn try_send_xso(&mut self, xso: &T) -> io::Result<()> { + let iter = match xso.as_xml_iter() { + Ok(v) => v, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; + for item in iter { + let item = match item { + Ok(v) => v, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; + self.start_send(&item)?; + } + Ok(()) + } + fn progress_write(&mut self, cx: &mut Context<'_>) -> Poll> { + self.flush_tx_log(); while self.tx_buffer.len() > 0 { let written = match ready!(self .parser @@ -173,6 +249,10 @@ impl<'x, Io: AsyncWrite> RawXmlStreamProj<'x, Io> { Err(e) => return Poll::Ready(Err(e)), }; self.tx_buffer.advance(written); + *self.tx_buffer_logged = self + .tx_buffer_logged + .checked_sub(written) + .expect("Buffer arithmetic error"); } Poll::Ready(Ok(())) } @@ -212,10 +292,8 @@ impl<'x, Io: AsyncWrite> Sink> for RawXmlStream { } fn start_send(self: Pin<&mut Self>, item: xso::Item<'x>) -> Result<(), Self::Error> { - let this = self.project(); - this.writer - .encode_into_bytes(item.as_rxml_item(), this.tx_buffer) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) + let mut this = self.project(); + this.start_send(&item) } } @@ -346,42 +424,48 @@ impl ReadXsoState { .as_mut() .parser_pinned() .set_text_buffering(text_buffering); + let ev = ready!(source.as_mut().poll_next(cx)).transpose()?; match self { - ReadXsoState::PreData => match ev { - Some(rxml::Event::XmlDeclaration(_, _)) => (), - Some(rxml::Event::Text(_, data)) => { - if xso::is_xml_whitespace(data.as_bytes()) { - continue; - } else { + ReadXsoState::PreData => { + log::trace!("ReadXsoState::PreData ev = {:?}", ev); + match ev { + Some(rxml::Event::XmlDeclaration(_, _)) => (), + Some(rxml::Event::Text(_, data)) => { + if xso::is_xml_whitespace(data.as_bytes()) { + log::trace!("Received {} bytes of whitespace", data.len()); + source.as_mut().stream_pinned().discard_capture(); + continue; + } else { + *self = ReadXsoState::Done; + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::InvalidData, + "non-whitespace text content before XSO", + ) + .into())); + } + } + Some(rxml::Event::StartElement(_, name, attrs)) => { + *self = ReadXsoState::Parsing( + as FromXml>::from_events(name, attrs) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?, + ); + } + // Amounts to EOF, as we expect to start on the stream level. + Some(rxml::Event::EndElement(_)) => { + *self = ReadXsoState::Done; + return Poll::Ready(Err(ReadXsoError::Footer)); + } + None => { *self = ReadXsoState::Done; return Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, - "non-whitespace text content before XSO", + "end of parent element before XSO started", ) .into())); } } - Some(rxml::Event::StartElement(_, name, attrs)) => { - *self = ReadXsoState::Parsing( - as FromXml>::from_events(name, attrs) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?, - ); - } - // Amounts to EOF, as we expect to start on the stream level. - Some(rxml::Event::EndElement(_)) => { - *self = ReadXsoState::Done; - return Poll::Ready(Err(ReadXsoError::Footer)); - } - None => { - *self = ReadXsoState::Done; - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::UnexpectedEof, - "end of parent element before XSO started", - ) - .into())); - } - }, + } ReadXsoState::Parsing(builder) => { let Some(ev) = ev else { *self = ReadXsoState::Done; @@ -395,6 +479,7 @@ impl ReadXsoState { match builder.feed(ev) { Err(err) => { *self = ReadXsoState::Done; + source.as_mut().stream_pinned().discard_capture(); return Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, err, @@ -403,10 +488,12 @@ impl ReadXsoState { } Ok(Some(Err(err))) => { *self = ReadXsoState::Done; + log_recv(Some(&err), source.as_mut().stream_pinned().take_capture()); return Poll::Ready(Err(ReadXsoError::Parse(err))); } Ok(Some(Ok(value))) => { *self = ReadXsoState::Done; + log_recv(None, source.as_mut().stream_pinned().take_capture()); return Poll::Ready(Ok(value)); } Ok(None) => (), diff --git a/tokio-xmpp/src/xmlstream/mod.rs b/tokio-xmpp/src/xmlstream/mod.rs index 796020fc9c154b1f2b4a84a2b85e4df09839b7fb..04af83c493b0d8a28e522011ac75870d853efa13 100644 --- a/tokio-xmpp/src/xmlstream/mod.rs +++ b/tokio-xmpp/src/xmlstream/mod.rs @@ -56,9 +56,12 @@ //! [`XmlStream::accept_reset`] handles sending the last pre-reset element and //! resetting the stream in a single step. +use core::fmt; use core::pin::Pin; use core::task::{Context, Poll}; use std::io; +#[cfg(feature = "syntax-highlighting")] +use std::sync::OnceLock; use futures::{ready, Sink, SinkExt, Stream}; @@ -66,6 +69,7 @@ use tokio::io::{AsyncBufRead, AsyncWrite}; use xso::{AsXml, FromXml, Item}; +mod capture; mod common; mod initiator; mod responder; @@ -79,6 +83,40 @@ pub use self::initiator::{InitiatingStream, PendingFeaturesRecv}; pub use self::responder::{AcceptedStream, PendingFeaturesSend}; pub use self::xmpp::XmppStreamElement; +#[cfg(feature = "syntax-highlighting")] +static PS: OnceLock = OnceLock::new(); +#[cfg(feature = "syntax-highlighting")] +static SYNTAX: OnceLock = OnceLock::new(); +#[cfg(feature = "syntax-highlighting")] +static THEME: OnceLock = OnceLock::new(); + +#[cfg(feature = "syntax-highlighting")] +fn highlight_xml(xml: &str) -> String { + let ps = PS.get_or_init(syntect::parsing::SyntaxSet::load_defaults_newlines); + let mut h = syntect::easy::HighlightLines::new( + SYNTAX.get_or_init(|| ps.find_syntax_by_extension("xml").unwrap().clone()), + THEME.get_or_init(|| { + syntect::highlighting::ThemeSet::load_defaults().themes["Solarized (dark)"].clone() + }), + ); + + let ranges: Vec<_> = h.highlight_line(&xml, ps).unwrap(); + let escaped = syntect::util::as_24_bit_terminal_escaped(&ranges[..], false); + format!("{}\x1b[0m", escaped) +} + +struct LogXsoBuf<'x>(&'x [u8]); + +impl<'x> fmt::Display for LogXsoBuf<'x> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // We always generate UTF-8, so this should be good... I think. + let text = std::str::from_utf8(&self.0).unwrap(); + #[cfg(feature = "syntax-highlighting")] + let text = highlight_xml(text); + f.write_str(&text) + } +} + /// Initiate a new stream /// /// Initiate a new stream using the given I/O object `io`. The default @@ -212,7 +250,7 @@ impl XmlStream { } } -impl XmlStream { +impl XmlStream { /// Initiate a stream reset /// /// To actually send the stream header, call @@ -277,7 +315,7 @@ impl XmlStream } } -impl Stream for XmlStream { +impl Stream for XmlStream { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -300,7 +338,7 @@ impl Stream for XmlStream { } } -impl<'x, Io: AsyncWrite, T: FromXml + AsXml> Sink<&'x T> for XmlStream { +impl<'x, Io: AsyncWrite, T: FromXml + AsXml + fmt::Debug> Sink<&'x T> for XmlStream { type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -347,20 +385,9 @@ impl<'x, Io: AsyncWrite, T: FromXml + AsXml> Sink<&'x T> for XmlStream { } fn start_send(self: Pin<&mut Self>, item: &'x T) -> Result<(), Self::Error> { - let mut this = self.project(); + let this = self.project(); this.write_state.check_writable()?; - let iter = match item.as_xml_iter() { - Ok(v) => v, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), - }; - for item in iter { - let item = match item { - Ok(v) => v, - Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), - }; - this.inner.as_mut().start_send(item)?; - } - Ok(()) + this.inner.start_send_xso(item) } } diff --git a/tokio-xmpp/src/xmlstream/responder.rs b/tokio-xmpp/src/xmlstream/responder.rs index 3b0d7972397bc736fa7ea7484f3bcd056948b9d0..dea183ffdde5e6e0cc6eeaffb86b3d553c31c7f6 100644 --- a/tokio-xmpp/src/xmlstream/responder.rs +++ b/tokio-xmpp/src/xmlstream/responder.rs @@ -83,14 +83,7 @@ impl PendingFeaturesSend { features: &'_ StreamFeatures, ) -> io::Result> { let Self { mut stream } = self; - let iter = features - .as_xml_iter() - .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; - - for item in iter { - let item = item.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; - stream.send(item).await?; - } + Pin::new(&mut stream).start_send_xso(features)?; stream.flush().await?; Ok(XmlStream::wrap(stream)) diff --git a/xso/src/lib.rs b/xso/src/lib.rs index 34422d01508144dd66be022f188aa5f4baf1997e..f8f05ce9b54120bc5a993710a29f25233c1da183 100644 --- a/xso/src/lib.rs +++ b/xso/src/lib.rs @@ -416,6 +416,18 @@ pub fn from_bytes(mut buf: &[u8]) -> Result { Err(self::error::Error::XmlError(rxml::Error::InvalidEof(None))) } +/// Attempt to serialise a type implementing [`AsXml`] to a vector of bytes. +pub fn to_vec(xso: &T) -> Result, self::error::Error> { + let iter = xso.as_xml_iter()?; + let mut writer = rxml::writer::Encoder::new(); + let mut buf = Vec::new(); + for item in iter { + let item = item?; + writer.encode(item.as_rxml_item(), &mut buf)?; + } + Ok(buf) +} + /// Return true if the string contains exclusively XML whitespace. /// /// XML whitespace is defined as U+0020 (space), U+0009 (tab), U+000a