1use std::error::Error;
2use std::sync::{LazyLock, OnceLock};
3use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::anyhow;
6use bytes::{BufMut, Bytes, BytesMut};
7use futures::{AsyncRead, TryStreamExt as _};
8use http_client::{RedirectPolicy, Url, http};
9use regex::Regex;
10use reqwest::{
11 header::{HeaderMap, HeaderValue},
12 redirect,
13};
14use smol::future::FutureExt;
15
16const DEFAULT_CAPACITY: usize = 4096;
17static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
18static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
19
20pub struct ReqwestClient {
21 client: reqwest::Client,
22 proxy: Option<Url>,
23 handle: tokio::runtime::Handle,
24}
25
26impl ReqwestClient {
27 fn builder() -> reqwest::ClientBuilder {
28 reqwest::Client::builder()
29 .use_rustls_tls()
30 .connect_timeout(Duration::from_secs(10))
31 }
32
33 pub fn new() -> Self {
34 Self::builder()
35 .build()
36 .expect("Failed to initialize HTTP client")
37 .into()
38 }
39
40 pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
41 let mut map = HeaderMap::new();
42 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
43 let client = Self::builder().default_headers(map).build()?;
44 Ok(client.into())
45 }
46
47 pub fn proxy_and_user_agent(proxy: Option<Url>, agent: &str) -> anyhow::Result<Self> {
48 let mut map = HeaderMap::new();
49 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
50 let mut client = Self::builder().default_headers(map);
51 let client_has_proxy;
52
53 if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
54 reqwest::Proxy::all(proxy_url.clone())
55 .inspect_err(|e| {
56 log::error!(
57 "Failed to parse proxy URL '{}': {}",
58 proxy_url,
59 e.source().unwrap_or(&e as &_)
60 )
61 })
62 .ok()
63 }) {
64 client = client.proxy(proxy);
65 client_has_proxy = true;
66 } else {
67 client_has_proxy = false;
68 };
69
70 let client = client
71 .use_preconfigured_tls(http_client_tls::tls_config())
72 .build()?;
73 let mut client: ReqwestClient = client.into();
74 client.proxy = client_has_proxy.then_some(proxy).flatten();
75 Ok(client)
76 }
77}
78
79impl From<reqwest::Client> for ReqwestClient {
80 fn from(client: reqwest::Client) -> Self {
81 let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
82 log::debug!("no tokio runtime found, creating one for Reqwest...");
83 let runtime = RUNTIME.get_or_init(|| {
84 tokio::runtime::Builder::new_multi_thread()
85 // Since we now have two executors, let's try to keep our footprint small
86 .worker_threads(1)
87 .enable_all()
88 .build()
89 .expect("Failed to initialize HTTP client")
90 });
91
92 runtime.handle().clone()
93 });
94 Self {
95 client,
96 handle,
97 proxy: None,
98 }
99 }
100}
101
102// This struct is essentially a re-implementation of
103// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
104// except outside of Tokio's aegis
105struct StreamReader {
106 reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
107 buf: BytesMut,
108 capacity: usize,
109}
110
111impl StreamReader {
112 fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
113 Self {
114 reader: Some(reader),
115 buf: BytesMut::new(),
116 capacity: DEFAULT_CAPACITY,
117 }
118 }
119}
120
121impl futures::Stream for StreamReader {
122 type Item = std::io::Result<Bytes>;
123
124 fn poll_next(
125 mut self: Pin<&mut Self>,
126 cx: &mut std::task::Context<'_>,
127 ) -> Poll<Option<Self::Item>> {
128 let mut this = self.as_mut();
129
130 let mut reader = match this.reader.take() {
131 Some(r) => r,
132 None => return Poll::Ready(None),
133 };
134
135 if this.buf.capacity() == 0 {
136 let capacity = this.capacity;
137 this.buf.reserve(capacity);
138 }
139
140 match poll_read_buf(&mut reader, cx, &mut this.buf) {
141 Poll::Pending => Poll::Pending,
142 Poll::Ready(Err(err)) => {
143 self.reader = None;
144
145 Poll::Ready(Some(Err(err)))
146 }
147 Poll::Ready(Ok(0)) => {
148 self.reader = None;
149 Poll::Ready(None)
150 }
151 Poll::Ready(Ok(_)) => {
152 let chunk = this.buf.split();
153 self.reader = Some(reader);
154 Poll::Ready(Some(Ok(chunk.freeze())))
155 }
156 }
157 }
158}
159
160/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
161/// Specialized for this use case
162pub fn poll_read_buf(
163 io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
164 cx: &mut std::task::Context<'_>,
165 buf: &mut BytesMut,
166) -> Poll<std::io::Result<usize>> {
167 if !buf.has_remaining_mut() {
168 return Poll::Ready(Ok(0));
169 }
170
171 let n = {
172 let dst = buf.chunk_mut();
173
174 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
175 // transparent wrapper around `[MaybeUninit<u8>]`.
176 let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
177 let mut buf = tokio::io::ReadBuf::uninit(dst);
178 let ptr = buf.filled().as_ptr();
179 let unfilled_portion = buf.initialize_unfilled();
180 // SAFETY: Pin projection
181 let io_pin = unsafe { Pin::new_unchecked(io) };
182 std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
183
184 // Ensure the pointer does not change from under us
185 assert_eq!(ptr, buf.filled().as_ptr());
186 buf.filled().len()
187 };
188
189 // Safety: This is guaranteed to be the number of initialized (and read)
190 // bytes due to the invariants provided by `ReadBuf::filled`.
191 unsafe {
192 buf.advance_mut(n);
193 }
194
195 Poll::Ready(Ok(n))
196}
197
198fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
199 if let Some(url) = error.url_mut() {
200 if let Some(query) = url.query() {
201 if let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED") {
202 url.set_query(Some(redacted.as_str()));
203 }
204 }
205 }
206 error
207}
208
209impl http_client::HttpClient for ReqwestClient {
210 fn proxy(&self) -> Option<&Url> {
211 self.proxy.as_ref()
212 }
213
214 fn type_name(&self) -> &'static str {
215 type_name::<Self>()
216 }
217
218 fn send(
219 &self,
220 req: http::Request<http_client::AsyncBody>,
221 ) -> futures::future::BoxFuture<
222 'static,
223 anyhow::Result<http_client::Response<http_client::AsyncBody>>,
224 > {
225 let (parts, body) = req.into_parts();
226
227 let mut request = self.client.request(parts.method, parts.uri.to_string());
228 request = request.headers(parts.headers);
229 if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
230 request = request.redirect_policy(match redirect_policy {
231 RedirectPolicy::NoFollow => redirect::Policy::none(),
232 RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
233 RedirectPolicy::FollowAll => redirect::Policy::limited(100),
234 });
235 }
236 let request = request.body(match body.0 {
237 http_client::Inner::Empty => reqwest::Body::default(),
238 http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
239 http_client::Inner::AsyncReader(stream) => {
240 reqwest::Body::wrap_stream(StreamReader::new(stream))
241 }
242 });
243
244 let handle = self.handle.clone();
245 async move {
246 let mut response = handle
247 .spawn(async { request.send().await })
248 .await?
249 .map_err(redact_error)?;
250
251 let headers = mem::take(response.headers_mut());
252 let mut builder = http::Response::builder()
253 .status(response.status().as_u16())
254 .version(response.version());
255 *builder.headers_mut().unwrap() = headers;
256
257 let bytes = response
258 .bytes_stream()
259 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
260 .into_async_read();
261 let body = http_client::AsyncBody::from_reader(bytes);
262
263 builder.body(body).map_err(|e| anyhow!(e))
264 }
265 .boxed()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use http_client::{HttpClient, Url};
272
273 use crate::ReqwestClient;
274
275 #[test]
276 fn test_proxy_uri() {
277 let client = ReqwestClient::new();
278 assert_eq!(client.proxy(), None);
279
280 let proxy = Url::parse("http://localhost:10809").unwrap();
281 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
282 assert_eq!(client.proxy(), Some(&proxy));
283
284 let proxy = Url::parse("https://localhost:10809").unwrap();
285 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
286 assert_eq!(client.proxy(), Some(&proxy));
287
288 let proxy = Url::parse("socks4://localhost:10808").unwrap();
289 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
290 assert_eq!(client.proxy(), Some(&proxy));
291
292 let proxy = Url::parse("socks4a://localhost:10808").unwrap();
293 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
294 assert_eq!(client.proxy(), Some(&proxy));
295
296 let proxy = Url::parse("socks5://localhost:10808").unwrap();
297 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
298 assert_eq!(client.proxy(), Some(&proxy));
299
300 let proxy = Url::parse("socks5h://localhost:10808").unwrap();
301 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
302 assert_eq!(client.proxy(), Some(&proxy));
303 }
304
305 #[test]
306 fn test_invalid_proxy_uri() {
307 let proxy = Url::parse("socks://127.0.0.1:20170").unwrap();
308 let client = ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
309 assert!(
310 client.proxy.is_none(),
311 "An invalid proxy URL should add no proxy to the client!"
312 )
313 }
314}