reqwest_client.rs

  1use std::{any::type_name, mem, pin::Pin, sync::OnceLock, task::Poll, time::Duration};
  2
  3use anyhow::anyhow;
  4use bytes::{BufMut, Bytes, BytesMut};
  5use futures::{AsyncRead, TryStreamExt as _};
  6use http_client::{http, RedirectPolicy};
  7use reqwest::{
  8    header::{HeaderMap, HeaderValue},
  9    redirect,
 10};
 11use smol::future::FutureExt;
 12
 13const DEFAULT_CAPACITY: usize = 4096;
 14static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
 15
 16pub struct ReqwestClient {
 17    client: reqwest::Client,
 18    proxy: Option<http::Uri>,
 19    handle: tokio::runtime::Handle,
 20}
 21
 22impl ReqwestClient {
 23    fn builder() -> reqwest::ClientBuilder {
 24        reqwest::Client::builder()
 25            .use_rustls_tls()
 26            .connect_timeout(Duration::from_secs(10))
 27    }
 28
 29    pub fn new() -> Self {
 30        Self::builder()
 31            .build()
 32            .expect("Failed to initialize HTTP client")
 33            .into()
 34    }
 35
 36    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
 37        let mut map = HeaderMap::new();
 38        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 39        let client = Self::builder().default_headers(map).build()?;
 40        Ok(client.into())
 41    }
 42
 43    pub fn proxy_and_user_agent(proxy: Option<http::Uri>, agent: &str) -> anyhow::Result<Self> {
 44        let mut map = HeaderMap::new();
 45        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 46        let mut client = Self::builder().default_headers(map);
 47        if let Some(proxy) = proxy.clone().and_then(|proxy_uri| {
 48            reqwest::Proxy::all(proxy_uri.to_string())
 49                .inspect_err(|e| log::error!("Failed to parse proxy URI {}: {}", proxy_uri, e))
 50                .ok()
 51        }) {
 52            client = client.proxy(proxy);
 53        }
 54
 55        let client = client
 56            .use_preconfigured_tls(http_client::tls_config())
 57            .build()?;
 58        let mut client: ReqwestClient = client.into();
 59        client.proxy = proxy;
 60        Ok(client)
 61    }
 62}
 63
 64impl From<reqwest::Client> for ReqwestClient {
 65    fn from(client: reqwest::Client) -> Self {
 66        let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
 67            log::debug!("no tokio runtime found, creating one for Reqwest...");
 68            let runtime = RUNTIME.get_or_init(|| {
 69                tokio::runtime::Builder::new_multi_thread()
 70                    // Since we now have two executors, let's try to keep our footprint small
 71                    .worker_threads(1)
 72                    .enable_all()
 73                    .build()
 74                    .expect("Failed to initialize HTTP client")
 75            });
 76
 77            runtime.handle().clone()
 78        });
 79        Self {
 80            client,
 81            handle,
 82            proxy: None,
 83        }
 84    }
 85}
 86
 87// This struct is essentially a re-implementation of
 88// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
 89// except outside of Tokio's aegis
 90struct StreamReader {
 91    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
 92    buf: BytesMut,
 93    capacity: usize,
 94}
 95
 96impl StreamReader {
 97    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
 98        Self {
 99            reader: Some(reader),
100            buf: BytesMut::new(),
101            capacity: DEFAULT_CAPACITY,
102        }
103    }
104}
105
106impl futures::Stream for StreamReader {
107    type Item = std::io::Result<Bytes>;
108
109    fn poll_next(
110        mut self: Pin<&mut Self>,
111        cx: &mut std::task::Context<'_>,
112    ) -> Poll<Option<Self::Item>> {
113        let mut this = self.as_mut();
114
115        let mut reader = match this.reader.take() {
116            Some(r) => r,
117            None => return Poll::Ready(None),
118        };
119
120        if this.buf.capacity() == 0 {
121            let capacity = this.capacity;
122            this.buf.reserve(capacity);
123        }
124
125        match poll_read_buf(&mut reader, cx, &mut this.buf) {
126            Poll::Pending => Poll::Pending,
127            Poll::Ready(Err(err)) => {
128                self.reader = None;
129
130                Poll::Ready(Some(Err(err)))
131            }
132            Poll::Ready(Ok(0)) => {
133                self.reader = None;
134                Poll::Ready(None)
135            }
136            Poll::Ready(Ok(_)) => {
137                let chunk = this.buf.split();
138                self.reader = Some(reader);
139                Poll::Ready(Some(Ok(chunk.freeze())))
140            }
141        }
142    }
143}
144
145/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
146/// Specialized for this use case
147pub fn poll_read_buf(
148    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
149    cx: &mut std::task::Context<'_>,
150    buf: &mut BytesMut,
151) -> Poll<std::io::Result<usize>> {
152    if !buf.has_remaining_mut() {
153        return Poll::Ready(Ok(0));
154    }
155
156    let n = {
157        let dst = buf.chunk_mut();
158
159        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
160        // transparent wrapper around `[MaybeUninit<u8>]`.
161        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
162        let mut buf = tokio::io::ReadBuf::uninit(dst);
163        let ptr = buf.filled().as_ptr();
164        let unfilled_portion = buf.initialize_unfilled();
165        // SAFETY: Pin projection
166        let io_pin = unsafe { Pin::new_unchecked(io) };
167        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
168
169        // Ensure the pointer does not change from under us
170        assert_eq!(ptr, buf.filled().as_ptr());
171        buf.filled().len()
172    };
173
174    // Safety: This is guaranteed to be the number of initialized (and read)
175    // bytes due to the invariants provided by `ReadBuf::filled`.
176    unsafe {
177        buf.advance_mut(n);
178    }
179
180    Poll::Ready(Ok(n))
181}
182
183impl http_client::HttpClient for ReqwestClient {
184    fn proxy(&self) -> Option<&http::Uri> {
185        self.proxy.as_ref()
186    }
187
188    fn type_name(&self) -> &'static str {
189        type_name::<Self>()
190    }
191
192    fn send(
193        &self,
194        req: http::Request<http_client::AsyncBody>,
195    ) -> futures::future::BoxFuture<
196        'static,
197        Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
198    > {
199        let (parts, body) = req.into_parts();
200
201        let mut request = self.client.request(parts.method, parts.uri.to_string());
202        request = request.headers(parts.headers);
203        if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
204            request = request.redirect_policy(match redirect_policy {
205                RedirectPolicy::NoFollow => redirect::Policy::none(),
206                RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
207                RedirectPolicy::FollowAll => redirect::Policy::limited(100),
208            });
209        }
210        let request = request.body(match body.0 {
211            http_client::Inner::Empty => reqwest::Body::default(),
212            http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
213            http_client::Inner::AsyncReader(stream) => {
214                reqwest::Body::wrap_stream(StreamReader::new(stream))
215            }
216        });
217
218        let handle = self.handle.clone();
219        async move {
220            let mut response = handle.spawn(async { request.send().await }).await??;
221
222            let headers = mem::take(response.headers_mut());
223            let mut builder = http::Response::builder()
224                .status(response.status().as_u16())
225                .version(response.version());
226            *builder.headers_mut().unwrap() = headers;
227
228            let bytes = response
229                .bytes_stream()
230                .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
231                .into_async_read();
232            let body = http_client::AsyncBody::from_reader(bytes);
233
234            builder.body(body).map_err(|e| anyhow!(e))
235        }
236        .boxed()
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use http_client::{http, HttpClient};
243
244    use crate::ReqwestClient;
245
246    #[test]
247    fn test_proxy_uri() {
248        let client = ReqwestClient::new();
249        assert_eq!(client.proxy(), None);
250
251        let proxy = http::Uri::from_static("http://localhost:10809");
252        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
253        assert_eq!(client.proxy(), Some(&proxy));
254
255        let proxy = http::Uri::from_static("https://localhost:10809");
256        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
257        assert_eq!(client.proxy(), Some(&proxy));
258
259        let proxy = http::Uri::from_static("socks4://localhost:10808");
260        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
261        assert_eq!(client.proxy(), Some(&proxy));
262
263        let proxy = http::Uri::from_static("socks4a://localhost:10808");
264        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
265        assert_eq!(client.proxy(), Some(&proxy));
266
267        let proxy = http::Uri::from_static("socks5://localhost:10808");
268        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
269        assert_eq!(client.proxy(), Some(&proxy));
270
271        let proxy = http::Uri::from_static("socks5h://localhost:10808");
272        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
273        assert_eq!(client.proxy(), Some(&proxy));
274    }
275
276    #[test]
277    #[should_panic]
278    fn test_invalid_proxy_uri() {
279        let proxy = http::Uri::from_static("file:///etc/hosts");
280        ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
281    }
282}