1use std::error::Error;
2use std::sync::{LazyLock, OnceLock};
3use std::{any::type_name, borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
4
5use anyhow::anyhow;
6use bytes::{BufMut, Bytes, BytesMut};
7use futures::{AsyncRead, TryStreamExt as _};
8use http_client::{RedirectPolicy, Url, http};
9use regex::Regex;
10use reqwest::{
11 header::{HeaderMap, HeaderValue},
12 redirect,
13};
14use smol::future::FutureExt;
15
16const DEFAULT_CAPACITY: usize = 4096;
17static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
18static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
19
20pub struct ReqwestClient {
21 client: reqwest::Client,
22 proxy: Option<Url>,
23 user_agent: Option<HeaderValue>,
24 handle: tokio::runtime::Handle,
25}
26
27impl ReqwestClient {
28 fn builder() -> reqwest::ClientBuilder {
29 reqwest::Client::builder()
30 .use_rustls_tls()
31 .connect_timeout(Duration::from_secs(10))
32 }
33
34 pub fn new() -> Self {
35 Self::builder()
36 .build()
37 .expect("Failed to initialize HTTP client")
38 .into()
39 }
40
41 pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
42 let mut map = HeaderMap::new();
43 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
44 let client = Self::builder().default_headers(map).build()?;
45 Ok(client.into())
46 }
47
48 pub fn proxy_and_user_agent(proxy: Option<Url>, user_agent: &str) -> anyhow::Result<Self> {
49 let user_agent = HeaderValue::from_str(user_agent)?;
50
51 let mut map = HeaderMap::new();
52 map.insert(http::header::USER_AGENT, user_agent.clone());
53 let mut client = Self::builder().default_headers(map);
54 let client_has_proxy;
55
56 if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
57 reqwest::Proxy::all(proxy_url.clone())
58 .inspect_err(|e| {
59 log::error!(
60 "Failed to parse proxy URL '{}': {}",
61 proxy_url,
62 e.source().unwrap_or(&e as &_)
63 )
64 })
65 .ok()
66 }) {
67 // Respect NO_PROXY env var
68 client = client.proxy(proxy.no_proxy(reqwest::NoProxy::from_env()));
69 client_has_proxy = true;
70 } else {
71 client_has_proxy = false;
72 };
73
74 let client = client
75 .use_preconfigured_tls(http_client_tls::tls_config())
76 .build()?;
77 let mut client: ReqwestClient = client.into();
78 client.proxy = client_has_proxy.then_some(proxy).flatten();
79 client.user_agent = Some(user_agent);
80 Ok(client)
81 }
82}
83
84impl From<reqwest::Client> for ReqwestClient {
85 fn from(client: reqwest::Client) -> Self {
86 let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
87 log::debug!("no tokio runtime found, creating one for Reqwest...");
88 let runtime = RUNTIME.get_or_init(|| {
89 tokio::runtime::Builder::new_multi_thread()
90 // Since we now have two executors, let's try to keep our footprint small
91 .worker_threads(1)
92 .enable_all()
93 .build()
94 .expect("Failed to initialize HTTP client")
95 });
96
97 runtime.handle().clone()
98 });
99 Self {
100 client,
101 handle,
102 proxy: None,
103 user_agent: None,
104 }
105 }
106}
107
108// This struct is essentially a re-implementation of
109// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
110// except outside of Tokio's aegis
111struct StreamReader {
112 reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
113 buf: BytesMut,
114 capacity: usize,
115}
116
117impl StreamReader {
118 fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
119 Self {
120 reader: Some(reader),
121 buf: BytesMut::new(),
122 capacity: DEFAULT_CAPACITY,
123 }
124 }
125}
126
127impl futures::Stream for StreamReader {
128 type Item = std::io::Result<Bytes>;
129
130 fn poll_next(
131 mut self: Pin<&mut Self>,
132 cx: &mut std::task::Context<'_>,
133 ) -> Poll<Option<Self::Item>> {
134 let mut this = self.as_mut();
135
136 let mut reader = match this.reader.take() {
137 Some(r) => r,
138 None => return Poll::Ready(None),
139 };
140
141 if this.buf.capacity() == 0 {
142 let capacity = this.capacity;
143 this.buf.reserve(capacity);
144 }
145
146 match poll_read_buf(&mut reader, cx, &mut this.buf) {
147 Poll::Pending => Poll::Pending,
148 Poll::Ready(Err(err)) => {
149 self.reader = None;
150
151 Poll::Ready(Some(Err(err)))
152 }
153 Poll::Ready(Ok(0)) => {
154 self.reader = None;
155 Poll::Ready(None)
156 }
157 Poll::Ready(Ok(_)) => {
158 let chunk = this.buf.split();
159 self.reader = Some(reader);
160 Poll::Ready(Some(Ok(chunk.freeze())))
161 }
162 }
163 }
164}
165
166/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
167/// Specialized for this use case
168pub fn poll_read_buf(
169 io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
170 cx: &mut std::task::Context<'_>,
171 buf: &mut BytesMut,
172) -> Poll<std::io::Result<usize>> {
173 if !buf.has_remaining_mut() {
174 return Poll::Ready(Ok(0));
175 }
176
177 let n = {
178 let dst = buf.chunk_mut();
179
180 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
181 // transparent wrapper around `[MaybeUninit<u8>]`.
182 let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
183 let mut buf = tokio::io::ReadBuf::uninit(dst);
184 let ptr = buf.filled().as_ptr();
185 let unfilled_portion = buf.initialize_unfilled();
186 // SAFETY: Pin projection
187 let io_pin = unsafe { Pin::new_unchecked(io) };
188 std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
189
190 // Ensure the pointer does not change from under us
191 assert_eq!(ptr, buf.filled().as_ptr());
192 buf.filled().len()
193 };
194
195 // Safety: This is guaranteed to be the number of initialized (and read)
196 // bytes due to the invariants provided by `ReadBuf::filled`.
197 unsafe {
198 buf.advance_mut(n);
199 }
200
201 Poll::Ready(Ok(n))
202}
203
204fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
205 if let Some(url) = error.url_mut() {
206 if let Some(query) = url.query() {
207 if let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED") {
208 url.set_query(Some(redacted.as_str()));
209 }
210 }
211 }
212 error
213}
214
215impl http_client::HttpClient for ReqwestClient {
216 fn proxy(&self) -> Option<&Url> {
217 self.proxy.as_ref()
218 }
219
220 fn type_name(&self) -> &'static str {
221 type_name::<Self>()
222 }
223
224 fn user_agent(&self) -> Option<&HeaderValue> {
225 self.user_agent.as_ref()
226 }
227
228 fn send(
229 &self,
230 req: http::Request<http_client::AsyncBody>,
231 ) -> futures::future::BoxFuture<
232 'static,
233 anyhow::Result<http_client::Response<http_client::AsyncBody>>,
234 > {
235 let (parts, body) = req.into_parts();
236
237 let mut request = self.client.request(parts.method, parts.uri.to_string());
238 request = request.headers(parts.headers);
239 if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
240 request = request.redirect_policy(match redirect_policy {
241 RedirectPolicy::NoFollow => redirect::Policy::none(),
242 RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
243 RedirectPolicy::FollowAll => redirect::Policy::limited(100),
244 });
245 }
246 let request = request.body(match body.0 {
247 http_client::Inner::Empty => reqwest::Body::default(),
248 http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
249 http_client::Inner::AsyncReader(stream) => {
250 reqwest::Body::wrap_stream(StreamReader::new(stream))
251 }
252 });
253
254 let handle = self.handle.clone();
255 async move {
256 let mut response = handle
257 .spawn(async { request.send().await })
258 .await?
259 .map_err(redact_error)?;
260
261 let headers = mem::take(response.headers_mut());
262 let mut builder = http::Response::builder()
263 .status(response.status().as_u16())
264 .version(response.version());
265 *builder.headers_mut().unwrap() = headers;
266
267 let bytes = response
268 .bytes_stream()
269 .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e))
270 .into_async_read();
271 let body = http_client::AsyncBody::from_reader(bytes);
272
273 builder.body(body).map_err(|e| anyhow!(e))
274 }
275 .boxed()
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use http_client::{HttpClient, Url};
282
283 use crate::ReqwestClient;
284
285 #[test]
286 fn test_proxy_uri() {
287 let client = ReqwestClient::new();
288 assert_eq!(client.proxy(), None);
289
290 let proxy = Url::parse("http://localhost:10809").unwrap();
291 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
292 assert_eq!(client.proxy(), Some(&proxy));
293
294 let proxy = Url::parse("https://localhost:10809").unwrap();
295 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
296 assert_eq!(client.proxy(), Some(&proxy));
297
298 let proxy = Url::parse("socks4://localhost:10808").unwrap();
299 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
300 assert_eq!(client.proxy(), Some(&proxy));
301
302 let proxy = Url::parse("socks4a://localhost:10808").unwrap();
303 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
304 assert_eq!(client.proxy(), Some(&proxy));
305
306 let proxy = Url::parse("socks5://localhost:10808").unwrap();
307 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
308 assert_eq!(client.proxy(), Some(&proxy));
309
310 let proxy = Url::parse("socks5h://localhost:10808").unwrap();
311 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
312 assert_eq!(client.proxy(), Some(&proxy));
313 }
314
315 #[test]
316 fn test_invalid_proxy_uri() {
317 let proxy = Url::parse("socks://127.0.0.1:20170").unwrap();
318 let client = ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
319 assert!(
320 client.proxy.is_none(),
321 "An invalid proxy URL should add no proxy to the client!"
322 )
323 }
324}