reqwest_client.rs

  1use std::error::Error;
  2use std::sync::{LazyLock, OnceLock};
  3use std::{borrow::Cow, mem, pin::Pin, task::Poll, time::Duration};
  4
  5use gpui_util::defer;
  6
  7use anyhow::anyhow;
  8use bytes::{BufMut, Bytes, BytesMut};
  9use futures::{AsyncRead, FutureExt as _, TryStreamExt as _};
 10use http_client::{RedirectPolicy, Url, http};
 11use regex::Regex;
 12use reqwest::{
 13    header::{HeaderMap, HeaderValue},
 14    redirect,
 15};
 16
 17const DEFAULT_CAPACITY: usize = 4096;
 18static RUNTIME: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
 19static REDACT_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"key=[^&]+").unwrap());
 20
 21pub struct ReqwestClient {
 22    client: reqwest::Client,
 23    proxy: Option<Url>,
 24    user_agent: Option<HeaderValue>,
 25    handle: tokio::runtime::Handle,
 26}
 27
 28impl ReqwestClient {
 29    fn builder() -> reqwest::ClientBuilder {
 30        reqwest::Client::builder()
 31            .use_rustls_tls()
 32            .connect_timeout(Duration::from_secs(10))
 33    }
 34
 35    pub fn new() -> Self {
 36        Self::builder()
 37            .build()
 38            .expect("Failed to initialize HTTP client")
 39            .into()
 40    }
 41
 42    pub fn user_agent(agent: &str) -> anyhow::Result<Self> {
 43        let mut map = HeaderMap::new();
 44        map.insert(http::header::USER_AGENT, HeaderValue::from_str(agent)?);
 45        let client = Self::builder().default_headers(map).build()?;
 46        Ok(client.into())
 47    }
 48
 49    pub fn proxy_and_user_agent(proxy: Option<Url>, user_agent: &str) -> anyhow::Result<Self> {
 50        let user_agent = HeaderValue::from_str(user_agent)?;
 51
 52        let mut map = HeaderMap::new();
 53        map.insert(http::header::USER_AGENT, user_agent.clone());
 54        let mut client = Self::builder().default_headers(map);
 55        let client_has_proxy;
 56
 57        if let Some(proxy) = proxy.as_ref().and_then(|proxy_url| {
 58            reqwest::Proxy::all(proxy_url.clone())
 59                .inspect_err(|e| {
 60                    log::error!(
 61                        "Failed to parse proxy URL '{}': {}",
 62                        proxy_url,
 63                        e.source().unwrap_or(&e as &_)
 64                    )
 65                })
 66                .ok()
 67        }) {
 68            // Respect NO_PROXY env var
 69            client = client.proxy(proxy.no_proxy(reqwest::NoProxy::from_env()));
 70            client_has_proxy = true;
 71        } else {
 72            client_has_proxy = false;
 73        };
 74
 75        let client = client
 76            .use_preconfigured_tls(http_client_tls::tls_config())
 77            .build()?;
 78        let mut client: ReqwestClient = client.into();
 79        client.proxy = client_has_proxy.then_some(proxy).flatten();
 80        client.user_agent = Some(user_agent);
 81        Ok(client)
 82    }
 83}
 84
 85pub fn runtime() -> &'static tokio::runtime::Runtime {
 86    RUNTIME.get_or_init(|| {
 87        tokio::runtime::Builder::new_multi_thread()
 88            // Since we now have two executors, let's try to keep our footprint small
 89            .worker_threads(1)
 90            .enable_all()
 91            .build()
 92            .expect("Failed to initialize HTTP client")
 93    })
 94}
 95
 96impl From<reqwest::Client> for ReqwestClient {
 97    fn from(client: reqwest::Client) -> Self {
 98        let handle = tokio::runtime::Handle::try_current().unwrap_or_else(|_| {
 99            log::debug!("no tokio runtime found, creating one for Reqwest...");
100            runtime().handle().clone()
101        });
102        Self {
103            client,
104            handle,
105            proxy: None,
106            user_agent: None,
107        }
108    }
109}
110
111// This struct is essentially a re-implementation of
112// https://docs.rs/tokio-util/0.7.12/tokio_util/io/struct.ReaderStream.html
113// except outside of Tokio's aegis
114struct StreamReader {
115    reader: Option<Pin<Box<dyn futures::AsyncRead + Send + Sync>>>,
116    buf: BytesMut,
117    capacity: usize,
118}
119
120impl StreamReader {
121    fn new(reader: Pin<Box<dyn futures::AsyncRead + Send + Sync>>) -> Self {
122        Self {
123            reader: Some(reader),
124            buf: BytesMut::new(),
125            capacity: DEFAULT_CAPACITY,
126        }
127    }
128}
129
130impl futures::Stream for StreamReader {
131    type Item = std::io::Result<Bytes>;
132
133    fn poll_next(
134        mut self: Pin<&mut Self>,
135        cx: &mut std::task::Context<'_>,
136    ) -> Poll<Option<Self::Item>> {
137        let mut this = self.as_mut();
138
139        let mut reader = match this.reader.take() {
140            Some(r) => r,
141            None => return Poll::Ready(None),
142        };
143
144        if this.buf.capacity() == 0 {
145            let capacity = this.capacity;
146            this.buf.reserve(capacity);
147        }
148
149        match poll_read_buf(&mut reader, cx, &mut this.buf) {
150            Poll::Pending => Poll::Pending,
151            Poll::Ready(Err(err)) => {
152                self.reader = None;
153
154                Poll::Ready(Some(Err(err)))
155            }
156            Poll::Ready(Ok(0)) => {
157                self.reader = None;
158                Poll::Ready(None)
159            }
160            Poll::Ready(Ok(_)) => {
161                let chunk = this.buf.split();
162                self.reader = Some(reader);
163                Poll::Ready(Some(Ok(chunk.freeze())))
164            }
165        }
166    }
167}
168
169/// Implementation from <https://docs.rs/tokio-util/0.7.12/src/tokio_util/util/poll_buf.rs.html>
170/// Specialized for this use case
171pub fn poll_read_buf(
172    io: &mut Pin<Box<dyn futures::AsyncRead + Send + Sync>>,
173    cx: &mut std::task::Context<'_>,
174    buf: &mut BytesMut,
175) -> Poll<std::io::Result<usize>> {
176    if !buf.has_remaining_mut() {
177        return Poll::Ready(Ok(0));
178    }
179
180    let n = {
181        let dst = buf.chunk_mut();
182
183        // Safety: `chunk_mut()` returns a `&mut UninitSlice`, and `UninitSlice` is a
184        // transparent wrapper around `[MaybeUninit<u8>]`.
185        let dst = unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit<u8>]) };
186        let mut buf = tokio::io::ReadBuf::uninit(dst);
187        let ptr = buf.filled().as_ptr();
188        let unfilled_portion = buf.initialize_unfilled();
189        // SAFETY: Pin projection
190        let io_pin = unsafe { Pin::new_unchecked(io) };
191        std::task::ready!(io_pin.poll_read(cx, unfilled_portion)?);
192
193        // Ensure the pointer does not change from under us
194        assert_eq!(ptr, buf.filled().as_ptr());
195        buf.filled().len()
196    };
197
198    // Safety: This is guaranteed to be the number of initialized (and read)
199    // bytes due to the invariants provided by `ReadBuf::filled`.
200    unsafe {
201        buf.advance_mut(n);
202    }
203
204    Poll::Ready(Ok(n))
205}
206
207fn redact_error(mut error: reqwest::Error) -> reqwest::Error {
208    if let Some(url) = error.url_mut()
209        && let Some(query) = url.query()
210        && let Cow::Owned(redacted) = REDACT_REGEX.replace_all(query, "key=REDACTED")
211    {
212        url.set_query(Some(redacted.as_str()));
213    }
214    error
215}
216
217impl http_client::HttpClient for ReqwestClient {
218    fn proxy(&self) -> Option<&Url> {
219        self.proxy.as_ref()
220    }
221
222    fn user_agent(&self) -> Option<&HeaderValue> {
223        self.user_agent.as_ref()
224    }
225
226    fn send(
227        &self,
228        req: http::Request<http_client::AsyncBody>,
229    ) -> futures::future::BoxFuture<
230        'static,
231        anyhow::Result<http_client::Response<http_client::AsyncBody>>,
232    > {
233        let (parts, body) = req.into_parts();
234
235        let mut request = self.client.request(parts.method, parts.uri.to_string());
236        request = request.headers(parts.headers);
237        if let Some(redirect_policy) = parts.extensions.get::<RedirectPolicy>() {
238            request = request.redirect_policy(match redirect_policy {
239                RedirectPolicy::NoFollow => redirect::Policy::none(),
240                RedirectPolicy::FollowLimit(limit) => redirect::Policy::limited(*limit as usize),
241                RedirectPolicy::FollowAll => redirect::Policy::limited(100),
242            });
243        }
244        let request = request.body(match body.0 {
245            http_client::Inner::Empty => reqwest::Body::default(),
246            http_client::Inner::Bytes(cursor) => cursor.into_inner().into(),
247            http_client::Inner::AsyncReader(stream) => {
248                reqwest::Body::wrap_stream(StreamReader::new(stream))
249            }
250        });
251
252        let handle = self.handle.clone();
253        async move {
254            let join_handle = handle.spawn(async { request.send().await });
255            let abort_handle = join_handle.abort_handle();
256            let _abort_on_drop = defer(move || abort_handle.abort());
257
258            let mut response = join_handle.await?.map_err(redact_error)?;
259
260            let headers = mem::take(response.headers_mut());
261            let mut builder = http::Response::builder()
262                .status(response.status().as_u16())
263                .version(response.version());
264            *builder.headers_mut().unwrap() = headers;
265
266            let bytes = response
267                .bytes_stream()
268                .map_err(futures::io::Error::other)
269                .into_async_read();
270            let body = http_client::AsyncBody::from_reader(bytes);
271
272            builder.body(body).map_err(|e| anyhow!(e))
273        }
274        .boxed()
275    }
276}
277
278#[cfg(test)]
279mod tests {
280    use http_client::{HttpClient, Url};
281
282    use crate::ReqwestClient;
283
284    #[test]
285    fn test_proxy_uri() {
286        let client = ReqwestClient::new();
287        assert_eq!(client.proxy(), None);
288
289        let proxy = Url::parse("http://localhost:10809").unwrap();
290        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
291        assert_eq!(client.proxy(), Some(&proxy));
292
293        let proxy = Url::parse("https://localhost:10809").unwrap();
294        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
295        assert_eq!(client.proxy(), Some(&proxy));
296
297        let proxy = Url::parse("socks4://localhost:10808").unwrap();
298        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
299        assert_eq!(client.proxy(), Some(&proxy));
300
301        let proxy = Url::parse("socks4a://localhost:10808").unwrap();
302        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
303        assert_eq!(client.proxy(), Some(&proxy));
304
305        let proxy = Url::parse("socks5://localhost:10808").unwrap();
306        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
307        assert_eq!(client.proxy(), Some(&proxy));
308
309        let proxy = Url::parse("socks5h://localhost:10808").unwrap();
310        let client = ReqwestClient::proxy_and_user_agent(Some(proxy.clone()), "test").unwrap();
311        assert_eq!(client.proxy(), Some(&proxy));
312    }
313
314    #[test]
315    fn test_invalid_proxy_uri() {
316        let proxy = Url::parse("socks://127.0.0.1:20170").unwrap();
317        let client = ReqwestClient::proxy_and_user_agent(Some(proxy), "test").unwrap();
318        assert!(
319            client.proxy.is_none(),
320            "An invalid proxy URL should add no proxy to the client!"
321        )
322    }
323}