1use std::{any::type_name, borrow::Cow, io::Read, mem, pin::Pin, sync::OnceLock, task::Poll};
2
3use anyhow::anyhow;
4use bytes::{BufMut, Bytes, BytesMut};
5use futures::{AsyncRead, TryStreamExt as _};
6use http_client::{http, ReadTimeout, RedirectPolicy};
7use reqwest::{
8 header::{HeaderMap, HeaderValue},
9 redirect,
10};
11use smol::future::FutureExt;
12
13const DEFAULT_CAPACITY: usize = 4096;
14static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
15
16pub struct ReqwestClient {
17 client: reqwest::Client,
18 proxy: Option<http::Uri>,
19 handle: tokio::runtime::Handle,
20}
21
22impl ReqwestClient {
23 pub fn new() -> Self {
24 reqwest::Client::builder()
25 .use_rustls_tls()
26 .build()
27 .expect("Failed to initialize HTTP client")
28 .into()
29 }
30
31 pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
32 let mut map = HeaderMap::new();
33 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
34 let client = reqwest::Client::builder()
35 .default_headers(map)
36 .use_rustls_tls()
37 .build()?;
38 Ok(client.into())
39 }
40
41 pub fn proxy_and_user_agent(proxy: Option<http::Uri>, agent: &str) -> anyhow::Result<Self> {
42 let mut map = HeaderMap::new();
43 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
44 let mut client = reqwest::Client::builder()
45 .use_rustls_tls()
46 .default_headers(map);
47 if let Some(proxy) = proxy.clone() {
48 client = client.proxy(reqwest::Proxy::all(proxy.to_string())?);
49 }
50 let client = client.build()?;
51 let mut client: ReqwestClient = client.into();
52 client.proxy = proxy;
53 Ok(client)
54 }
55}
56
57impl From<reqwest::Client> for ReqwestClient {
58 fn from(client: reqwest::Client) -> Self {
59 let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
60 log::info!("no tokio runtime found, creating one for Reqwest...");
61 let runtime = RUNTIME.get_or_init(|| {
62 tokio::runtime::Builder::new_multi_thread()
63 // Since we now have two executors, let's try to keep our footprint small
64 .worker_threads(1)
65 .enable_all()
66 .build()
67 .expect("Failed to initialize HTTP client")
68 });
69
70 runtime.handle().clone()
71 });
72 Self {
73 client,
74 handle,
75 proxy: None,
76 }
77 }
78}
79
80// This struct is essentially a re-implementation of
81// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
82// except outside of Tokio's aegis
83struct StreamReader {
84 reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
85 buf: BytesMut,
86 capacity: usize,
87}
88
89impl StreamReader {
90 fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
91 Self {
92 reader: Some(reader),
93 buf: BytesMut::new(),
94 capacity: DEFAULT_CAPACITY,
95 }
96 }
97}
98
99impl futures::Stream for StreamReader {
100 type Item = std::io::Result<Bytes>;
101
102 fn poll_next(
103 mut self: Pin<&mut Self>,
104 cx: &mut std::task::Context<'_>,
105 ) -> Poll<Option<Self::Item>> {
106 let mut this = self.as_mut();
107
108 let mut reader = match this.reader.take() {
109 Some(r) => r,
110 None => return Poll::Ready(None),
111 };
112
113 if this.buf.capacity() == 0 {
114 let capacity = this.capacity;
115 this.buf.reserve(capacity);
116 }
117
118 match poll_read_buf(&mut reader, cx, &mut this.buf) {
119 Poll::Pending => Poll::Pending,
120 Poll::Ready(Err(err)) => {
121 self.reader = None;
122
123 Poll::Ready(Some(Err(err)))
124 }
125 Poll::Ready(Ok(0)) => {
126 self.reader = None;
127 Poll::Ready(None)
128 }
129 Poll::Ready(Ok(_)) => {
130 let chunk = this.buf.split();
131 self.reader = Some(reader);
132 Poll::Ready(Some(Ok(chunk.freeze())))
133 }
134 }
135 }
136}
137
138/// Implementation from https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html
139/// Specialized for this use case
140pub fn poll_read_buf(
141 io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
142 cx: &mut std::task::Context<'_>,
143 buf: &mut BytesMut,
144) -> Poll<std::io::Result<usize>> {
145 if !buf.has_remaining_mut() {
146 return Poll::Ready(Ok(0));
147 }
148
149 let n = {
150 let dst = buf.chunk_mut();
151
152 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
153 // transparent wrapper around `[MaybeUninit<u8>]`.
154 let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
155 let mut buf = tokio::io::ReadBuf::uninit(dst);
156 let ptr = buf.filled().as_ptr();
157 let unfilled_portion = buf.initialize_unfilled();
158 // SAFETY: Pin projection
159 let io_pin = unsafe { Pin::new_unchecked(io) };
160 std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
161
162 // Ensure the pointer does not change from under us
163 assert_eq!(ptr, buf.filled().as_ptr());
164 buf.filled().len()
165 };
166
167 // Safety: This is guaranteed to be the number of initialized (and read)
168 // bytes due to the invariants provided by `ReadBuf::filled`.
169 unsafe {
170 buf.advance_mut(n);
171 }
172
173 Poll::Ready(Ok(n))
174}
175
176struct SyncReader {
177 cursor: Option<std::io::Cursor<Cow<'static, [u8]>>>,
178}
179
180impl SyncReader {
181 fn new(cursor: std::io::Cursor<Cow<'static, [u8]>>) -> Self {
182 Self {
183 cursor: Some(cursor),
184 }
185 }
186}
187
188impl futures::stream::Stream for SyncReader {
189 type Item = Result<Bytes, std::io::Error>;
190
191 fn poll_next(
192 mut self: std::pin::Pin<&mut Self>,
193 _cx: &mut std::task::Context<'_>,
194 ) -> std::task::Poll<Option<Self::Item>> {
195 let Some(mut cursor) = self.cursor.take() else {
196 return Poll::Ready(None);
197 };
198
199 let mut buf = Vec::new();
200 match cursor.read_to_end(&mut buf) {
201 Ok(_) => {
202 return Poll::Ready(Some(Ok(Bytes::from(buf))));
203 }
204 Err(e) => return Poll::Ready(Some(Err(e))),
205 }
206 }
207}
208
209impl http_client::HttpClient for ReqwestClient {
210 fn proxy(&self) -> Option<&http::Uri> {
211 self.proxy.as_ref()
212 }
213
214 fn type_name(&self) -> &'static str {
215 type_name::<Self>()
216 }
217
218 fn send(
219 &self,
220 req: http::Request<http_client::AsyncBody>,
221 ) -> futures::future::BoxFuture<
222 'static,
223 Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
224 > {
225 let (parts, body) = req.into_parts();
226
227 let mut request = self.client.request(parts.method, parts.uri.to_string());
228 request = request.headers(parts.headers);
229 if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
230 request = request.redirect_policy(match redirect_policy {
231 RedirectPolicy::NoFollow => redirect::Policy::none(),
232 RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
233 RedirectPolicy::FollowAll => redirect::Policy::limited(100),
234 });
235 }
236 if let Some(ReadTimeout(timeout)) = parts.extensions.get::<ReadTimeout>() {
237 request = request.timeout(*timeout);
238 }
239 let request = request.body(match body.0 {
240 http_client::Inner::Empty => reqwest::Body::default(),
241 http_client::Inner::SyncReader(cursor) => {
242 reqwest::Body::wrap_stream(SyncReader::new(cursor))
243 }
244 http_client::Inner::AsyncReader(stream) => {
245 reqwest::Body::wrap_stream(StreamReader::new(stream))
246 }
247 });
248
249 let handle = self.handle.clone();
250 async move {
251 let mut response = handle.spawn(async { request.send().await }).await??;
252
253 let headers = mem::take(response.headers_mut());
254 let mut builder = http::Response::builder()
255 .status(response.status().as_u16())
256 .version(response.version());
257 *builder.headers_mut().unwrap() = headers;
258
259 let bytes = response
260 .bytes_stream()
261 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
262 .into_async_read();
263 let body = http_client::AsyncBody::from_reader(bytes);
264
265 builder.body(body).map_err(|e| anyhow!(e))
266 }
267 .boxed()
268 }
269}