1use std::error::Error;
2use std::sync::{LazyLock, OnceLock};
3use std::{borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
4
5use gpui_util::defer;
6
7use anyhow::anyhow;
8use bytes::{BufMut, Bytes, BytesMut};
9use futures::{AsyncRead, FutureExt as _, TryStreamExt as _};
10use http_client::{RedirectPolicy, Url, http};
11use regex::Regex;
12use reqwest::{
13 header::{HeaderMap, HeaderValue},
14 redirect,
15};
16
17const DEFAULT_CAPACITY: usize = 4096;
18static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
19static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
20
21pub struct ReqwestClient {
22 client: reqwest::Client,
23 proxy: Option<Url>,
24 user_agent: Option<HeaderValue>,
25 handle: tokio::runtime::Handle,
26}
27
28impl ReqwestClient {
29 fn builder() -> reqwest::ClientBuilder {
30 reqwest::Client::builder()
31 .use_rustls_tls()
32 .connect_timeout(Duration::from_secs(10))
33 }
34
35 pub fn new() -> Self {
36 Self::builder()
37 .build()
38 .expect("Failed to initialize HTTP client")
39 .into()
40 }
41
42 pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
43 let mut map = HeaderMap::new();
44 map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
45 let client = Self::builder().default_headers(map).build()?;
46 Ok(client.into())
47 }
48
49 pub fn proxy_and_user_agent(proxy: Option<Url>, user_agent: &str) -> anyhow::Result<Self> {
50 let user_agent = HeaderValue::from_str(user_agent)?;
51
52 let mut map = HeaderMap::new();
53 map.insert(http::header::USER_AGENT, user_agent.clone());
54 let mut client = Self::builder().default_headers(map);
55 let client_has_proxy;
56
57 if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
58 reqwest::Proxy::all(proxy_url.clone())
59 .inspect_err(|e| {
60 log::error!(
61 "Failed to parse proxy URL '{}': {}",
62 proxy_url,
63 e.source().unwrap_or(&e as &_)
64 )
65 })
66 .ok()
67 }) {
68 // Respect NO_PROXY env var
69 client = client.proxy(proxy.no_proxy(reqwest::NoProxy::from_env()));
70 client_has_proxy = true;
71 } else {
72 client_has_proxy = false;
73 };
74
75 let client = client
76 .use_preconfigured_tls(http_client_tls::tls_config())
77 .build()?;
78 let mut client: ReqwestClient = client.into();
79 client.proxy = client_has_proxy.then_some(proxy).flatten();
80 client.user_agent = Some(user_agent);
81 Ok(client)
82 }
83}
84
85pub fn runtime() -> &'static tokio::runtime::Runtime {
86 RUNTIME.get_or_init(|| {
87 tokio::runtime::Builder::new_multi_thread()
88 // Since we now have two executors, let's try to keep our footprint small
89 .worker_threads(1)
90 .enable_all()
91 .build()
92 .expect("Failed to initialize HTTP client")
93 })
94}
95
96impl From<reqwest::Client> for ReqwestClient {
97 fn from(client: reqwest::Client) -> Self {
98 let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
99 log::debug!("no tokio runtime found, creating one for Reqwest...");
100 runtime().handle().clone()
101 });
102 Self {
103 client,
104 handle,
105 proxy: None,
106 user_agent: None,
107 }
108 }
109}
110
111// This struct is essentially a re-implementation of
112// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
113// except outside of Tokio's aegis
114struct StreamReader {
115 reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
116 buf: BytesMut,
117 capacity: usize,
118}
119
120impl StreamReader {
121 fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
122 Self {
123 reader: Some(reader),
124 buf: BytesMut::new(),
125 capacity: DEFAULT_CAPACITY,
126 }
127 }
128}
129
130impl futures::Stream for StreamReader {
131 type Item = std::io::Result<Bytes>;
132
133 fn poll_next(
134 mut self: Pin<&mut Self>,
135 cx: &mut std::task::Context<'_>,
136 ) -> Poll<Option<Self::Item>> {
137 let mut this = self.as_mut();
138
139 let mut reader = match this.reader.take() {
140 Some(r) => r,
141 None => return Poll::Ready(None),
142 };
143
144 if this.buf.capacity() == 0 {
145 let capacity = this.capacity;
146 this.buf.reserve(capacity);
147 }
148
149 match poll_read_buf(&mut reader, cx, &mut this.buf) {
150 Poll::Pending => Poll::Pending,
151 Poll::Ready(Err(err)) => {
152 self.reader = None;
153
154 Poll::Ready(Some(Err(err)))
155 }
156 Poll::Ready(Ok(0)) => {
157 self.reader = None;
158 Poll::Ready(None)
159 }
160 Poll::Ready(Ok(_)) => {
161 let chunk = this.buf.split();
162 self.reader = Some(reader);
163 Poll::Ready(Some(Ok(chunk.freeze())))
164 }
165 }
166 }
167}
168
169/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
170/// Specialized for this use case
171pub fn poll_read_buf(
172 io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
173 cx: &mut std::task::Context<'_>,
174 buf: &mut BytesMut,
175) -> Poll<std::io::Result<usize>> {
176 if !buf.has_remaining_mut() {
177 return Poll::Ready(Ok(0));
178 }
179
180 let n = {
181 let dst = buf.chunk_mut();
182
183 // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
184 // transparent wrapper around `[MaybeUninit<u8>]`.
185 let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
186 let mut buf = tokio::io::ReadBuf::uninit(dst);
187 let ptr = buf.filled().as_ptr();
188 let unfilled_portion = buf.initialize_unfilled();
189 // SAFETY: Pin projection
190 let io_pin = unsafe { Pin::new_unchecked(io) };
191 std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
192
193 // Ensure the pointer does not change from under us
194 assert_eq!(ptr, buf.filled().as_ptr());
195 buf.filled().len()
196 };
197
198 // Safety: This is guaranteed to be the number of initialized (and read)
199 // bytes due to the invariants provided by `ReadBuf::filled`.
200 unsafe {
201 buf.advance_mut(n);
202 }
203
204 Poll::Ready(Ok(n))
205}
206
207fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
208 if let Some(url) = error.url_mut()
209 && let Some(query) = url.query()
210 && let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED")
211 {
212 url.set_query(Some(redacted.as_str()));
213 }
214 error
215}
216
217impl http_client::HttpClient for ReqwestClient {
218 fn proxy(&self) -> Option<&Url> {
219 self.proxy.as_ref()
220 }
221
222 fn user_agent(&self) -> Option<&HeaderValue> {
223 self.user_agent.as_ref()
224 }
225
226 fn send(
227 &self,
228 req: http::Request<http_client::AsyncBody>,
229 ) -> futures::future::BoxFuture<
230 'static,
231 anyhow::Result<http_client::Response<http_client::AsyncBody>>,
232 > {
233 let (parts, body) = req.into_parts();
234
235 let mut request = self.client.request(parts.method, parts.uri.to_string());
236 request = request.headers(parts.headers);
237 if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
238 request = request.redirect_policy(match redirect_policy {
239 RedirectPolicy::NoFollow => redirect::Policy::none(),
240 RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
241 RedirectPolicy::FollowAll => redirect::Policy::limited(100),
242 });
243 }
244 let request = request.body(match body.0 {
245 http_client::Inner::Empty => reqwest::Body::default(),
246 http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
247 http_client::Inner::AsyncReader(stream) => {
248 reqwest::Body::wrap_stream(StreamReader::new(stream))
249 }
250 });
251
252 let handle = self.handle.clone();
253 async move {
254 let join_handle = handle.spawn(async { request.send().await });
255 let abort_handle = join_handle.abort_handle();
256 let _abort_on_drop = defer(move || abort_handle.abort());
257
258 let mut response = join_handle.await?.map_err(redact_error)?;
259
260 let headers = mem::take(response.headers_mut());
261 let mut builder = http::Response::builder()
262 .status(response.status().as_u16())
263 .version(response.version());
264 *builder.headers_mut().unwrap() = headers;
265
266 let bytes = response
267 .bytes_stream()
268 .map_err(futures::io::Error::other)
269 .into_async_read();
270 let body = http_client::AsyncBody::from_reader(bytes);
271
272 builder.body(body).map_err(|e| anyhow!(e))
273 }
274 .boxed()
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use http_client::{HttpClient, Url};
281
282 use crate::ReqwestClient;
283
284 #[test]
285 fn test_proxy_uri() {
286 let client = ReqwestClient::new();
287 assert_eq!(client.proxy(), None);
288
289 let proxy = Url::parse("http://localhost:10809").unwrap();
290 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
291 assert_eq!(client.proxy(), Some(&proxy));
292
293 let proxy = Url::parse("https://localhost:10809").unwrap();
294 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
295 assert_eq!(client.proxy(), Some(&proxy));
296
297 let proxy = Url::parse("socks4://localhost:10808").unwrap();
298 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
299 assert_eq!(client.proxy(), Some(&proxy));
300
301 let proxy = Url::parse("socks4a://localhost:10808").unwrap();
302 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
303 assert_eq!(client.proxy(), Some(&proxy));
304
305 let proxy = Url::parse("socks5://localhost:10808").unwrap();
306 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
307 assert_eq!(client.proxy(), Some(&proxy));
308
309 let proxy = Url::parse("socks5h://localhost:10808").unwrap();
310 let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
311 assert_eq!(client.proxy(), Some(&proxy));
312 }
313
314 #[test]
315 fn test_invalid_proxy_uri() {
316 let proxy = Url::parse("socks://127.0.0.1:20170").unwrap();
317 let client = ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
318 assert!(
319 client.proxy.is_none(),
320 "An invalid proxy URL should add no proxy to the client!"
321 )
322 }
323}