1use std::sync::{LazyLock, OnceLock};
2use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
3
4use anyhow::anyhow;
5use bytes::{BufMut, Bytes, BytesMut};
6use futures::{AsyncRead, TryStreamExt as _};
7use http_client::{RedirectPolicy, http};
8use regex::Regex;
9use reqwest::{
10 header::{HeaderMap, HeaderValue},
11 redirect,
12};
13use smol::future::FutureExt;
14
15const DEFAULT_CAPACITY: usize = 4096;
16static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
17static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
18
19pub struct ReqwestClient {
20 client: reqwest::Client,
21 proxy: Option<http::Uri>,
22 handle: tokio::runtime::Handle,
23}
24
25impl ReqwestClient {
26 fn builder() -> reqwest::ClientBuilder {
27 reqwest::Client::builder()
28 .use_rustls_tls()
29 .connect_timeout(Duration::from_secs(10))
30 }
31
32 pub fn new() -> Self {
33 Self::builder()
34 .build()
35 .expect("Failed to initialize HTTP client")
36 .into()
37 }
38
39 pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
40 let mut map = HeaderMap::new();
41 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
42 let client = Self::builder().default_headers(map).build()?;
43 Ok(client.into())
44 }
45
46 pub fn proxy_and_user_agent(proxy: Option<http::Uri>, agent: &str) -> anyhow::Result<Self> {
47 let mut map = HeaderMap::new();
48 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
49 let mut client = Self::builder().default_headers(map);
50 if let Some(proxy) = proxy.clone().and_then(|proxy_uri| {
51 reqwest::Proxy::all(proxy_uri.to_string())
52 .inspect_err(|e| log::error!("Failed to parse proxy URI {}: {}", proxy_uri, e))
53 .ok()
54 }) {
55 client = client.proxy(proxy);
56 }
57
58 let client = client
59 .use_preconfigured_tls(http_client_tls::tls_config())
60 .build()?;
61 let mut client: ReqwestClient = client.into();
62 client.proxy = proxy;
63 Ok(client)
64 }
65}
66
67impl From<reqwest::Client> for ReqwestClient {
68 fn from(client: reqwest::Client) -> Self {
69 let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
70 log::debug!("no tokio runtime found, creating one for Reqwest...");
71 let runtime = RUNTIME.get_or_init(|| {
72 tokio::runtime::Builder::new_multi_thread()
73 // Since we now have two executors, let's try to keep our footprint small
74 .worker_threads(1)
75 .enable_all()
76 .build()
77 .expect("Failed to initialize HTTP client")
78 });
79
80 runtime.handle().clone()
81 });
82 Self {
83 client,
84 handle,
85 proxy: None,
86 }
87 }
88}
89
90// This struct is essentially a re-implementation of
91// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
92// except outside of Tokio's aegis
93struct StreamReader {
94 reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
95 buf: BytesMut,
96 capacity: usize,
97}
98
99impl StreamReader {
100 fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
101 Self {
102 reader: Some(reader),
103 buf: BytesMut::new(),
104 capacity: DEFAULT_CAPACITY,
105 }
106 }
107}
108
109impl futures::Stream for StreamReader {
110 type Item = std::io::Result<Bytes>;
111
112 fn poll_next(
113 mut self: Pin<&mut Self>,
114 cx: &mut std::task::Context<'_>,
115 ) -> Poll<Option<Self::Item>> {
116 let mut this = self.as_mut();
117
118 let mut reader = match this.reader.take() {
119 Some(r) => r,
120 None => return Poll::Ready(None),
121 };
122
123 if this.buf.capacity() == 0 {
124 let capacity = this.capacity;
125 this.buf.reserve(capacity);
126 }
127
128 match poll_read_buf(&mut reader, cx, &mut this.buf) {
129 Poll::Pending => Poll::Pending,
130 Poll::Ready(Err(err)) => {
131 self.reader = None;
132
133 Poll::Ready(Some(Err(err)))
134 }
135 Poll::Ready(Ok(0)) => {
136 self.reader = None;
137 Poll::Ready(None)
138 }
139 Poll::Ready(Ok(_)) => {
140 let chunk = this.buf.split();
141 self.reader = Some(reader);
142 Poll::Ready(Some(Ok(chunk.freeze())))
143 }
144 }
145 }
146}
147
148/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
149/// Specialized for this use case
150pub fn poll_read_buf(
151 io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
152 cx: &mut std::task::Context<'_>,
153 buf: &mut BytesMut,
154) -> Poll<std::io::Result<usize>> {
155 if !buf.has_remaining_mut() {
156 return Poll::Ready(Ok(0));
157 }
158
159 let n = {
160 let dst = buf.chunk_mut();
161
162 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
163 // transparent wrapper around `[MaybeUninit<u8>]`.
164 let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
165 let mut buf = tokio::io::ReadBuf::uninit(dst);
166 let ptr = buf.filled().as_ptr();
167 let unfilled_portion = buf.initialize_unfilled();
168 // SAFETY: Pin projection
169 let io_pin = unsafe { Pin::new_unchecked(io) };
170 std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
171
172 // Ensure the pointer does not change from under us
173 assert_eq!(ptr, buf.filled().as_ptr());
174 buf.filled().len()
175 };
176
177 // Safety: This is guaranteed to be the number of initialized (and read)
178 // bytes due to the invariants provided by `ReadBuf::filled`.
179 unsafe {
180 buf.advance_mut(n);
181 }
182
183 Poll::Ready(Ok(n))
184}
185
186fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
187 if let Some(url) = error.url_mut() {
188 if let Some(query) = url.query() {
189 if let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED") {
190 url.set_query(Some(redacted.as_str()));
191 }
192 }
193 }
194 error
195}
196
197impl http_client::HttpClient for ReqwestClient {
198 fn proxy(&self) -> Option<&http::Uri> {
199 self.proxy.as_ref()
200 }
201
202 fn type_name(&self) -> &'static str {
203 type_name::<Self>()
204 }
205
206 fn send(
207 &self,
208 req: http::Request<http_client::AsyncBody>,
209 ) -> futures::future::BoxFuture<
210 'static,
211 Result<http_client::Response<http_client::AsyncBody>, anyhow::Error>,
212 > {
213 let (parts, body) = req.into_parts();
214
215 let mut request = self.client.request(parts.method, parts.uri.to_string());
216 request = request.headers(parts.headers);
217 if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
218 request = request.redirect_policy(match redirect_policy {
219 RedirectPolicy::NoFollow => redirect::Policy::none(),
220 RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
221 RedirectPolicy::FollowAll => redirect::Policy::limited(100),
222 });
223 }
224 let request = request.body(match body.0 {
225 http_client::Inner::Empty => reqwest::Body::default(),
226 http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
227 http_client::Inner::AsyncReader(stream) => {
228 reqwest::Body::wrap_stream(StreamReader::new(stream))
229 }
230 });
231
232 let handle = self.handle.clone();
233 async move {
234 let mut response = handle
235 .spawn(async { request.send().await })
236 .await?
237 .map_err(redact_error)?;
238
239 let headers = mem::take(response.headers_mut());
240 let mut builder = http::Response::builder()
241 .status(response.status().as_u16())
242 .version(response.version());
243 *builder.headers_mut().unwrap() = headers;
244
245 let bytes = response
246 .bytes_stream()
247 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
248 .into_async_read();
249 let body = http_client::AsyncBody::from_reader(bytes);
250
251 builder.body(body).map_err(|e| anyhow!(e))
252 }
253 .boxed()
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use http_client::{HttpClient, http};
260
261 use crate::ReqwestClient;
262
263 #[test]
264 fn test_proxy_uri() {
265 let client = ReqwestClient::new();
266 assert_eq!(client.proxy(), None);
267
268 let proxy = http::Uri::from_static("http://localhost:10809");
269 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
270 assert_eq!(client.proxy(), Some(&proxy));
271
272 let proxy = http::Uri::from_static("https://localhost:10809");
273 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
274 assert_eq!(client.proxy(), Some(&proxy));
275
276 let proxy = http::Uri::from_static("socks4://localhost:10808");
277 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
278 assert_eq!(client.proxy(), Some(&proxy));
279
280 let proxy = http::Uri::from_static("socks4a://localhost:10808");
281 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
282 assert_eq!(client.proxy(), Some(&proxy));
283
284 let proxy = http::Uri::from_static("socks5://localhost:10808");
285 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
286 assert_eq!(client.proxy(), Some(&proxy));
287
288 let proxy = http::Uri::from_static("socks5h://localhost:10808");
289 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
290 assert_eq!(client.proxy(), Some(&proxy));
291 }
292
293 #[test]
294 #[should_panic]
295 fn test_invalid_proxy_uri() {
296 let proxy = http::Uri::from_static("file:///etc/hosts");
297 ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
298 }
299}