reqwest_client.rs

  1use std::error::Error;
  2use std::sync::{LazyLock, OnceLock};
  3use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
  4
  5use anyhow::anyhow;
  6use bytes::{BufMut, Bytes, BytesMut};
  7use futures::{AsyncRead, TryStreamExt as _};
  8use http_client::{RedirectPolicy, Url, http};
  9use regex::Regex;
 10use reqwest::{
 11    header::{HeaderMap, HeaderValue},
 12    redirect,
 13};
 14use smol::future::FutureExt;
 15
 16const DEFAULT_CAPACITY: usize = 4096;
 17static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
 18static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
 19
 20pub struct ReqwestClient {
 21    client: reqwest::Client,
 22    proxy: Option<Url>,
 23    handle: tokio::runtime::Handle,
 24}
 25
 26impl ReqwestClient {
 27    fn builder() -> reqwest::ClientBuilder {
 28        reqwest::Client::builder()
 29            .use_rustls_tls()
 30            .connect_timeout(Duration::from_secs(10))
 31    }
 32
 33    pub fn new() -> Self {
 34        Self::builder()
 35            .build()
 36            .expect("Failed to initialize HTTP client")
 37            .into()
 38    }
 39
 40    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
 41        let mut map = HeaderMap::new();
 42        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 43        let client = Self::builder().default_headers(map).build()?;
 44        Ok(client.into())
 45    }
 46
 47    pub fn proxy_and_user_agent(proxy: Option<Url>, agent: &str) -> anyhow::Result<Self> {
 48        let mut map = HeaderMap::new();
 49        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 50        let mut client = Self::builder().default_headers(map);
 51        let client_has_proxy;
 52
 53        if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
 54            reqwest::Proxy::all(proxy_url.clone())
 55                .inspect_err(|e| {
 56                    log::error!(
 57                        "Failed to parse proxy URL '{}': {}",
 58                        proxy_url,
 59                        e.source().unwrap_or(&e as &_)
 60                    )
 61                })
 62                .ok()
 63        }) {
 64            client = client.proxy(proxy);
 65            client_has_proxy = true;
 66        } else {
 67            client_has_proxy = false;
 68        };
 69
 70        let client = client
 71            .use_preconfigured_tls(http_client_tls::tls_config())
 72            .build()?;
 73        let mut client: ReqwestClient = client.into();
 74        client.proxy = client_has_proxy.then_some(proxy).flatten();
 75        Ok(client)
 76    }
 77}
 78
 79impl From<reqwest::Client> for ReqwestClient {
 80    fn from(client: reqwest::Client) -> Self {
 81        let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
 82            log::debug!("no tokio runtime found, creating one for Reqwest...");
 83            let runtime = RUNTIME.get_or_init(|| {
 84                tokio::runtime::Builder::new_multi_thread()
 85                    // Since we now have two executors, let's try to keep our footprint small
 86                    .worker_threads(1)
 87                    .enable_all()
 88                    .build()
 89                    .expect("Failed to initialize HTTP client")
 90            });
 91
 92            runtime.handle().clone()
 93        });
 94        Self {
 95            client,
 96            handle,
 97            proxy: None,
 98        }
 99    }
