reqwest_client.rs

  1use std::{any::type_name, borrow::Cow, io::Read, mem, pin::Pin, sync::OnceLock, task::Poll};
  2
  3use anyhow::anyhow;
  4use bytes::{BufMut, Bytes, BytesMut};
  5use futures::{AsyncRead, TryStreamExt as _};
  6use http_client::{http, ReadTimeout, RedirectPolicy};
  7use reqwest::{
  8    header::{HeaderMap, HeaderValue},
  9    redirect,
 10};
 11use smol::future::FutureExt;
 12
 13const DEFAULT_CAPACITY: usize = 4096;
 14static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
 15
 16pub struct ReqwestClient {
 17    client: reqwest::Client,
 18    proxy: Option<http::Uri>,
 19    handle: tokio::runtime::Handle,
 20}
 21
 22impl ReqwestClient {
 23    pub fn new() -> Self {
 24        reqwest::Client::builder()
 25            .use_rustls_tls()
 26            .build()
 27            .expect("Failed to initialize HTTP client")
 28            .into()
 29    }
 30
 31    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
 32        let mut map = HeaderMap::new();
 33        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 34        let client = reqwest::Client::builder()
 35            .default_headers(map)
 36            .use_rustls_tls()
 37            .build()?;
 38        Ok(client.into())
 39    }
 40
 41    pub fn proxy_and_user_agent(proxy: Option<http::Uri>, agent: &str) -> anyhow::Result<Self> {
 42        let mut map = HeaderMap::new();
 43        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 44        let mut client = reqwest::Client::builder()
 45            .use_rustls_tls()
 46            .default_headers(map);
 47        if let Some(proxy) = proxy.clone() {
 48            client = client.proxy(reqwest::Proxy::all(proxy.to_string())?);
 49        }
 50        let client = client.build()?;
 51        let mut client: ReqwestClient = client.into();
 52        client.proxy = proxy;
 53        Ok(client)
 54    }
 55}
 56
 57impl From<reqwest::Client> for ReqwestClient {
 58    fn from(client: reqwest::Client) -> Self {
 59        let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
 60            log::info!("no tokio runtime found, creating one for Reqwest...");
 61            let runtime = RUNTIME.get_or_init(|| {
 62                tokio::runtime::Builder::new_multi_thread()
 63                    // Since we now have two executors, let's try to keep our footprint small
 64                    .worker_threads(1)
 65                    .enable_all()
 66                    .build()
 67                    .expect("Failed to initialize HTTP client")
 68            });
 69
 70            runtime.handle().clone()
 71        });
 72        Self {
 73            client,
 74            handle,
 75            proxy: None,
 76        }
 77    }
 78}
 79
 80// This struct is essentially a re-implementation of
 81// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
 82// except outside of Tokio's aegis
 83struct StreamReader {
 84    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
 85    buf: BytesMut,
 86    capacity: usize,
 87}
 88
 89impl StreamReader {
 90    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
 91        Self {
 92            reader: Some(reader),
 93            buf: BytesMut::new(),
 94            capacity: DEFAULT_CAPACITY,
 95        }
 96    }
 97}
 98
 99impl futures::Stream for StreamReader {
100    type Item = std::io::Result<Bytes>;
101
102    fn poll_next(
103        mut self: Pin<&mut Self>,
104        cx: &mut std::task::Context<'_>,
105    ) -> Poll<Option<Self::Item>> {
106        let mut this = self.as_mut();
107
108        let mut reader = match this.reader.take() {
109            Some(r) => r,
110            None => return Poll::Ready(None),
111        };
112
113        if this.buf.capacity() == 0 {
114            let capacity = this.capacity;
115            this.buf.reserve(capacity);
116        }
117
118        match poll_read_buf(&mut reader, cx, &mut this.buf) {
119            Poll::Pending => Poll::Pending,
120            Poll::Ready(Err(err)) => {
121                self.reader = None;
122
123                Poll::Ready(Some(Err(err)))
124            }
125            Poll::Ready(Ok(0)) => {
126                self.reader = None;
127                Poll::Ready(None)
128            }
129            Poll::Ready(Ok(_)) => {
130                let chunk = this.buf.split();
131                self.reader = Some(reader);
132                Poll::Ready(Some(Ok(chunk.freeze())))
133            }
134        }
135    }
136}
137
138/// Implementation from https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html
139/// Specialized for this use case
140pub fn poll_read_buf(
141    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
142    cx: &mut std::task::Context<'_>,
143    buf: &mut BytesMut,
144) -> Poll<std::io::Result<usize>> {
145    if !buf.has_remaining_mut() {
146        return Poll::Ready(Ok(0));
147    }
148
149    let n = {
150        let dst = buf.chunk_mut();
151
152        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
153        // transparent wrapper around `[MaybeUninit<u8>]`.
154        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
155        let mut buf = tokio::io::ReadBuf::uninit(dst);
156        let ptr = buf.filled().as_ptr();
157        let unfilled_portion = buf.initialize_unfilled();
158        // SAFETY: Pin projection
159        let io_pin = unsafe { Pin::new_unchecked(io) };
160        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
161
162        // Ensure the pointer does not change from under us
163        assert_eq!(ptr, buf.filled().as_ptr());
164        buf.filled().len()
165    };
166
167    // Safety: This is guaranteed to be the number of initialized (and read)
168    // bytes due to the invariants provided by `ReadBuf::filled`.
169    unsafe {
170        buf.advance_mut(n);
171    }
172
173    Poll::Ready(Ok(n))
174}
175
176struct SyncReader {
177    cursor: Option<std::io::Cursor<Cow<'static, [u8]>>>,
178}
179
180impl SyncReader {
181    fn new(cursor: std::io::Cursor<Cow<'static, [u8]>>) -> Self {
182        Self {
183            cursor: Some(cursor),
184        }
185    }
186}
187
188impl futures::stream::Stream for SyncReader {
189    type Item = Result<Bytes, std::io::Error>;
190
191    fn poll_next(
192        mut self: std::pin::Pin<&mut Self>,
193        _cx: &mut std::task::Context<'_>,
194    ) -> std::task::Poll<Option<Self::Item>> {
195        let Some(mut cursor) = self.cursor.take() else {
196            return Poll::Ready(None);
197        };
198
199        let mut buf = Vec::new();
200        match cursor.read_to_end(&mut buf) {
201            Ok(_) => {
202                return Poll::Ready(Some(Ok(Bytes::from(buf))));
203            }
204            Err(e) => return Poll::Ready(Some(Err(e))),
205        }
206    }
207}
208
209impl http_client::HttpClient for ReqwestClient {
210    fn proxy(&self) -> Option<&http::Uri> {
211        self.proxy.as_ref()
212    }
213
214    fn type_name(&self) -> &'static str {
215        type_name::<Self>()
216    }
217
218    fn send(
219        &self,
220        req: http::Request<http_client::AsyncBody>,
221    ) -> futures::future::BoxFuture<
222        'static,
223        Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
224    > {
225        let (parts, body) = req.into_parts();
226
227        let mut request = self.client.request(parts.method, parts.uri.to_string());
228        request = request.headers(parts.headers);
229        if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
230            request = request.redirect_policy(match redirect_policy {
231                RedirectPolicy::NoFollow => redirect::Policy::none(),
232                RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
233                RedirectPolicy::FollowAll => redirect::Policy::limited(100),
234            });
235        }
236        if let Some(ReadTimeout(timeout)) = parts.extensions.get::<ReadTimeout>() {
237            request = request.timeout(*timeout);
238        }
239        let request = request.body(match body.0 {
240            http_client::Inner::Empty => reqwest::Body::default(),
241            http_client::Inner::SyncReader(cursor) => {
242                reqwest::Body::wrap_stream(SyncReader::new(cursor))
243            }
244            http_client::Inner::AsyncReader(stream) => {
245                reqwest::Body::wrap_stream(StreamReader::new(stream))
246            }
247        });
248
249        let handle = self.handle.clone();
250        async move {
251            let mut response = handle.spawn(async { request.send().await }).await??;
252
253            let headers = mem::take(response.headers_mut());
254            let mut builder = http::Response::builder()
255                .status(response.status().as_u16())
256                .version(response.version());
257            *builder.headers_mut().unwrap() = headers;
258
259            let bytes = response
260                .bytes_stream()
261                .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
262                .into_async_read();
263            let body = http_client::AsyncBody::from_reader(bytes);
264
265            builder.body(body).map_err(|e| anyhow!(e))
266        }
267        .boxed()
268    }
269}