1use std::error::Error;
2use std::sync::{LazyLock, OnceLock};
3use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::anyhow;
6use bytes::{BufMut, Bytes, BytesMut};
7use futures::{AsyncRead, FutureExt as _, TryStreamExt as _};
8use http_client::{RedirectPolicy, Url, http};
9use regex::Regex;
10use reqwest::{
11 header::{HeaderMap, HeaderValue},
12 redirect,
13};
14
15const DEFAULT_CAPACITY: usize = 4096;
16static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
17static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
18
19pub struct ReqwestClient {
20 client: reqwest::Client,
21 proxy: Option<Url>,
22 user_agent: Option<HeaderValue>,
23 handle: tokio::runtime::Handle,
24}
25
26impl ReqwestClient {
27 fn builder() -> reqwest::ClientBuilder {
28 reqwest::Client::builder()
29 .use_rustls_tls()
30 .connect_timeout(Duration::from_secs(10))
31 }
32
33 pub fn new() -> Self {
34 Self::builder()
35 .build()
36 .expect("Failed to initialize HTTP client")
37 .into()
38 }
39
40 pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
41 let mut map = HeaderMap::new();
42 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
43 let client = Self::builder().default_headers(map).build()?;
44 Ok(client.into())
45 }
46
47 pub fn proxy_and_user_agent(proxy: Option<Url>, user_agent: &str) -> anyhow::Result<Self> {
48 let user_agent = HeaderValue::from_str(user_agent)?;
49
50 let mut map = HeaderMap::new();
51 map.insert(http::header::USER_AGENT, user_agent.clone());
52 let mut client = Self::builder().default_headers(map);
53 let client_has_proxy;
54
55 if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
56 reqwest::Proxy::all(proxy_url.clone())
57 .inspect_err(|e| {
58 log::error!(
59 "Failed to parse proxy URL '{}': {}",
60 proxy_url,
61 e.source().unwrap_or(&e as &_)
62 )
63 })
64 .ok()
65 }) {
66 // Respect NO_PROXY env var
67 client = client.proxy(proxy.no_proxy(reqwest::NoProxy::from_env()));
68 client_has_proxy = true;
69 } else {
70 client_has_proxy = false;
71 };
72
73 let client = client
74 .use_preconfigured_tls(http_client_tls::tls_config())
75 .build()?;
76 let mut client: ReqwestClient = client.into();
77 client.proxy = client_has_proxy.then_some(proxy).flatten();
78 client.user_agent = Some(user_agent);
79 Ok(client)
80 }
81}
82
83impl From<reqwest::Client> for ReqwestClient {
84 fn from(client: reqwest::Client) -> Self {
85 let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
86 log::debug!("no tokio runtime found, creating one for Reqwest...");
87 let runtime = RUNTIME.get_or_init(|| {
88 tokio::runtime::Builder::new_multi_thread()
89 // Since we now have two executors, let's try to keep our footprint small
90 .worker_threads(1)
91 .enable_all()
92 .build()
93 .expect("Failed to initialize HTTP client")
94 });
95
96 runtime.handle().clone()
97 });
98 Self {
99 client,
100 handle,
101 proxy: None,
102 user_agent: None,
103 }
104 }
105}
106
107// This struct is essentially a re-implementation of
108// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
109// except outside of Tokio's aegis
110struct StreamReader {
111 reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
112 buf: BytesMut,
113 capacity: usize,
114}
115
116impl StreamReader {
117 fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
118 Self {
119 reader: Some(reader),
120 buf: BytesMut::new(),
121 capacity: DEFAULT_CAPACITY,
122 }
123 }
124}
125
126impl futures::Stream for StreamReader {
127 type Item = std::io::Result<Bytes>;
128
129 fn poll_next(
130 mut self: Pin<&mut Self>,
131 cx: &mut std::task::Context<'_>,
132 ) -> Poll<Option<Self::Item>> {
133 let mut this = self.as_mut();
134
135 let mut reader = match this.reader.take() {
136 Some(r) => r,
137 None => return Poll::Ready(None),
138 };
139
140 if this.buf.capacity() == 0 {
141 let capacity = this.capacity;
142 this.buf.reserve(capacity);
143 }
144
145 match poll_read_buf(&mut reader, cx, &mut this.buf) {
146 Poll::Pending => Poll::Pending,
147 Poll::Ready(Err(err)) => {
148 self.reader = None;
149
150 Poll::Ready(Some(Err(err)))
151 }
152 Poll::Ready(Ok(0)) => {
153 self.reader = None;
154 Poll::Ready(None)
155 }
156 Poll::Ready(Ok(_)) => {
157 let chunk = this.buf.split();
158 self.reader = Some(reader);
159 Poll::Ready(Some(Ok(chunk.freeze())))
160 }
161 }
162 }
163}
164
165/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
166/// Specialized for this use case
167pub fn poll_read_buf(
168 io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
169 cx: &mut std::task::Context<'_>,
170 buf: &mut BytesMut,
171) -> Poll<std::io::Result<usize>> {
172 if !buf.has_remaining_mut() {
173 return Poll::Ready(Ok(0));
174 }
175
176 let n = {
177 let dst = buf.chunk_mut();
178
179 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
180 // transparent wrapper around `[MaybeUninit<u8>]`.
181 let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
182 let mut buf = tokio::io::ReadBuf::uninit(dst);
183 let ptr = buf.filled().as_ptr();
184 let unfilled_portion = buf.initialize_unfilled();
185 // SAFETY: Pin projection
186 let io_pin = unsafe { Pin::new_unchecked(io) };
187 std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
188
189 // Ensure the pointer does not change from under us
190 assert_eq!(ptr, buf.filled().as_ptr());
191 buf.filled().len()
192 };
193
194 // Safety: This is guaranteed to be the number of initialized (and read)
195 // bytes due to the invariants provided by `ReadBuf::filled`.
196 unsafe {
197 buf.advance_mut(n);
198 }
199
200 Poll::Ready(Ok(n))
201}
202
203fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
204 if let Some(url) = error.url_mut()
205 && let Some(query) = url.query()
206 && let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED")
207 {
208 url.set_query(Some(redacted.as_str()));
209 }
210 error
211}
212
213impl http_client::HttpClient for ReqwestClient {
214 fn proxy(&self) -> Option<&Url> {
215 self.proxy.as_ref()
216 }
217
218 fn type_name(&self) -> &'static str {
219 type_name::<Self>()
220 }
221
222 fn user_agent(&self) -> Option<&HeaderValue> {
223 self.user_agent.as_ref()
224 }
225
226 fn send(
227 &self,
228 req: http::Request<http_client::AsyncBody>,
229 ) -> futures::future::BoxFuture<
230 'static,
231 anyhow::Result<http_client::Response<http_client::AsyncBody>>,
232 > {
233 let (parts, body) = req.into_parts();
234
235 let mut request = self.client.request(parts.method, parts.uri.to_string());
236 request = request.headers(parts.headers);
237 if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
238 request = request.redirect_policy(match redirect_policy {
239 RedirectPolicy::NoFollow => redirect::Policy::none(),
240 RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
241 RedirectPolicy::FollowAll => redirect::Policy::limited(100),
242 });
243 }
244 let request = request.body(match body.0 {
245 http_client::Inner::Empty => reqwest::Body::default(),
246 http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
247 http_client::Inner::AsyncReader(stream) => {
248 reqwest::Body::wrap_stream(StreamReader::new(stream))
249 }
250 });
251
252 let handle = self.handle.clone();
253 async move {
254 let mut response = handle
255 .spawn(async { request.send().await })
256 .await?
257 .map_err(redact_error)?;
258
259 let headers = mem::take(response.headers_mut());
260 let mut builder = http::Response::builder()
261 .status(response.status().as_u16())
262 .version(response.version());
263 *builder.headers_mut().unwrap() = headers;
264
265 let bytes = response
266 .bytes_stream()
267 .map_err(futures::io::Error::other)
268 .into_async_read();
269 let body = http_client::AsyncBody::from_reader(bytes);
270
271 builder.body(body).map_err(|e| anyhow!(e))
272 }
273 .boxed()
274 }
275
276 fn send_multipart_form<'a>(
277 &'a self,
278 url: &str,
279 form: reqwest::multipart::Form,
280 ) -> futures::future::BoxFuture<'a, anyhow::Result<http_client::Response<http_client::AsyncBody>>>
281 {
282 let response = self.client.post(url).multipart(form).send();
283 self.handle
284 .spawn(async move {
285 let response = response.await?;
286 let mut builder = http::response::Builder::new().status(response.status());
287 for (k, v) in response.headers() {
288 builder = builder.header(k, v)
289 }
290 Ok(builder.body(response.bytes().await?.into())?)
291 })
292 .map(|e| e?)
293 .boxed()
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use http_client::{HttpClient, Url};
300
301 use crate::ReqwestClient;
302
303 #[test]
304 fn test_proxy_uri() {
305 let client = ReqwestClient::new();
306 assert_eq!(client.proxy(), None);
307
308 let proxy = Url::parse("http://localhost:10809").unwrap();
309 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
310 assert_eq!(client.proxy(), Some(&proxy));
311
312 let proxy = Url::parse("https://localhost:10809").unwrap();
313 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
314 assert_eq!(client.proxy(), Some(&proxy));
315
316 let proxy = Url::parse("socks4://localhost:10808").unwrap();
317 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
318 assert_eq!(client.proxy(), Some(&proxy));
319
320 let proxy = Url::parse("socks4a://localhost:10808").unwrap();
321 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
322 assert_eq!(client.proxy(), Some(&proxy));
323
324 let proxy = Url::parse("socks5://localhost:10808").unwrap();
325 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
326 assert_eq!(client.proxy(), Some(&proxy));
327
328 let proxy = Url::parse("socks5h://localhost:10808").unwrap();
329 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
330 assert_eq!(client.proxy(), Some(&proxy));
331 }
332
333 #[test]
334 fn test_invalid_proxy_uri() {
335 let proxy = Url::parse("socks://127.0.0.1:20170").unwrap();
336 let client = ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
337 assert!(
338 client.proxy.is_none(),
339 "An invalid proxy URL should add no proxy to the client!"
340 )
341 }
342}