reqwest_client.rs

  1use std::{borrow::Cow, io::Read, pin::Pin, task::Poll};
  2
  3use anyhow::anyhow;
  4use bytes::{BufMut, Bytes, BytesMut};
  5use futures::{AsyncRead, TryStreamExt};
  6use http_client::{http, AsyncBody, ReadTimeout};
  7use reqwest::header::{HeaderMap, HeaderValue};
  8use smol::future::FutureExt;
  9
 10const DEFAULT_CAPACITY: usize = 4096;
 11
 12pub struct ReqwestClient {
 13    client: reqwest::Client,
 14}
 15
 16impl ReqwestClient {
 17    pub fn new() -> Self {
 18        Self {
 19            client: reqwest::Client::new(),
 20        }
 21    }
 22
 23    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
 24        let mut map = HeaderMap::new();
 25        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 26        Ok(Self {
 27            client: reqwest::Client::builder().default_headers(map).build()?,
 28        })
 29    }
 30}
 31
 32impl From<reqwest::Client> for ReqwestClient {
 33    fn from(client: reqwest::Client) -> Self {
 34        Self { client }
 35    }
 36}
 37
 38// This struct is essentially a re-implementation of
 39// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
 40// except outside of Tokio's aegis
 41struct ReaderStream {
 42    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
 43    buf: BytesMut,
 44    capacity: usize,
 45}
 46
 47impl ReaderStream {
 48    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
 49        Self {
 50            reader: Some(reader),
 51            buf: BytesMut::new(),
 52            capacity: DEFAULT_CAPACITY,
 53        }
 54    }
 55}
 56
 57impl futures::Stream for ReaderStream {
 58    type Item = std::io::Result<Bytes>;
 59
 60    fn poll_next(
 61        mut self: Pin<&mut Self>,
 62        cx: &mut std::task::Context<'_>,
 63    ) -> Poll<Option<Self::Item>> {
 64        let mut this = self.as_mut();
 65
 66        let mut reader = match this.reader.take() {
 67            Some(r) => r,
 68            None => return Poll::Ready(None),
 69        };
 70
 71        if this.buf.capacity() == 0 {
 72            let capacity = this.capacity;
 73            this.buf.reserve(capacity);
 74        }
 75
 76        match poll_read_buf(&mut reader, cx, &mut this.buf) {
 77            Poll::Pending => Poll::Pending,
 78            Poll::Ready(Err(err)) => {
 79                self.reader = None;
 80
 81                Poll::Ready(Some(Err(err)))
 82            }
 83            Poll::Ready(Ok(0)) => {
 84                self.reader = None;
 85                Poll::Ready(None)
 86            }
 87            Poll::Ready(Ok(_)) => {
 88                let chunk = this.buf.split();
 89                self.reader = Some(reader);
 90                Poll::Ready(Some(Ok(chunk.freeze())))
 91            }
 92        }
 93    }
 94}
 95
 96/// Implementation from https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html
 97/// Specialized for this use case
 98pub fn poll_read_buf(
 99    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
100    cx: &mut std::task::Context<'_>,
101    buf: &mut BytesMut,
102) -> Poll<std::io::Result<usize>> {
103    if !buf.has_remaining_mut() {
104        return Poll::Ready(Ok(0));
105    }
106
107    let n = {
108        let dst = buf.chunk_mut();
109
110        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
111        // transparent wrapper around `[MaybeUninit<u8>]`.
112        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
113        let mut buf = tokio::io::ReadBuf::uninit(dst);
114        let ptr = buf.filled().as_ptr();
115        let unfilled_portion = buf.initialize_unfilled();
116        // SAFETY: Pin projection
117        let io_pin = unsafe { Pin::new_unchecked(io) };
118        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
119
120        // Ensure the pointer does not change from under us
121        assert_eq!(ptr, buf.filled().as_ptr());
122        buf.filled().len()
123    };
124
125    // Safety: This is guaranteed to be the number of initialized (and read)
126    // bytes due to the invariants provided by `ReadBuf::filled`.
127    unsafe {
128        buf.advance_mut(n);
129    }
130
131    Poll::Ready(Ok(n))
132}
133
134enum WrappedBodyInner {
135    None,
136    SyncReader(std::io::Cursor<Cow<'static, [u8]>>),
137    Stream(ReaderStream),
138}
139
140struct WrappedBody(WrappedBodyInner);
141
142impl WrappedBody {
143    fn new(body: AsyncBody) -> Self {
144        match body.0 {
145            http_client::Inner::Empty => Self(WrappedBodyInner::None),
146            http_client::Inner::SyncReader(cursor) => Self(WrappedBodyInner::SyncReader(cursor)),
147            http_client::Inner::AsyncReader(pin) => {
148                Self(WrappedBodyInner::Stream(ReaderStream::new(pin)))
149            }
150        }
151    }
152}
153
154impl futures::stream::Stream for WrappedBody {
155    type Item = Result<Bytes, std::io::Error>;
156
157    fn poll_next(
158        mut self: std::pin::Pin<&mut Self>,
159        cx: &mut std::task::Context<'_>,
160    ) -> std::task::Poll<Option<Self::Item>> {
161        match &mut self.0 {
162            WrappedBodyInner::None => Poll::Ready(None),
163            WrappedBodyInner::SyncReader(cursor) => {
164                let mut buf = Vec::new();
165                match cursor.read_to_end(&mut buf) {
166                    Ok(_) => {
167                        return Poll::Ready(Some(Ok(Bytes::from(buf))));
168                    }
169                    Err(e) => return Poll::Ready(Some(Err(e))),
170                }
171            }
172            WrappedBodyInner::Stream(stream) => {
173                // SAFETY: Pin projection
174                let stream = unsafe { Pin::new_unchecked(stream) };
175                futures::Stream::poll_next(stream, cx)
176            }
177        }
178    }
179}
180
181impl http_client::HttpClient for ReqwestClient {
182    fn proxy(&self) -> Option<&http::Uri> {
183        None
184    }
185
186    fn send(
187        &self,
188        req: http::Request<http_client::AsyncBody>,
189    ) -> futures::future::BoxFuture<
190        'static,
191        Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
192    > {
193        let (parts, body) = req.into_parts();
194
195        let mut request = self.client.request(parts.method, parts.uri.to_string());
196
197        request = request.headers(parts.headers);
198
199        if let Some(redirect_policy) = parts.extensions.get::<http_client::RedirectPolicy>() {
200            request = request.redirect_policy(match redirect_policy {
201                http_client::RedirectPolicy::NoFollow => reqwest::redirect::Policy::none(),
202                http_client::RedirectPolicy::FollowLimit(limit) => {
203                    reqwest::redirect::Policy::limited(*limit as usize)
204                }
205                http_client::RedirectPolicy::FollowAll => reqwest::redirect::Policy::limited(100),
206            });
207        }
208
209        if let Some(ReadTimeout(timeout)) = parts.extensions.get::<ReadTimeout>() {
210            request = request.timeout(*timeout);
211        }
212
213        let body = WrappedBody::new(body);
214        let request = request.body(reqwest::Body::wrap_stream(body));
215
216        async move {
217            let response = request.send().await.map_err(|e| anyhow!(e))?;
218            let status = response.status();
219            let mut builder = http::Response::builder().status(status.as_u16());
220            for (name, value) in response.headers() {
221                builder = builder.header(name, value);
222            }
223            let bytes = response.bytes_stream();
224            let bytes = bytes
225                .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
226                .into_async_read();
227            let body = http_client::AsyncBody::from_reader(bytes);
228            builder.body(body).map_err(|e| anyhow!(e))
229        }
230        .boxed()
231    }
232}