reqwest_client.rs

  1use std::sync::{LazyLock, OnceLock};
  2use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
  3
  4use anyhow::anyhow;
  5use bytes::{BufMut, Bytes, BytesMut};
  6use futures::{AsyncRead, TryStreamExt as _};
  7use http_client::{RedirectPolicy, http};
  8use regex::Regex;
  9use reqwest::{
 10    header::{HeaderMap, HeaderValue},
 11    redirect,
 12};
 13use smol::future::FutureExt;
 14
 15const DEFAULT_CAPACITY: usize = 4096;
 16static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
 17static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
 18
 19pub struct ReqwestClient {
 20    client: reqwest::Client,
 21    proxy: Option<http::Uri>,
 22    handle: tokio::runtime::Handle,
 23}
 24
 25impl ReqwestClient {
 26    fn builder() -> reqwest::ClientBuilder {
 27        reqwest::Client::builder()
 28            .use_rustls_tls()
 29            .connect_timeout(Duration::from_secs(10))
 30    }
 31
 32    pub fn new() -> Self {
 33        Self::builder()
 34            .build()
 35            .expect("Failed to initialize HTTP client")
 36            .into()
 37    }
 38
 39    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
 40        let mut map = HeaderMap::new();
 41        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 42        let client = Self::builder().default_headers(map).build()?;
 43        Ok(client.into())
 44    }
 45
 46    pub fn proxy_and_user_agent(proxy: Option<http::Uri>, agent: &str) -> anyhow::Result<Self> {
 47        let mut map = HeaderMap::new();
 48        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 49        let mut client = Self::builder().default_headers(map);
 50        if let Some(proxy) = proxy.clone().and_then(|proxy_uri| {
 51            reqwest::Proxy::all(proxy_uri.to_string())
 52                .inspect_err(|e| log::error!("Failed to parse proxy URI {}: {}", proxy_uri, e))
 53                .ok()
 54        }) {
 55            client = client.proxy(proxy);
 56        }
 57
 58        let client = client
 59            .use_preconfigured_tls(http_client_tls::tls_config())
 60            .build()?;
 61        let mut client: ReqwestClient = client.into();
 62        client.proxy = proxy;
 63        Ok(client)
 64    }
 65}
 66
 67impl From<reqwest::Client> for ReqwestClient {
 68    fn from(client: reqwest::Client) -> Self {
 69        let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
 70            log::debug!("no tokio runtime found, creating one for Reqwest...");
 71            let runtime = RUNTIME.get_or_init(|| {
 72                tokio::runtime::Builder::new_multi_thread()
 73                    // Since we now have two executors, let's try to keep our footprint small
 74                    .worker_threads(1)
 75                    .enable_all()
 76                    .build()
 77                    .expect("Failed to initialize HTTP client")
 78            });
 79
 80            runtime.handle().clone()
 81        });
 82        Self {
 83            client,
 84            handle,
 85            proxy: None,
 86        }
 87    }
 88}
 89
 90// This struct is essentially a re-implementation of
 91// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
 92// except outside of Tokio's aegis
 93struct StreamReader {
 94    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
 95    buf: BytesMut,
 96    capacity: usize,
 97}
 98
 99impl StreamReader {
100    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
101        Self {
102            reader: Some(reader),
103            buf: BytesMut::new(),
104            capacity: DEFAULT_CAPACITY,
105        }
106    }
107}
108
109impl futures::Stream for StreamReader {
110    type Item = std::io::Result<Bytes>;
111
112    fn poll_next(
113        mut self: Pin<&mut Self>,
114        cx: &mut std::task::Context<'_>,
115    ) -> Poll<Option<Self::Item>> {
116        let mut this = self.as_mut();
117
118        let mut reader = match this.reader.take() {
119            Some(r) => r,
120            None => return Poll::Ready(None),
121        };
122
123        if this.buf.capacity() == 0 {
124            let capacity = this.capacity;
125            this.buf.reserve(capacity);
126        }
127
128        match poll_read_buf(&mut reader, cx, &mut this.buf) {
129            Poll::Pending => Poll::Pending,
130            Poll::Ready(Err(err)) => {
131                self.reader = None;
132
133                Poll::Ready(Some(Err(err)))
134            }
135            Poll::Ready(Ok(0)) => {
136                self.reader = None;
137                Poll::Ready(None)
138            }
139            Poll::Ready(Ok(_)) => {
140                let chunk = this.buf.split();
141                self.reader = Some(reader);
142                Poll::Ready(Some(Ok(chunk.freeze())))
143            }
144        }
145    }
146}
147
148/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
149/// Specialized for this use case
150pub fn poll_read_buf(
151    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
152    cx: &mut std::task::Context<'_>,
153    buf: &mut BytesMut,
154) -> Poll<std::io::Result<usize>> {
155    if !buf.has_remaining_mut() {
156        return Poll::Ready(Ok(0));
157    }
158
159    let n = {
160        let dst = buf.chunk_mut();
161
162        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
163        // transparent wrapper around `[MaybeUninit<u8>]`.
164        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
165        let mut buf = tokio::io::ReadBuf::uninit(dst);
166        let ptr = buf.filled().as_ptr();
167        let unfilled_portion = buf.initialize_unfilled();
168        // SAFETY: Pin projection
169        let io_pin = unsafe { Pin::new_unchecked(io) };
170        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
171
172        // Ensure the pointer does not change from under us
173        assert_eq!(ptr, buf.filled().as_ptr());
174        buf.filled().len()
175    };
176
177    // Safety: This is guaranteed to be the number of initialized (and read)
178    // bytes due to the invariants provided by `ReadBuf::filled`.
179    unsafe {
180        buf.advance_mut(n);
181    }
182
183    Poll::Ready(Ok(n))
184}
185
186fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
187    if let Some(url) = error.url_mut() {
188        if let Some(query) = url.query() {
189            if let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED") {
190                url.set_query(Some(redacted.as_str()));
191            }
192        }
193    }
194    error
195}
196
197impl http_client::HttpClient for ReqwestClient {
198    fn proxy(&self) -> Option<&http::Uri> {
199        self.proxy.as_ref()
200    }
201
202    fn type_name(&self) -> &'static str {
203        type_name::<Self>()
204    }
205
206    fn send(
207        &self,
208        req: http::Request<http_client::AsyncBody>,
209    ) -> futures::future::BoxFuture<
210        'static,
211        Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
212    > {
213        let (parts, body) = req.into_parts();
214
215        let mut request = self.client.request(parts.method, parts.uri.to_string());
216        request = request.headers(parts.headers);
217        if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
218            request = request.redirect_policy(match redirect_policy {
219                RedirectPolicy::NoFollow => redirect::Policy::none(),
220                RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
221                RedirectPolicy::FollowAll => redirect::Policy::limited(100),
222            });
223        }
224        let request = request.body(match body.0 {
225            http_client::Inner::Empty => reqwest::Body::default(),
226            http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
227            http_client::Inner::AsyncReader(stream) => {
228                reqwest::Body::wrap_stream(StreamReader::new(stream))
229            }
230        });
231
232        let handle = self.handle.clone();
233        async move {
234            let mut response = handle
235                .spawn(async { request.send().await })
236                .await?
237                .map_err(redact_error)?;
238
239            let headers = mem::take(response.headers_mut());
240            let mut builder = http::Response::builder()
241                .status(response.status().as_u16())
242                .version(response.version());
243            *builder.headers_mut().unwrap() = headers;
244
245            let bytes = response
246                .bytes_stream()
247                .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
248                .into_async_read();
249            let body = http_client::AsyncBody::from_reader(bytes);
250
251            builder.body(body).map_err(|e| anyhow!(e))
252        }
253        .boxed()
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use http_client::{HttpClient, http};
260
261    use crate::ReqwestClient;
262
263    #[test]
264    fn test_proxy_uri() {
265        let client = ReqwestClient::new();
266        assert_eq!(client.proxy(), None);
267
268        let proxy = http::Uri::from_static("http://localhost:10809");
269        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
270        assert_eq!(client.proxy(), Some(&proxy));
271
272        let proxy = http::Uri::from_static("https://localhost:10809");
273        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
274        assert_eq!(client.proxy(), Some(&proxy));
275
276        let proxy = http::Uri::from_static("socks4://localhost:10808");
277        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
278        assert_eq!(client.proxy(), Some(&proxy));
279
280        let proxy = http::Uri::from_static("socks4a://localhost:10808");
281        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
282        assert_eq!(client.proxy(), Some(&proxy));
283
284        let proxy = http::Uri::from_static("socks5://localhost:10808");
285        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
286        assert_eq!(client.proxy(), Some(&proxy));
287
288        let proxy = http::Uri::from_static("socks5h://localhost:10808");
289        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
290        assert_eq!(client.proxy(), Some(&proxy));
291    }
292
293    #[test]
294    #[should_panic]
295    fn test_invalid_proxy_uri() {
296        let proxy = http::Uri::from_static("file:///etc/hosts");
297        ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
298    }
299}