1use std::{any::type_name, mem, pin::Pin, sync::OnceLock, task::Poll, time::Duration};
2
3use anyhow::anyhow;
4use bytes::{BufMut, Bytes, BytesMut};
5use futures::{AsyncRead, TryStreamExt as _};
6use http_client::{http, RedirectPolicy};
7use reqwest::{
8 header::{HeaderMap, HeaderValue},
9 redirect,
10};
11use smol::future::FutureExt;
12
13const DEFAULT_CAPACITY: usize = 4096;
14static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
15
16pub struct ReqwestClient {
17 client: reqwest::Client,
18 proxy: Option<http::Uri>,
19 handle: tokio::runtime::Handle,
20}
21
22impl ReqwestClient {
23 fn builder() -> reqwest::ClientBuilder {
24 reqwest::Client::builder()
25 .use_rustls_tls()
26 .connect_timeout(Duration::from_secs(10))
27 }
28
29 pub fn new() -> Self {
30 Self::builder()
31 .build()
32 .expect("Failed to initialize HTTP client")
33 .into()
34 }
35
36 pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
37 let mut map = HeaderMap::new();
38 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
39 let client = Self::builder().default_headers(map).build()?;
40 Ok(client.into())
41 }
42
43 pub fn proxy_and_user_agent(proxy: Option<http::Uri>, agent: &str) -> anyhow::Result<Self> {
44 let mut map = HeaderMap::new();
45 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
46 let mut client = Self::builder().default_headers(map);
47 if let Some(proxy) = proxy.clone().and_then(|proxy_uri| {
48 reqwest::Proxy::all(proxy_uri.to_string())
49 .inspect_err(|e| log::error!("Failed to parse proxy URI {}: {}", proxy_uri, e))
50 .ok()
51 }) {
52 client = client.proxy(proxy);
53 }
54
55 let client = client
56 .use_preconfigured_tls(http_client::tls_config())
57 .build()?;
58 let mut client: ReqwestClient = client.into();
59 client.proxy = proxy;
60 Ok(client)
61 }
62}
63
64impl From<reqwest::Client> for ReqwestClient {
65 fn from(client: reqwest::Client) -> Self {
66 let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
67 log::debug!("no tokio runtime found, creating one for Reqwest...");
68 let runtime = RUNTIME.get_or_init(|| {
69 tokio::runtime::Builder::new_multi_thread()
70 // Since we now have two executors, let's try to keep our footprint small
71 .worker_threads(1)
72 .enable_all()
73 .build()
74 .expect("Failed to initialize HTTP client")
75 });
76
77 runtime.handle().clone()
78 });
79 Self {
80 client,
81 handle,
82 proxy: None,
83 }
84 }
85}
86
87// This struct is essentially a re-implementation of
88// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
89// except outside of Tokio's aegis
90struct StreamReader {
91 reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
92 buf: BytesMut,
93 capacity: usize,
94}
95
96impl StreamReader {
97 fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
98 Self {
99 reader: Some(reader),
100 buf: BytesMut::new(),
101 capacity: DEFAULT_CAPACITY,
102 }
103 }
104}
105
106impl futures::Stream for StreamReader {
107 type Item = std::io::Result<Bytes>;
108
109 fn poll_next(
110 mut self: Pin<&mut Self>,
111 cx: &mut std::task::Context<'_>,
112 ) -> Poll<Option<Self::Item>> {
113 let mut this = self.as_mut();
114
115 let mut reader = match this.reader.take() {
116 Some(r) => r,
117 None => return Poll::Ready(None),
118 };
119
120 if this.buf.capacity() == 0 {
121 let capacity = this.capacity;
122 this.buf.reserve(capacity);
123 }
124
125 match poll_read_buf(&mut reader, cx, &mut this.buf) {
126 Poll::Pending => Poll::Pending,
127 Poll::Ready(Err(err)) => {
128 self.reader = None;
129
130 Poll::Ready(Some(Err(err)))
131 }
132 Poll::Ready(Ok(0)) => {
133 self.reader = None;
134 Poll::Ready(None)
135 }
136 Poll::Ready(Ok(_)) => {
137 let chunk = this.buf.split();
138 self.reader = Some(reader);
139 Poll::Ready(Some(Ok(chunk.freeze())))
140 }
141 }
142 }
143}
144
145/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
146/// Specialized for this use case
147pub fn poll_read_buf(
148 io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
149 cx: &mut std::task::Context<'_>,
150 buf: &mut BytesMut,
151) -> Poll<std::io::Result<usize>> {
152 if !buf.has_remaining_mut() {
153 return Poll::Ready(Ok(0));
154 }
155
156 let n = {
157 let dst = buf.chunk_mut();
158
159 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
160 // transparent wrapper around `[MaybeUninit<u8>]`.
161 let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
162 let mut buf = tokio::io::ReadBuf::uninit(dst);
163 let ptr = buf.filled().as_ptr();
164 let unfilled_portion = buf.initialize_unfilled();
165 // SAFETY: Pin projection
166 let io_pin = unsafe { Pin::new_unchecked(io) };
167 std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
168
169 // Ensure the pointer does not change from under us
170 assert_eq!(ptr, buf.filled().as_ptr());
171 buf.filled().len()
172 };
173
174 // Safety: This is guaranteed to be the number of initialized (and read)
175 // bytes due to the invariants provided by `ReadBuf::filled`.
176 unsafe {
177 buf.advance_mut(n);
178 }
179
180 Poll::Ready(Ok(n))
181}
182
183impl http_client::HttpClient for ReqwestClient {
184 fn proxy(&self) -> Option<&http::Uri> {
185 self.proxy.as_ref()
186 }
187
188 fn type_name(&self) -> &'static str {
189 type_name::<Self>()
190 }
191
192 fn send(
193 &self,
194 req: http::Request<http_client::AsyncBody>,
195 ) -> futures::future::BoxFuture<
196 'static,
197 Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
198 > {
199 let (parts, body) = req.into_parts();
200
201 let mut request = self.client.request(parts.method, parts.uri.to_string());
202 request = request.headers(parts.headers);
203 if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
204 request = request.redirect_policy(match redirect_policy {
205 RedirectPolicy::NoFollow => redirect::Policy::none(),
206 RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
207 RedirectPolicy::FollowAll => redirect::Policy::limited(100),
208 });
209 }
210 let request = request.body(match body.0 {
211 http_client::Inner::Empty => reqwest::Body::default(),
212 http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
213 http_client::Inner::AsyncReader(stream) => {
214 reqwest::Body::wrap_stream(StreamReader::new(stream))
215 }
216 });
217
218 let handle = self.handle.clone();
219 async move {
220 let mut response = handle.spawn(async { request.send().await }).await??;
221
222 let headers = mem::take(response.headers_mut());
223 let mut builder = http::Response::builder()
224 .status(response.status().as_u16())
225 .version(response.version());
226 *builder.headers_mut().unwrap() = headers;
227
228 let bytes = response
229 .bytes_stream()
230 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
231 .into_async_read();
232 let body = http_client::AsyncBody::from_reader(bytes);
233
234 builder.body(body).map_err(|e| anyhow!(e))
235 }
236 .boxed()
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use http_client::{http, HttpClient};
243
244 use crate::ReqwestClient;
245
246 #[test]
247 fn test_proxy_uri() {
248 let client = ReqwestClient::new();
249 assert_eq!(client.proxy(), None);
250
251 let proxy = http::Uri::from_static("http://localhost:10809");
252 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
253 assert_eq!(client.proxy(), Some(&proxy));
254
255 let proxy = http::Uri::from_static("https://localhost:10809");
256 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
257 assert_eq!(client.proxy(), Some(&proxy));
258
259 let proxy = http::Uri::from_static("socks4://localhost:10808");
260 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
261 assert_eq!(client.proxy(), Some(&proxy));
262
263 let proxy = http::Uri::from_static("socks4a://localhost:10808");
264 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
265 assert_eq!(client.proxy(), Some(&proxy));
266
267 let proxy = http::Uri::from_static("socks5://localhost:10808");
268 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
269 assert_eq!(client.proxy(), Some(&proxy));
270
271 let proxy = http::Uri::from_static("socks5h://localhost:10808");
272 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
273 assert_eq!(client.proxy(), Some(&proxy));
274 }
275
276 #[test]
277 #[should_panic]
278 fn test_invalid_proxy_uri() {
279 let proxy = http::Uri::from_static("file:///etc/hosts");
280 ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
281 }
282}