reqwest_client.rs

  1use std::error::Error;
  2use std::sync::{LazyLock, OnceLock};
  3use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
  4
  5use anyhow::anyhow;
  6use bytes::{BufMut, Bytes, BytesMut};
  7use futures::{AsyncRead, FutureExt as _, TryStreamExt as _};
  8use http_client::{RedirectPolicy, Url, http};
  9use regex::Regex;
 10use reqwest::{
 11    header::{HeaderMap, HeaderValue},
 12    redirect,
 13};
 14
 15const DEFAULT_CAPACITY: usize = 4096;
 16static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
 17static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
 18
 19pub struct ReqwestClient {
 20    client: reqwest::Client,
 21    proxy: Option<Url>,
 22    user_agent: Option<HeaderValue>,
 23    handle: tokio::runtime::Handle,
 24}
 25
 26impl ReqwestClient {
 27    fn builder() -> reqwest::ClientBuilder {
 28        reqwest::Client::builder()
 29            .use_rustls_tls()
 30            .connect_timeout(Duration::from_secs(10))
 31    }
 32
 33    pub fn new() -> Self {
 34        Self::builder()
 35            .build()
 36            .expect("Failed to initialize HTTP client")
 37            .into()
 38    }
 39
 40    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
 41        let mut map = HeaderMap::new();
 42        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 43        let client = Self::builder().default_headers(map).build()?;
 44        Ok(client.into())
 45    }
 46
 47    pub fn proxy_and_user_agent(proxy: Option<Url>, user_agent: &str) -> anyhow::Result<Self> {
 48        let user_agent = HeaderValue::from_str(user_agent)?;
 49
 50        let mut map = HeaderMap::new();
 51        map.insert(http::header::USER_AGENT, user_agent.clone());
 52        let mut client = Self::builder().default_headers(map);
 53        let client_has_proxy;
 54
 55        if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
 56            reqwest::Proxy::all(proxy_url.clone())
 57                .inspect_err(|e| {
 58                    log::error!(
 59                        "Failed to parse proxy URL '{}': {}",
 60                        proxy_url,
 61                        e.source().unwrap_or(&e as &_)
 62                    )
 63                })
 64                .ok()
 65        }) {
 66            // Respect NO_PROXY env var
 67            client = client.proxy(proxy.no_proxy(reqwest::NoProxy::from_env()));
 68            client_has_proxy = true;
 69        } else {
 70            client_has_proxy = false;
 71        };
 72
 73        let client = client
 74            .use_preconfigured_tls(http_client_tls::tls_config())
 75            .build()?;
 76        let mut client: ReqwestClient = client.into();
 77        client.proxy = client_has_proxy.then_some(proxy).flatten();
 78        client.user_agent = Some(user_agent);
 79        Ok(client)
 80    }
 81}
 82
 83impl From<reqwest::Client> for ReqwestClient {
 84    fn from(client: reqwest::Client) -> Self {
 85        let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
 86            log::debug!("no tokio runtime found, creating one for Reqwest...");
 87            let runtime = RUNTIME.get_or_init(|| {
 88                tokio::runtime::Builder::new_multi_thread()
 89                    // Since we now have two executors, let's try to keep our footprint small
 90                    .worker_threads(1)
 91                    .enable_all()
 92                    .build()
 93                    .expect("Failed to initialize HTTP client")
 94            });
 95
 96            runtime.handle().clone()
 97        });
 98        Self {
 99            client,
100            handle,
101            proxy: None,
102            user_agent: None,
103        }
104    }
105}
106
107// This struct is essentially a re-implementation of
108// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
109// except outside of Tokio's aegis
110struct StreamReader {
111    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
112    buf: BytesMut,
113    capacity: usize,
114}
115
116impl StreamReader {
117    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
118        Self {
119            reader: Some(reader),
120            buf: BytesMut::new(),
121            capacity: DEFAULT_CAPACITY,
122        }
123    }
124}
125
126impl futures::Stream for StreamReader {
127    type Item = std::io::Result<Bytes>;
128
129    fn poll_next(
130        mut self: Pin<&mut Self>,
131        cx: &mut std::task::Context<'_>,
132    ) -> Poll<Option<Self::Item>> {
133        let mut this = self.as_mut();
134
135        let mut reader = match this.reader.take() {
136            Some(r) => r,
137            None => return Poll::Ready(None),
138        };
139
140        if this.buf.capacity() == 0 {
141            let capacity = this.capacity;
142            this.buf.reserve(capacity);
143        }
144
145        match poll_read_buf(&mut reader, cx, &mut this.buf) {
146            Poll::Pending => Poll::Pending,
147            Poll::Ready(Err(err)) => {
148                self.reader = None;
149
150                Poll::Ready(Some(Err(err)))
151            }
152            Poll::Ready(Ok(0)) => {
153                self.reader = None;
154                Poll::Ready(None)
155            }
156            Poll::Ready(Ok(_)) => {
157                let chunk = this.buf.split();
158                self.reader = Some(reader);
159                Poll::Ready(Some(Ok(chunk.freeze())))
160            }
161        }
162    }
163}
164
165/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
166/// Specialized for this use case
167pub fn poll_read_buf(
168    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
169    cx: &mut std::task::Context<'_>,
170    buf: &mut BytesMut,
171) -> Poll<std::io::Result<usize>> {
172    if !buf.has_remaining_mut() {
173        return Poll::Ready(Ok(0));
174    }
175
176    let n = {
177        let dst = buf.chunk_mut();
178
179        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
180        // transparent wrapper around `[MaybeUninit<u8>]`.
181        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
182        let mut buf = tokio::io::ReadBuf::uninit(dst);
183        let ptr = buf.filled().as_ptr();
184        let unfilled_portion = buf.initialize_unfilled();
185        // SAFETY: Pin projection
186        let io_pin = unsafe { Pin::new_unchecked(io) };
187        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
188
189        // Ensure the pointer does not change from under us
190        assert_eq!(ptr, buf.filled().as_ptr());
191        buf.filled().len()
192    };
193
194    // Safety: This is guaranteed to be the number of initialized (and read)
195    // bytes due to the invariants provided by `ReadBuf::filled`.
196    unsafe {
197        buf.advance_mut(n);
198    }
199
200    Poll::Ready(Ok(n))
201}
202
203fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
204    if let Some(url) = error.url_mut()
205        && let Some(query) = url.query()
206            && let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED") {
207                url.set_query(Some(redacted.as_str()));
208            }
209    error
210}
211
212impl http_client::HttpClient for ReqwestClient {
213    fn proxy(&self) -> Option<&Url> {
214        self.proxy.as_ref()
215    }
216
217    fn type_name(&self) -> &'static str {
218        type_name::<Self>()
219    }
220
221    fn user_agent(&self) -> Option<&HeaderValue> {
222        self.user_agent.as_ref()
223    }
224
225    fn send(
226        &self,
227        req: http::Request<http_client::AsyncBody>,
228    ) -> futures::future::BoxFuture<
229        'static,
230        anyhow::Result<http_client::Response<http_client::AsyncBody>>,
231    > {
232        let (parts, body) = req.into_parts();
233
234        let mut request = self.client.request(parts.method, parts.uri.to_string());
235        request = request.headers(parts.headers);
236        if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
237            request = request.redirect_policy(match redirect_policy {
238                RedirectPolicy::NoFollow => redirect::Policy::none(),
239                RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
240                RedirectPolicy::FollowAll => redirect::Policy::limited(100),
241            });
242        }
243        let request = request.body(match body.0 {
244            http_client::Inner::Empty => reqwest::Body::default(),
245            http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
246            http_client::Inner::AsyncReader(stream) => {
247                reqwest::Body::wrap_stream(StreamReader::new(stream))
248            }
249        });
250
251        let handle = self.handle.clone();
252        async move {
253            let mut response = handle
254                .spawn(async { request.send().await })
255                .await?
256                .map_err(redact_error)?;
257
258            let headers = mem::take(response.headers_mut());
259            let mut builder = http::Response::builder()
260                .status(response.status().as_u16())
261                .version(response.version());
262            *builder.headers_mut().unwrap() = headers;
263
264            let bytes = response
265                .bytes_stream()
266                .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
267                .into_async_read();
268            let body = http_client::AsyncBody::from_reader(bytes);
269
270            builder.body(body).map_err(|e| anyhow!(e))
271        }
272        .boxed()
273    }
274
275    fn send_multipart_form<'a>(
276        &'a self,
277        url: &str,
278        form: reqwest::multipart::Form,
279    ) -> futures::future::BoxFuture<'a, anyhow::Result<http_client::Response<http_client::AsyncBody>>>
280    {
281        let response = self.client.post(url).multipart(form).send();
282        self.handle
283            .spawn(async move {
284                let response = response.await?;
285                let mut builder = http::response::Builder::new().status(response.status());
286                for (k, v) in response.headers() {
287                    builder = builder.header(k, v)
288                }
289                Ok(builder.body(response.bytes().await?.into())?)
290            })
291            .map(|e| e?)
292            .boxed()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use http_client::{HttpClient, Url};
299
300    use crate::ReqwestClient;
301
302    #[test]
303    fn test_proxy_uri() {
304        let client = ReqwestClient::new();
305        assert_eq!(client.proxy(), None);
306
307        let proxy = Url::parse("http://localhost:10809").unwrap();
308        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
309        assert_eq!(client.proxy(), Some(&proxy));
310
311        let proxy = Url::parse("https://localhost:10809").unwrap();
312        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
313        assert_eq!(client.proxy(), Some(&proxy));
314
315        let proxy = Url::parse("socks4://localhost:10808").unwrap();
316        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
317        assert_eq!(client.proxy(), Some(&proxy));
318
319        let proxy = Url::parse("socks4a://localhost:10808").unwrap();
320        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
321        assert_eq!(client.proxy(), Some(&proxy));
322
323        let proxy = Url::parse("socks5://localhost:10808").unwrap();
324        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
325        assert_eq!(client.proxy(), Some(&proxy));
326
327        let proxy = Url::parse("socks5h://localhost:10808").unwrap();
328        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
329        assert_eq!(client.proxy(), Some(&proxy));
330    }
331
332    #[test]
333    fn test_invalid_proxy_uri() {
334        let proxy = Url::parse("socks://127.0.0.1:20170").unwrap();
335        let client = ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
336        assert!(
337            client.proxy.is_none(),
338            "An invalid proxy URL should add no proxy to the client!"
339        )
340    }
341}