reqwest_client.rs

  1use std::error::Error;
  2use std::sync::{LazyLock, OnceLock};
  3use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
  4
  5use anyhow::anyhow;
  6use bytes::{BufMut, Bytes, BytesMut};
  7use futures::{AsyncRead, TryStreamExt as _};
  8use http_client::{RedirectPolicy, Url, http};
  9use regex::Regex;
 10use reqwest::{
 11    header::{HeaderMap, HeaderValue},
 12    redirect,
 13};
 14use smol::future::FutureExt;
 15
 16const DEFAULT_CAPACITY: usize = 4096;
 17static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
 18static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
 19
 20pub struct ReqwestClient {
 21    client: reqwest::Client,
 22    proxy: Option<Url>,
 23    user_agent: Option<HeaderValue>,
 24    handle: tokio::runtime::Handle,
 25}
 26
 27impl ReqwestClient {
 28    fn builder() -> reqwest::ClientBuilder {
 29        reqwest::Client::builder()
 30            .use_rustls_tls()
 31            .connect_timeout(Duration::from_secs(10))
 32    }
 33
 34    pub fn new() -> Self {
 35        Self::builder()
 36            .build()
 37            .expect("Failed to initialize HTTP client")
 38            .into()
 39    }
 40
 41    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
 42        let mut map = HeaderMap::new();
 43        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 44        let client = Self::builder().default_headers(map).build()?;
 45        Ok(client.into())
 46    }
 47
 48    pub fn proxy_and_user_agent(proxy: Option<Url>, user_agent: &str) -> anyhow::Result<Self> {
 49        let user_agent = HeaderValue::from_str(user_agent)?;
 50
 51        let mut map = HeaderMap::new();
 52        map.insert(http::header::USER_AGENT, user_agent.clone());
 53        let mut client = Self::builder().default_headers(map);
 54        let client_has_proxy;
 55
 56        if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
 57            reqwest::Proxy::all(proxy_url.clone())
 58                .inspect_err(|e| {
 59                    log::error!(
 60                        "Failed to parse proxy URL '{}': {}",
 61                        proxy_url,
 62                        e.source().unwrap_or(&e as &_)
 63                    )
 64                })
 65                .ok()
 66        }) {
 67            // Respect NO_PROXY env var
 68            client = client.proxy(proxy.no_proxy(reqwest::NoProxy::from_env()));
 69            client_has_proxy = true;
 70        } else {
 71            client_has_proxy = false;
 72        };
 73
 74        let client = client
 75            .use_preconfigured_tls(http_client_tls::tls_config())
 76            .build()?;
 77        let mut client: ReqwestClient = client.into();
 78        client.proxy = client_has_proxy.then_some(proxy).flatten();
 79        client.user_agent = Some(user_agent);
 80        Ok(client)
 81    }
 82}
 83
 84impl From<reqwest::Client> for ReqwestClient {
 85    fn from(client: reqwest::Client) -> Self {
 86        let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
 87            log::debug!("no tokio runtime found, creating one for Reqwest...");
 88            let runtime = RUNTIME.get_or_init(|| {
 89                tokio::runtime::Builder::new_multi_thread()
 90                    // Since we now have two executors, let's try to keep our footprint small
 91                    .worker_threads(1)
 92                    .enable_all()
 93                    .build()
 94                    .expect("Failed to initialize HTTP client")
 95            });
 96
 97            runtime.handle().clone()
 98        });
 99        Self {
100            client,
101            handle,
102            proxy: None,
103            user_agent: None,
104        }
105    }
106}
107
108// This struct is essentially a re-implementation of
109// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
110// except outside of Tokio's aegis
111struct StreamReader {
112    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
113    buf: BytesMut,
114    capacity: usize,
115}
116
117impl StreamReader {
118    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
119        Self {
120            reader: Some(reader),
121            buf: BytesMut::new(),
122            capacity: DEFAULT_CAPACITY,
123        }
124    }
125}
126
127impl futures::Stream for StreamReader {
128    type Item = std::io::Result<Bytes>;
129
130    fn poll_next(
131        mut self: Pin<&mut Self>,
132        cx: &mut std::task::Context<'_>,
133    ) -> Poll<Option<Self::Item>> {
134        let mut this = self.as_mut();
135
136        let mut reader = match this.reader.take() {
137            Some(r) => r,
138            None => return Poll::Ready(None),
139        };
140
141        if this.buf.capacity() == 0 {
142            let capacity = this.capacity;
143            this.buf.reserve(capacity);
144        }
145
146        match poll_read_buf(&mut reader, cx, &mut this.buf) {
147            Poll::Pending => Poll::Pending,
148            Poll::Ready(Err(err)) => {
149                self.reader = None;
150
151                Poll::Ready(Some(Err(err)))
152            }
153            Poll::Ready(Ok(0)) => {
154                self.reader = None;
155                Poll::Ready(None)
156            }
157            Poll::Ready(Ok(_)) => {
158                let chunk = this.buf.split();
159                self.reader = Some(reader);
160                Poll::Ready(Some(Ok(chunk.freeze())))
161            }
162        }
163    }
164}
165
166/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
167/// Specialized for this use case
168pub fn poll_read_buf(
169    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
170    cx: &mut std::task::Context<'_>,
171    buf: &mut BytesMut,
172) -> Poll<std::io::Result<usize>> {
173    if !buf.has_remaining_mut() {
174        return Poll::Ready(Ok(0));
175    }
176
177    let n = {
178        let dst = buf.chunk_mut();
179
180        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
181        // transparent wrapper around `[MaybeUninit<u8>]`.
182        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
183        let mut buf = tokio::io::ReadBuf::uninit(dst);
184        let ptr = buf.filled().as_ptr();
185        let unfilled_portion = buf.initialize_unfilled();
186        // SAFETY: Pin projection
187        let io_pin = unsafe { Pin::new_unchecked(io) };
188        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
189
190        // Ensure the pointer does not change from under us
191        assert_eq!(ptr, buf.filled().as_ptr());
192        buf.filled().len()
193    };
194
195    // Safety: This is guaranteed to be the number of initialized (and read)
196    // bytes due to the invariants provided by `ReadBuf::filled`.
197    unsafe {
198        buf.advance_mut(n);
199    }
200
201    Poll::Ready(Ok(n))
202}
203
204fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
205    if let Some(url) = error.url_mut() {
206        if let Some(query) = url.query() {
207            if let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED") {
208                url.set_query(Some(redacted.as_str()));
209            }
210        }
211    }
212    error
213}
214
215impl http_client::HttpClient for ReqwestClient {
216    fn proxy(&self) -> Option<&Url> {
217        self.proxy.as_ref()
218    }
219
220    fn type_name(&self) -> &'static str {
221        type_name::<Self>()
222    }
223
224    fn user_agent(&self) -> Option<&HeaderValue> {
225        self.user_agent.as_ref()
226    }
227
228    fn send(
229        &self,
230        req: http::Request<http_client::AsyncBody>,
231    ) -> futures::future::BoxFuture<
232        'static,
233        anyhow::Result<http_client::Response<http_client::AsyncBody>>,
234    > {
235        let (parts, body) = req.into_parts();
236
237        let mut request = self.client.request(parts.method, parts.uri.to_string());
238        request = request.headers(parts.headers);
239        if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
240            request = request.redirect_policy(match redirect_policy {
241                RedirectPolicy::NoFollow => redirect::Policy::none(),
242                RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
243                RedirectPolicy::FollowAll => redirect::Policy::limited(100),
244            });
245        }
246        let request = request.body(match body.0 {
247            http_client::Inner::Empty => reqwest::Body::default(),
248            http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
249            http_client::Inner::AsyncReader(stream) => {
250                reqwest::Body::wrap_stream(StreamReader::new(stream))
251            }
252        });
253
254        let handle = self.handle.clone();
255        async move {
256            let mut response = handle
257                .spawn(async { request.send().await })
258                .await?
259                .map_err(redact_error)?;
260
261            let headers = mem::take(response.headers_mut());
262            let mut builder = http::Response::builder()
263                .status(response.status().as_u16())
264                .version(response.version());
265            *builder.headers_mut().unwrap() = headers;
266
267            let bytes = response
268                .bytes_stream()
269                .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
270                .into_async_read();
271            let body = http_client::AsyncBody::from_reader(bytes);
272
273            builder.body(body).map_err(|e| anyhow!(e))
274        }
275        .boxed()
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use http_client::{HttpClient, Url};
282
283    use crate::ReqwestClient;
284
285    #[test]
286    fn test_proxy_uri() {
287        let client = ReqwestClient::new();
288        assert_eq!(client.proxy(), None);
289
290        let proxy = Url::parse("http://localhost:10809").unwrap();
291        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
292        assert_eq!(client.proxy(), Some(&proxy));
293
294        let proxy = Url::parse("https://localhost:10809").unwrap();
295        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
296        assert_eq!(client.proxy(), Some(&proxy));
297
298        let proxy = Url::parse("socks4://localhost:10808").unwrap();
299        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
300        assert_eq!(client.proxy(), Some(&proxy));
301
302        let proxy = Url::parse("socks4a://localhost:10808").unwrap();
303        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
304        assert_eq!(client.proxy(), Some(&proxy));
305
306        let proxy = Url::parse("socks5://localhost:10808").unwrap();
307        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
308        assert_eq!(client.proxy(), Some(&proxy));
309
310        let proxy = Url::parse("socks5h://localhost:10808").unwrap();
311        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
312        assert_eq!(client.proxy(), Some(&proxy));
313    }
314
315    #[test]
316    fn test_invalid_proxy_uri() {
317        let proxy = Url::parse("socks://127.0.0.1:20170").unwrap();
318        let client = ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
319        assert!(
320            client.proxy.is_none(),
321            "An invalid proxy URL should add no proxy to the client!"
322        )
323    }
324}