1// Copyright (c) 2024 Jonas Schäfer <jonas@zombofant.net>
2//
3// This Source Code Form is subject to the terms of the Mozilla Public
4// License, v. 2.0. If a copy of the MPL was not distributed with this
5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
6
7//! Small helper struct to capture data read from an AsyncBufRead.
8
9use core::pin::Pin;
10use core::task::{Context, Poll};
11use std::io::{self, IoSlice};
12
13use futures::ready;
14
15use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
16
17use super::LogXsoBuf;
18
19pin_project_lite::pin_project! {
20 /// Wrapper around [`AsyncBufRead`] which stores bytes which have been
21 /// read in an internal vector for later inspection.
22 ///
23 /// This struct implements [`AsyncRead`] and [`AsyncBufRead`] and passes
24 /// read requests down to the wrapped [`AsyncBufRead`].
25 ///
26 /// After capturing has been enabled using [`Self::enable_capture`], any
27 /// data which is read via the struct will be stored in an internal buffer
28 /// and can be extracted with [`Self::take_capture`] or discarded using
29 /// [`Self::discard_capture`].
30 ///
31 /// This can be used to log data which is being read from a source.
32 ///
33 /// In addition, this struct implements [`AsyncWrite`] if and only if `T`
34 /// implements [`AsyncWrite`]. Writing is unaffected by capturing and is
35 /// implemented solely for convenience purposes (to allow duplex usage
36 /// of a wrapped I/O object).
37 pub(super) struct CaptureBufRead<T> {
38 #[pin]
39 inner: T,
40 buf: Option<(Vec<u8>, usize)>,
41 }
42}
43
44impl<T> CaptureBufRead<T> {
45 /// Wrap a given [`AsyncBufRead`].
46 ///
47 /// Note that capturing of data which is being read is disabled by default
48 /// and needs to be enabled using [`Self::enable_capture`].
49 pub fn wrap(inner: T) -> Self {
50 Self { inner, buf: None }
51 }
52
53 /// Extract the inner [`AsyncBufRead`] and discard the capture buffer.
54 pub fn into_inner(self) -> T {
55 self.inner
56 }
57
58 /// Obtain a reference to the inner [`AsyncBufRead`].
59 pub fn inner(&self) -> &T {
60 &self.inner
61 }
62
63 /// Enable capturing of read data into the inner buffer.
64 ///
65 /// Any data which is read from now on will be copied into the internal
66 /// buffer. That buffer will grow indefinitely until calls to
67 /// [`Self::take_capture`] or [`Self::discard_capture`].
68 pub fn enable_capture(&mut self) {
69 self.buf = Some((Vec::new(), 0));
70 }
71
72 /// Discard the current buffer data, if any.
73 ///
74 /// Further data which is read will be captured again.
75 pub(super) fn discard_capture(self: Pin<&mut Self>) {
76 let this = self.project();
77 if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
78 buf.drain(..*consumed_up_to);
79 *consumed_up_to = 0;
80 }
81 }
82
83 /// Take the currently captured data out of the inner buffer.
84 ///
85 /// Returns `None` unless capturing has been enabled using
86 /// [`Self::enable_capture`].
87 pub(super) fn take_capture(self: Pin<&mut Self>) -> Option<Vec<u8>> {
88 let this = self.project();
89 let (buf, consumed_up_to) = this.buf.as_mut()?;
90 let result = buf.drain(..*consumed_up_to).collect();
91 *consumed_up_to = 0;
92 Some(result)
93 }
94}
95
96impl<T: AsyncRead> AsyncRead for CaptureBufRead<T> {
97 fn poll_read(
98 self: Pin<&mut Self>,
99 cx: &mut Context,
100 read_buf: &mut ReadBuf,
101 ) -> Poll<io::Result<()>> {
102 let this = self.project();
103 let prev_len = read_buf.filled().len();
104 let result = ready!(this.inner.poll_read(cx, read_buf));
105 if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
106 buf.truncate(*consumed_up_to);
107 buf.extend(&read_buf.filled()[prev_len..]);
108 *consumed_up_to = buf.len();
109 }
110 Poll::Ready(result)
111 }
112}
113
114impl<T: AsyncBufRead> AsyncBufRead for CaptureBufRead<T> {
115 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
116 let this = self.project();
117 let result = ready!(this.inner.poll_fill_buf(cx))?;
118 if let Some((buf, consumed_up_to)) = this.buf.as_mut() {
119 buf.truncate(*consumed_up_to);
120 buf.extend(result);
121 }
122 Poll::Ready(Ok(result))
123 }
124
125 fn consume(self: Pin<&mut Self>, amt: usize) {
126 let this = self.project();
127 this.inner.consume(amt);
128 if let Some((_, consumed_up_to)) = this.buf.as_mut() {
129 // Increase the amount of data to preserve.
130 *consumed_up_to += amt;
131 }
132 }
133}
134
135impl<T: AsyncWrite> AsyncWrite for CaptureBufRead<T> {
136 fn poll_write(
137 self: Pin<&mut Self>,
138 cx: &mut Context<'_>,
139 buf: &[u8],
140 ) -> Poll<io::Result<usize>> {
141 self.project().inner.poll_write(cx, buf)
142 }
143
144 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145 self.project().inner.poll_shutdown(cx)
146 }
147
148 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
149 self.project().inner.poll_flush(cx)
150 }
151
152 fn is_write_vectored(&self) -> bool {
153 self.inner.is_write_vectored()
154 }
155
156 fn poll_write_vectored(
157 self: Pin<&mut Self>,
158 cx: &mut Context,
159 bufs: &[IoSlice],
160 ) -> Poll<io::Result<usize>> {
161 self.project().inner.poll_write_vectored(cx, bufs)
162 }
163}
164
165/// Return true if logging via [`log_recv`] or [`log_send`] might be visible
166/// to the user.
167pub(super) fn log_enabled() -> bool {
168 log::log_enabled!(log::Level::Trace)
169}
170
171/// Log received data.
172///
173/// `err` is an error which may be logged alongside the received data.
174/// `capture` is the data which has been received and which should be logged.
175/// If built with the `syntax-highlighting` feature, `capture` data will be
176/// logged with XML syntax highlighting.
177///
178/// If both `err` and `capture` are None, nothing will be logged.
179pub(super) fn log_recv(err: Option<&xmpp_parsers::Error>, capture: Option<Vec<u8>>) {
180 match err {
181 Some(err) => match capture {
182 Some(capture) => {
183 log::trace!("RECV (error: {}) {}", err, LogXsoBuf(&capture));
184 }
185 None => {
186 log::trace!("RECV (error: {}) [data capture disabled]", err);
187 }
188 },
189 None => {
190 if let Some(capture) = capture {
191 log::trace!("RECV (ok) {}", LogXsoBuf(&capture));
192 }
193 }
194 }
195}
196
197/// Log sent data.
198///
199/// If built with the `syntax-highlighting` feature, `data` data will be
200/// logged with XML syntax highlighting.
201pub(super) fn log_send(data: &[u8]) {
202 log::trace!("SEND {}", LogXsoBuf(data));
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 use tokio::io::{AsyncBufReadExt, AsyncReadExt};
210
211 #[tokio::test]
212 async fn captures_data_read_via_async_read() {
213 let mut src = &b"Hello World!"[..];
214 let src = tokio::io::BufReader::new(&mut src);
215 let mut src = CaptureBufRead::wrap(src);
216 src.enable_capture();
217
218 let mut dst = [0u8; 8];
219 assert_eq!(src.read(&mut dst[..]).await.unwrap(), 8);
220 assert_eq!(&dst, b"Hello Wo");
221 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello Wo");
222 }
223
224 #[tokio::test]
225 async fn captures_data_read_via_async_buf_read() {
226 let mut src = &b"Hello World!"[..];
227 let src = tokio::io::BufReader::new(&mut src);
228 let mut src = CaptureBufRead::wrap(src);
229 src.enable_capture();
230
231 assert_eq!(src.fill_buf().await.unwrap().len(), 12);
232 // We haven't consumed any bytes yet -> must return zero.
233 assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
234
235 src.consume(5);
236 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello");
237
238 src.consume(6);
239 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b" World");
240 }
241
242 #[tokio::test]
243 async fn discard_capture_drops_consumed_data() {
244 let mut src = &b"Hello World!"[..];
245 let src = tokio::io::BufReader::new(&mut src);
246 let mut src = CaptureBufRead::wrap(src);
247 src.enable_capture();
248
249 assert_eq!(src.fill_buf().await.unwrap().len(), 12);
250 // We haven't consumed any bytes yet -> must return zero.
251 assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
252
253 src.consume(5);
254 Pin::new(&mut src).discard_capture();
255
256 src.consume(6);
257 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b" World");
258 }
259
260 #[tokio::test]
261 async fn captured_data_accumulates() {
262 let mut src = &b"Hello World!"[..];
263 let src = tokio::io::BufReader::new(&mut src);
264 let mut src = CaptureBufRead::wrap(src);
265 src.enable_capture();
266
267 assert_eq!(src.fill_buf().await.unwrap().len(), 12);
268 // We haven't consumed any bytes yet -> must return zero.
269 assert_eq!(Pin::new(&mut src).take_capture().unwrap().len(), 0);
270
271 src.consume(5);
272 src.consume(6);
273 assert_eq!(Pin::new(&mut src).take_capture().unwrap(), b"Hello World");
274 }
275}