100}
101
102// This struct is essentially a re-implementation of
103// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
104// except outside of Tokio's aegis
105struct StreamReader {
106    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
107    buf: BytesMut,
108    capacity: usize,
109}
110
111impl StreamReader {
112    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
113        Self {
114            reader: Some(reader),
115            buf: BytesMut::new(),
116            capacity: DEFAULT_CAPACITY,
117        }
118    }
119}
120
121impl futures::Stream for StreamReader {
122    type Item = std::io::Result<Bytes>;
123
124    fn poll_next(
125        mut self: Pin<&mut Self>,
126        cx: &mut std::task::Context<'_>,
127    ) -> Poll<Option<Self::Item>> {
128        let mut this = self.as_mut();
129
130        let mut reader = match this.reader.take() {
131            Some(r) => r,
132            None => return Poll::Ready(None),
133        };
134
135        if this.buf.capacity() == 0 {
136            let capacity = this.capacity;
137            this.buf.reserve(capacity);
138        }
139
140        match poll_read_buf(&mut reader, cx, &mut this.buf) {
141            Poll::Pending => Poll::Pending,
142            Poll::Ready(Err(err)) => {
143                self.reader = None;
144
145                Poll::Ready(Some(Err(err)))
146            }
147            Poll::Ready(Ok(0)) => {
148                self.reader = None;
149                Poll::Ready(None)
150            }
151            Poll::Ready(Ok(_)) => {
152                let chunk = this.buf.split();
153                self.reader = Some(reader);
154                Poll::Ready(Some(Ok(chunk.freeze())))
155            }
156        }
157    }
158}
159
160/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
161/// Specialized for this use case
162pub fn poll_read_buf(
163    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
164    cx: &mut std::task::Context<'_>,
165    buf: &mut BytesMut,
166) -> Poll<std::io::Result<usize>> {
167    if !buf.has_remaining_mut() {
168        return Poll::Ready(Ok(0));
169    }
170
171    let n = {
172        let dst = buf.chunk_mut();
173
174        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
175        // transparent wrapper around `[MaybeUninit<u8>]`.
176        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
177        let mut buf = tokio::io::ReadBuf::uninit(dst);
178        let ptr = buf.filled().as_ptr();
179        let unfilled_portion = buf.initialize_unfilled();
180        // SAFETY: Pin projection
181        let io_pin = unsafe { Pin::new_unchecked(io) };
182        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
183
184        // Ensure the pointer does not change from under us
185        assert_eq!(ptr, buf.filled().as_ptr());
186        buf.filled().len()
187    };
188
189    // Safety: This is guaranteed to be the number of initialized (and read)
190    // bytes due to the invariants provided by `ReadBuf::filled`.
191    unsafe {
192        buf.advance_mut(n);
193    }
194
195    Poll::Ready(Ok(n))
196}
197
198fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
199    if let Some(url) = error.url_mut() {
200        if let Some(query) = url.query() {
201            if let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED") {
202                url.set_query(Some(redacted.as_str()));
203            }
204        }
205    }
206    error
207}
208
209impl http_client::HttpClient for ReqwestClient {
210    fn proxy(&self) -> Option<&Url> {
211        self.proxy.as_ref()
212    }
213
214    fn type_name(&self) -> &'static str {
215        type_name::<Self>()
216    }
217
218    fn send(
219        &self,
220        req: http::Request<http_client::AsyncBody>,
221    ) -> futures::future::BoxFuture<
222        'static,
223        anyhow::Result<http_client::Response<http_client::AsyncBody>>,
224    > {
225        let (parts, body) = req.into_parts();
226
227        let mut request = self.client.request(parts.method, parts.uri.to_string());
228        request = request.headers(parts.headers);
229        if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
230            request = request.redirect_policy(match redirect_policy {
231                RedirectPolicy::NoFollow => redirect::Policy::none(),
232                RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
233                RedirectPolicy::FollowAll => redirect::Policy::limited(100),
234            });
235        }
236        let request = request.body(match body.0 {
237            http_client::Inner::Empty => reqwest::Body::default(),
238            http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
239            http_client::Inner::AsyncReader(stream) => {
240                reqwest::Body::wrap_stream(StreamReader::new(stream))
241            }
242        });
243
244        let handle = self.handle.clone();
245        async move {
246            let mut response = handle
247                .spawn(async { request.send().await })
248                .await?
249                .map_err(redact_error)?;
250
251            let headers = mem::take(response.headers_mut());
252            let mut builder = http::Response::builder()
253                .status(response.status().as_u16())
254                .version(response.version());
255            *builder.headers_mut().unwrap() = headers;
256
257            let bytes = response
258                .bytes_stream()
259                .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
260                .into_async_read();
261            let body = http_client::AsyncBody::from_reader(bytes);
262
263            builder.body(body).map_err(|e| anyhow!(e))
264        }
265        .boxed()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use http_client::{HttpClient, Url};
272
273    use crate::ReqwestClient;
274
275    #[test]
276    fn test_proxy_uri() {
277        let client = ReqwestClient::new();
278        assert_eq!(client.proxy(), None);
279
280        let proxy = Url::parse("http://localhost:10809").unwrap();
281        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
282        assert_eq!(client.proxy(), Some(&proxy));
283
284        let proxy = Url::parse("https://localhost:10809").unwrap();
285        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
286        assert_eq!(client.proxy(), Some(&proxy));
287
288        let proxy = Url::parse("socks4://localhost:10808").unwrap();
289        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
290        assert_eq!(client.proxy(), Some(&proxy));
291
292        let proxy = Url::parse("socks4a://localhost:10808").unwrap();
293        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
294        assert_eq!(client.proxy(), Some(&proxy));
295
296        let proxy = Url::parse("socks5://localhost:10808").unwrap();
297        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
298        assert_eq!(client.proxy(), Some(&proxy));
299
300        let proxy = Url::parse("socks5h://localhost:10808").unwrap();
301        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
302        assert_eq!(client.proxy(), Some(&proxy));
303    }
304
305    #[test]
306    fn test_invalid_proxy_uri() {
307        let proxy = Url::parse("socks://127.0.0.1:20170").unwrap();
308        let client = ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
309        assert!(
310            client.proxy.is_none(),
311            "An invalid proxy URL should add no proxy to the client!"
312        )
313    }
314}