1use std::{borrow::Cow, io::Read, pin::Pin, task::Poll};
2
3use anyhow::anyhow;
4use bytes::{BufMut, Bytes, BytesMut};
5use futures::{AsyncRead, TryStreamExt};
6use http_client::{http, AsyncBody, ReadTimeout};
7use reqwest::header::{HeaderMap, HeaderValue};
8use smol::future::FutureExt;
9
10const DEFAULT_CAPACITY: usize = 4096;
11
12pub struct ReqwestClient {
13 client: reqwest::Client,
14}
15
16impl ReqwestClient {
17 pub fn new() -> Self {
18 Self {
19 client: reqwest::Client::new(),
20 }
21 }
22
23 pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
24 let mut map = HeaderMap::new();
25 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
26 Ok(Self {
27 client: reqwest::Client::builder().default_headers(map).build()?,
28 })
29 }
30}
31
32impl From<reqwest::Client> for ReqwestClient {
33 fn from(client: reqwest::Client) -> Self {
34 Self { client }
35 }
36}
37
38// This struct is essentially a re-implementation of
39// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
40// except outside of Tokio's aegis
41struct ReaderStream {
42 reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
43 buf: BytesMut,
44 capacity: usize,
45}
46
47impl ReaderStream {
48 fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
49 Self {
50 reader: Some(reader),
51 buf: BytesMut::new(),
52 capacity: DEFAULT_CAPACITY,
53 }
54 }
55}
56
57impl futures::Stream for ReaderStream {
58 type Item = std::io::Result<Bytes>;
59
60 fn poll_next(
61 mut self: Pin<&mut Self>,
62 cx: &mut std::task::Context<'_>,
63 ) -> Poll<Option<Self::Item>> {
64 let mut this = self.as_mut();
65
66 let mut reader = match this.reader.take() {
67 Some(r) => r,
68 None => return Poll::Ready(None),
69 };
70
71 if this.buf.capacity() == 0 {
72 let capacity = this.capacity;
73 this.buf.reserve(capacity);
74 }
75
76 match poll_read_buf(&mut reader, cx, &mut this.buf) {
77 Poll::Pending => Poll::Pending,
78 Poll::Ready(Err(err)) => {
79 self.reader = None;
80
81 Poll::Ready(Some(Err(err)))
82 }
83 Poll::Ready(Ok(0)) => {
84 self.reader = None;
85 Poll::Ready(None)
86 }
87 Poll::Ready(Ok(_)) => {
88 let chunk = this.buf.split();
89 self.reader = Some(reader);
90 Poll::Ready(Some(Ok(chunk.freeze())))
91 }
92 }
93 }
94}
95
96/// Implementation from https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html
97/// Specialized for this use case
98pub fn poll_read_buf(
99 io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
100 cx: &mut std::task::Context<'_>,
101 buf: &mut BytesMut,
102) -> Poll<std::io::Result<usize>> {
103 if !buf.has_remaining_mut() {
104 return Poll::Ready(Ok(0));
105 }
106
107 let n = {
108 let dst = buf.chunk_mut();
109
110 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
111 // transparent wrapper around `[MaybeUninit<u8>]`.
112 let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
113 let mut buf = tokio::io::ReadBuf::uninit(dst);
114 let ptr = buf.filled().as_ptr();
115 let unfilled_portion = buf.initialize_unfilled();
116 // SAFETY: Pin projection
117 let io_pin = unsafe { Pin::new_unchecked(io) };
118 std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
119
120 // Ensure the pointer does not change from under us
121 assert_eq!(ptr, buf.filled().as_ptr());
122 buf.filled().len()
123 };
124
125 // Safety: This is guaranteed to be the number of initialized (and read)
126 // bytes due to the invariants provided by `ReadBuf::filled`.
127 unsafe {
128 buf.advance_mut(n);
129 }
130
131 Poll::Ready(Ok(n))
132}
133
134enum WrappedBodyInner {
135 None,
136 SyncReader(std::io::Cursor<Cow<'static, [u8]>>),
137 Stream(ReaderStream),
138}
139
140struct WrappedBody(WrappedBodyInner);
141
142impl WrappedBody {
143 fn new(body: AsyncBody) -> Self {
144 match body.0 {
145 http_client::Inner::Empty => Self(WrappedBodyInner::None),
146 http_client::Inner::SyncReader(cursor) => Self(WrappedBodyInner::SyncReader(cursor)),
147 http_client::Inner::AsyncReader(pin) => {
148 Self(WrappedBodyInner::Stream(ReaderStream::new(pin)))
149 }
150 }
151 }
152}
153
154impl futures::stream::Stream for WrappedBody {
155 type Item = Result<Bytes, std::io::Error>;
156
157 fn poll_next(
158 mut self: std::pin::Pin<&mut Self>,
159 cx: &mut std::task::Context<'_>,
160 ) -> std::task::Poll<Option<Self::Item>> {
161 match &mut self.0 {
162 WrappedBodyInner::None => Poll::Ready(None),
163 WrappedBodyInner::SyncReader(cursor) => {
164 let mut buf = Vec::new();
165 match cursor.read_to_end(&mut buf) {
166 Ok(_) => {
167 return Poll::Ready(Some(Ok(Bytes::from(buf))));
168 }
169 Err(e) => return Poll::Ready(Some(Err(e))),
170 }
171 }
172 WrappedBodyInner::Stream(stream) => {
173 // SAFETY: Pin projection
174 let stream = unsafe { Pin::new_unchecked(stream) };
175 futures::Stream::poll_next(stream, cx)
176 }
177 }
178 }
179}
180
181impl http_client::HttpClient for ReqwestClient {
182 fn proxy(&self) -> Option<&http::Uri> {
183 None
184 }
185
186 fn send(
187 &self,
188 req: http::Request<http_client::AsyncBody>,
189 ) -> futures::future::BoxFuture<
190 'static,
191 Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
192 > {
193 let (parts, body) = req.into_parts();
194
195 let mut request = self.client.request(parts.method, parts.uri.to_string());
196
197 request = request.headers(parts.headers);
198
199 if let Some(redirect_policy) = parts.extensions.get::<http_client::RedirectPolicy>() {
200 request = request.redirect_policy(match redirect_policy {
201 http_client::RedirectPolicy::NoFollow => reqwest::redirect::Policy::none(),
202 http_client::RedirectPolicy::FollowLimit(limit) => {
203 reqwest::redirect::Policy::limited(*limit as usize)
204 }
205 http_client::RedirectPolicy::FollowAll => reqwest::redirect::Policy::limited(100),
206 });
207 }
208
209 if let Some(ReadTimeout(timeout)) = parts.extensions.get::<ReadTimeout>() {
210 request = request.timeout(*timeout);
211 }
212
213 let body = WrappedBody::new(body);
214 let request = request.body(reqwest::Body::wrap_stream(body));
215
216 async move {
217 let response = request.send().await.map_err(|e| anyhow!(e))?;
218 let status = response.status();
219 let mut builder = http::Response::builder().status(status.as_u16());
220 for (name, value) in response.headers() {
221 builder = builder.header(name, value);
222 }
223 let bytes = response.bytes_stream();
224 let bytes = bytes
225 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
226 .into_async_read();
227 let body = http_client::AsyncBody::from_reader(bytes);
228 builder.body(body).map_err(|e| anyhow!(e))
229 }
230 .boxed()
231 }
232}