reqwest_client.rs

  1use std::{any::type_name, borrow::Cow, io::Read, pin::Pin, task::Poll};
  2
  3use anyhow::anyhow;
  4use bytes::{BufMut, Bytes, BytesMut};
  5use futures::{AsyncRead, TryStreamExt};
  6use http_client::{http, AsyncBody, ReadTimeout};
  7use reqwest::header::{HeaderMap, HeaderValue};
  8use smol::future::FutureExt;
  9
 10const DEFAULT_CAPACITY: usize = 4096;
 11
 12pub struct ReqwestClient {
 13    client: reqwest::Client,
 14}
 15
 16impl ReqwestClient {
 17    pub fn new() -> Self {
 18        Self {
 19            client: reqwest::Client::new(),
 20        }
 21    }
 22
 23    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
 24        let mut map = HeaderMap::new();
 25        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 26        Ok(Self {
 27            client: reqwest::Client::builder().default_headers(map).build()?,
 28        })
 29    }
 30}
 31
 32impl From<reqwest::Client> for ReqwestClient {
 33    fn from(client: reqwest::Client) -> Self {
 34        Self { client }
 35    }
 36}
 37
 38// This struct is essentially a re-implementation of
 39// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
 40// except outside of Tokio's aegis
 41struct ReaderStream {
 42    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
 43    buf: BytesMut,
 44    capacity: usize,
 45}
 46
 47impl ReaderStream {
 48    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
 49        Self {
 50            reader: Some(reader),
 51            buf: BytesMut::new(),
 52            capacity: DEFAULT_CAPACITY,
 53        }
 54    }
 55}
 56
 57impl futures::Stream for ReaderStream {
 58    type Item = std::io::Result<Bytes>;
 59
 60    fn poll_next(
 61        mut self: Pin<&mut Self>,
 62        cx: &mut std::task::Context<'_>,
 63    ) -> Poll<Option<Self::Item>> {
 64        let mut this = self.as_mut();
 65
 66        let mut reader = match this.reader.take() {
 67            Some(r) => r,
 68            None => return Poll::Ready(None),
 69        };
 70
 71        if this.buf.capacity() == 0 {
 72            let capacity = this.capacity;
 73            this.buf.reserve(capacity);
 74        }
 75
 76        match poll_read_buf(&mut reader, cx, &mut this.buf) {
 77            Poll::Pending => Poll::Pending,
 78            Poll::Ready(Err(err)) => {
 79                self.reader = None;
 80
 81                Poll::Ready(Some(Err(err)))
 82            }
 83            Poll::Ready(Ok(0)) => {
 84                self.reader = None;
 85                Poll::Ready(None)
 86            }
 87            Poll::Ready(Ok(_)) => {
 88                let chunk = this.buf.split();
 89                self.reader = Some(reader);
 90                Poll::Ready(Some(Ok(chunk.freeze())))
 91            }
 92        }
 93    }
 94}
 95
 96/// Implementation from https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html
 97/// Specialized for this use case
 98pub fn poll_read_buf(
 99    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
100    cx: &mut std::task::Context<'_>,
101    buf: &mut BytesMut,
102) -> Poll<std::io::Result<usize>> {
103    if !buf.has_remaining_mut() {
104        return Poll::Ready(Ok(0));
105    }
106
107    let n = {
108        let dst = buf.chunk_mut();
109
110        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
111        // transparent wrapper around `[MaybeUninit<u8>]`.
112        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
113        let mut buf = tokio::io::ReadBuf::uninit(dst);
114        let ptr = buf.filled().as_ptr();
115        let unfilled_portion = buf.initialize_unfilled();
116        // SAFETY: Pin projection
117        let io_pin = unsafe { Pin::new_unchecked(io) };
118        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
119
120        // Ensure the pointer does not change from under us
121        assert_eq!(ptr, buf.filled().as_ptr());
122        buf.filled().len()
123    };
124
125    // Safety: This is guaranteed to be the number of initialized (and read)
126    // bytes due to the invariants provided by `ReadBuf::filled`.
127    unsafe {
128        buf.advance_mut(n);
129    }
130
131    Poll::Ready(Ok(n))
132}
133
134enum WrappedBodyInner {
135    None,
136    SyncReader(std::io::Cursor<Cow<'static, [u8]>>),
137    Stream(ReaderStream),
138}
139
140struct WrappedBody(WrappedBodyInner);
141
142impl WrappedBody {
143    fn new(body: AsyncBody) -> Self {
144        match body.0 {
145            http_client::Inner::Empty => Self(WrappedBodyInner::None),
146            http_client::Inner::SyncReader(cursor) => Self(WrappedBodyInner::SyncReader(cursor)),
147            http_client::Inner::AsyncReader(pin) => {
148                Self(WrappedBodyInner::Stream(ReaderStream::new(pin)))
149            }
150        }
151    }
152}
153
154impl futures::stream::Stream for WrappedBody {
155    type Item = Result<Bytes, std::io::Error>;
156
157    fn poll_next(
158        mut self: std::pin::Pin<&mut Self>,
159        cx: &mut std::task::Context<'_>,
160    ) -> std::task::Poll<Option<Self::Item>> {
161        match &mut self.0 {
162            WrappedBodyInner::None => Poll::Ready(None),
163            WrappedBodyInner::SyncReader(cursor) => {
164                let mut buf = Vec::new();
165                match cursor.read_to_end(&mut buf) {
166                    Ok(bytes) => {
167                        if bytes == 0 {
168                            return Poll::Ready(None);
169                        } else {
170                            return Poll::Ready(Some(Ok(Bytes::from(buf))));
171                        }
172                    }
173                    Err(e) => return Poll::Ready(Some(Err(e))),
174                }
175            }
176            WrappedBodyInner::Stream(stream) => {
177                // SAFETY: Pin projection
178                let stream = unsafe { Pin::new_unchecked(stream) };
179                futures::Stream::poll_next(stream, cx)
180            }
181        }
182    }
183}
184
185impl http_client::HttpClient for ReqwestClient {
186    fn proxy(&self) -> Option<&http::Uri> {
187        None
188    }
189
190    fn type_name(&self) -> &'static str {
191        type_name::<Self>()
192    }
193
194    fn send(
195        &self,
196        req: http::Request<http_client::AsyncBody>,
197    ) -> futures::future::BoxFuture<
198        'static,
199        Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
200    > {
201        let (parts, body) = req.into_parts();
202
203        let mut request = self.client.request(parts.method, parts.uri.to_string());
204
205        request = request.headers(parts.headers);
206
207        if let Some(redirect_policy) = parts.extensions.get::<http_client::RedirectPolicy>() {
208            request = request.redirect_policy(match redirect_policy {
209                http_client::RedirectPolicy::NoFollow => reqwest::redirect::Policy::none(),
210                http_client::RedirectPolicy::FollowLimit(limit) => {
211                    reqwest::redirect::Policy::limited(*limit as usize)
212                }
213                http_client::RedirectPolicy::FollowAll => reqwest::redirect::Policy::limited(100),
214            });
215        }
216
217        if let Some(ReadTimeout(timeout)) = parts.extensions.get::<ReadTimeout>() {
218            request = request.timeout(*timeout);
219        }
220
221        let body = WrappedBody::new(body);
222        let request = request.body(reqwest::Body::wrap_stream(body));
223
224        async move {
225            let response = request.send().await.map_err(|e| anyhow!(e))?;
226            let status = response.status();
227            let mut builder = http::Response::builder().status(status.as_u16());
228            for (name, value) in response.headers() {
229                builder = builder.header(name, value);
230            }
231            let bytes = response.bytes_stream();
232            let bytes = bytes
233                .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
234                .into_async_read();
235            let body = http_client::AsyncBody::from_reader(bytes);
236            builder.body(body).map_err(|e| anyhow!(e))
237        }
238        .boxed()
239    }
240}
241
242#[cfg(test)]
243mod test {
244
245    use core::str;
246
247    use http_client::AsyncBody;
248    use smol::stream::StreamExt;
249
250    use crate::WrappedBody;
251
252    #[tokio::test]
253    async fn test_sync_streaming_upload() {
254        let mut body = WrappedBody::new(AsyncBody::from("hello there".to_string())).fuse();
255        let result = body.next().await.unwrap().unwrap();
256        assert!(body.next().await.is_none());
257        assert_eq!(str::from_utf8(&result).unwrap(), "hello there");
258    }
259}