capture.rs

  1// Copyright (c) 2024 Jonas Schäfer <jonas@zombofant.net>
  2//
  3// This Source Code Form is subject to the terms of the Mozilla Public
  4// License, v. 2.0. If a copy of the MPL was not distributed with this
  5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6
  7//! Small helper struct to capture data read from an AsyncBufRead.
  8
  9use core::pin::Pin;
 10use core::task::{Context, Poll};
 11use std::io::{self, IoSlice};
 12
 13use futures::ready;
 14
 15use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
 16
 17use super::LogXsoBuf;
 18
 19pin_project_lite::pin_project! {
 20    /// Wrapper around [`AsyncBufRead`] which stores bytes which have been
 21    /// read in an internal vector for later inspection.
 22    ///
 23    /// This struct implements [`AsyncRead`] and [`AsyncBufRead`] and passes
 24    /// read requests down to the wrapped [`AsyncBufRead`].
 25    ///
 26    /// After capturing has been enabled using [`Self::enable_capture`], any
 27    /// data which is read via the struct will be stored in an internal buffer
 28    /// and can be extracted with [`Self::take_capture`] or discarded using
 29    /// [`Self::discard_capture`].
 30    ///
 31    /// This can be used to log data which is being read from a source.
 32    ///
 33    /// In addition, this struct implements [`AsyncWrite`] if and only if `T`
 34    /// implements [`AsyncWrite`]. Writing is unaffected by capturing and is
 35    /// implemented solely for convenience purposes (to allow duplex usage
 36    /// of a wrapped I/O object).
 37    pub(super) struct CaptureBufRead<T> {
 38        #[pin]
 39        inner: T,
 40        buf: Option<(Vec<u8>, usize)>,
 41    }
 42}
 43
 44impl<T> CaptureBufRead<T> {
 45    /// Wrap a given [`AsyncBufRead`].
 46    ///
 47    /// Note that capturing of data which is being read is disabled by default
 48    /// and needs to be enabled using [`Self::enable_capture`].
 49    pub fn wrap(inner: T) -> Self {
 50        Self { inner, buf: None }
 51    }
 52
 53    /// Extract the inner [`AsyncBufRead`] and discard the capture buffer.
 54    pub fn into_inner(self) -> T {
 55        self.inner
 56    }
 57
 58    /// Obtain a reference to the inner [`AsyncBufRead`].
 59    pub fn inner(&self) -> &T {
 60        &self.inner
 61    }
 62
 63    /// Enable capturing of read data into the inner buffer.
 64    ///
 65    /// Any data which is read from now on will be copied into the internal
 66    /// buffer. That buffer will grow indefinitely until calls to
 67    /// [`Self::take_capture`] or [`Self::discard_capture`].
 68    pub fn enable_capture(&mut self) {
 69        self.buf = Some((Vec::new(), 0));
 70    }
 71
 72    /// Discard the current buffer data, if any.
 73    ///
 74    /// Further data which is read will be captured again.
 75    pub(super) fn discard_capture(self: Pin<&mut Self>) {
 76        let this = self.project();
 77        if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
 78            buf.drain(..*consumed_up_to);
 79            *consumed_up_to = 0;
 80        }
 81    }
 82
 83    /// Take the currently captured data out of the inner buffer.
 84    ///
 85    /// Returns `None` unless capturing has been enabled using
 86    /// [`Self::enable_capture`].
 87    pub(super) fn take_capture(self: Pin<&mut Self>) -> Option<Vec<u8>> {
 88        let this = self.project();
 89        let (buf, consumed_up_to) = this.buf.as_mut()?;
 90        let result = buf.drain(..*consumed_up_to).collect();
 91        *consumed_up_to = 0;
 92        Some(result)
 93    }
 94}
 95
 96impl<T: AsyncRead> AsyncRead for CaptureBufRead<T> {
 97    fn poll_read(
 98        self: Pin<&mut Self>,
 99        cx: &mut Context,
100        read_buf: &mut ReadBuf,
101    ) -> Poll<io::Result<()>> {
102        let this = self.project();
103        let prev_len = read_buf.filled().len();
104        let result = ready!(this.inner.poll_read(cx, read_buf));
105        if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
106            buf.truncate(*consumed_up_to);
107            buf.extend(&read_buf.filled()[prev_len..]);
108            *consumed_up_to = buf.len();
109        }
110        Poll::Ready(result)
111    }
112}
113
114impl<T: AsyncBufRead> AsyncBufRead for CaptureBufRead<T> {
115    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
116        let this = self.project();
117        let result = ready!(this.inner.poll_fill_buf(cx))?;
118        if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
119            buf.truncate(*consumed_up_to);
120            buf.extend(result);
121        }
122        Poll::Ready(Ok(result))
123    }
124
125    fn consume(self: Pin<&mut Self>, amt: usize) {
126        let this = self.project();
127        this.inner.consume(amt);
128        if let Some((_, consumed_up_to)) = this.buf.as_mut() {
129            // Increase the amount of data to preserve.
130            *consumed_up_to += amt;
131        }
132    }
133}
134
135impl<T: AsyncWrite> AsyncWrite for CaptureBufRead<T> {
136    fn poll_write(
137        self: Pin<&mut Self>,
138        cx: &mut Context<'_>,
139        buf: &[u8],
140    ) -> Poll<io::Result<usize>> {
141        self.project().inner.poll_write(cx, buf)
142    }
143
144    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145        self.project().inner.poll_shutdown(cx)
146    }
147
148    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
149        self.project().inner.poll_flush(cx)
150    }
151
152    fn is_write_vectored(&self) -> bool {
153        self.inner.is_write_vectored()
154    }
155
156    fn poll_write_vectored(
157        self: Pin<&mut Self>,
158        cx: &mut Context,
159        bufs: &[IoSlice],
160    ) -> Poll<io::Result<usize>> {
161        self.project().inner.poll_write_vectored(cx, bufs)
162    }
163}
164
165/// Return true if logging via [`log_recv`] or [`log_send`] might be visible
166/// to the user.
167pub(super) fn log_enabled() -> bool {
168    log::log_enabled!(log::Level::Trace)
169}
170
171/// Log received data.
172///
173/// `err` is an error which may be logged alongside the received data.
174/// `capture` is the data which has been received and which should be logged.
175/// If built with the `syntax-highlighting` feature, `capture` data will be
176/// logged with XML syntax highlighting.
177///
178/// If both `err` and `capture` are None, nothing will be logged.
179pub(super) fn log_recv(err: Option<&xmpp_parsers::Error>, capture: Option<Vec<u8>>) {
180    match err {
181        Some(err) => match capture {
182            Some(capture) => {
183                log::trace!("RECV (error: {}) {}", err, LogXsoBuf(&capture));
184            }
185            None => {
186                log::trace!("RECV (error: {}) [data capture disabled]", err);
187            }
188        },
189        None => {
190            if let Some(capture) = capture {
191                log::trace!("RECV (ok) {}", LogXsoBuf(&capture));
192            }
193        }
194    }
195}
196
197/// Log sent data.
198///
199/// If built with the `syntax-highlighting` feature, `data` data will be
200/// logged with XML syntax highlighting.
201pub(super) fn log_send(data: &[u8]) {
202    log::trace!("SEND {}", LogXsoBuf(data));
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    use tokio::io::{AsyncBufReadExt, AsyncReadExt};
210
211    #[tokio::test]
212    async fn captures_data_read_via_async_read() {
213        let mut src = &b"Hello World!"[..];
214        let src = tokio::io::BufReader::new(&mut src);
215        let mut src = CaptureBufRead::wrap(src);
216        src.enable_capture();
217
218        let mut dst = [0u8; 8];
219        assert_eq!(src.read(&mut dst[..]).await.unwrap(), 8);
220        assert_eq!(&dst, b"Hello Wo");
221        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello Wo");
222    }
223
224    #[tokio::test]
225    async fn captures_data_read_via_async_buf_read() {
226        let mut src = &b"Hello World!"[..];
227        let src = tokio::io::BufReader::new(&mut src);
228        let mut src = CaptureBufRead::wrap(src);
229        src.enable_capture();
230
231        assert_eq!(src.fill_buf().await.unwrap().len(), 12);
232        // We haven't consumed any bytes yet -> must return zero.
233        assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
234
235        src.consume(5);
236        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello");
237
238        src.consume(6);
239        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b" World");
240    }
241
242    #[tokio::test]
243    async fn discard_capture_drops_consumed_data() {
244        let mut src = &b"Hello World!"[..];
245        let src = tokio::io::BufReader::new(&mut src);
246        let mut src = CaptureBufRead::wrap(src);
247        src.enable_capture();
248
249        assert_eq!(src.fill_buf().await.unwrap().len(), 12);
250        // We haven't consumed any bytes yet -> must return zero.
251        assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
252
253        src.consume(5);
254        Pin::new(&mut src).discard_capture();
255
256        src.consume(6);
257        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b" World");
258    }
259
260    #[tokio::test]
261    async fn captured_data_accumulates() {
262        let mut src = &b"Hello World!"[..];
263        let src = tokio::io::BufReader::new(&mut src);
264        let mut src = CaptureBufRead::wrap(src);
265        src.enable_capture();
266
267        assert_eq!(src.fill_buf().await.unwrap().len(), 12);
268        // We haven't consumed any bytes yet -> must return zero.
269        assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
270
271        src.consume(5);
272        src.consume(6);
273        assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello World");
274    }
275}