ureq_client.rs

  1use std::collections::HashMap;
  2use std::io::Read;
  3use std::sync::Arc;
  4use std::time::Duration;
  5use std::{pin::Pin, task::Poll};
  6
  7use anyhow::Error;
  8use futures::channel::mpsc;
  9use futures::future::BoxFuture;
 10use futures::{AsyncRead, SinkExt, StreamExt};
 11use http_client::{http, AsyncBody, HttpClient, RedirectPolicy, Uri};
 12use smol::future::FutureExt;
 13use util::ResultExt;
 14
 15pub struct UreqClient {
 16    // Note in ureq 2.x the options are stored on the Agent.
 17    // In ureq 3.x we'll be able to set these on the request.
 18    // In practice it's probably "fine" to have many clients, the number of distinct options
 19    // is low; and most requests to the same connection will have the same options so the
 20    // connection pool will work.
 21    clients: Arc<parking_lot::Mutex<HashMap<(Duration, RedirectPolicy), ureq::Agent>>>,
 22    proxy_url: Option<Uri>,
 23    proxy: Option<ureq::Proxy>,
 24    user_agent: String,
 25    background_executor: gpui::BackgroundExecutor,
 26}
 27
 28impl UreqClient {
 29    pub fn new(
 30        proxy_url: Option<Uri>,
 31        user_agent: String,
 32        background_executor: gpui::BackgroundExecutor,
 33    ) -> Self {
 34        Self {
 35            clients: Arc::default(),
 36            proxy_url: proxy_url.clone(),
 37            proxy: proxy_url.and_then(|url| ureq::Proxy::new(url.to_string()).log_err()),
 38            user_agent,
 39            background_executor,
 40        }
 41    }
 42
 43    fn agent_for(&self, redirect_policy: RedirectPolicy, timeout: Duration) -> ureq::Agent {
 44        let mut clients = self.clients.lock();
 45        // in case our assumption of distinct options is wrong, we'll sporadically clean it out.
 46        if clients.len() > 50 {
 47            clients.clear()
 48        }
 49
 50        clients
 51            .entry((timeout, redirect_policy.clone()))
 52            .or_insert_with(|| {
 53                let mut builder = ureq::AgentBuilder::new()
 54                    .timeout_connect(Duration::from_secs(5))
 55                    .timeout_read(timeout)
 56                    .timeout_write(timeout)
 57                    .user_agent(&self.user_agent)
 58                    .tls_config(http_client::TLS_CONFIG.clone())
 59                    .redirects(match redirect_policy {
 60                        RedirectPolicy::NoFollow => 0,
 61                        RedirectPolicy::FollowLimit(limit) => limit,
 62                        RedirectPolicy::FollowAll => 100,
 63                    });
 64                if let Some(proxy) = &self.proxy {
 65                    builder = builder.proxy(proxy.clone());
 66                }
 67                builder.build()
 68            })
 69            .clone()
 70    }
 71}
 72impl HttpClient for UreqClient {
 73    fn proxy(&self) -> Option<&Uri> {
 74        self.proxy_url.as_ref()
 75    }
 76
 77    fn send(
 78        &self,
 79        request: http::Request<AsyncBody>,
 80    ) -> BoxFuture<'static, Result<http::Response<AsyncBody>, Error>> {
 81        let agent = self.agent_for(
 82            request
 83                .extensions()
 84                .get::<RedirectPolicy>()
 85                .cloned()
 86                .unwrap_or_default(),
 87            request
 88                .extensions()
 89                .get::<http_client::ReadTimeout>()
 90                .cloned()
 91                .unwrap_or_default()
 92                .0,
 93        );
 94        let mut req = agent.request(&request.method().as_ref(), &request.uri().to_string());
 95        for (name, value) in request.headers().into_iter() {
 96            req = req.set(name.as_str(), value.to_str().unwrap());
 97        }
 98        let body = request.into_body();
 99        let executor = self.background_executor.clone();
100
101        self.background_executor
102            .spawn(async move {
103                let response = req.send(body)?;
104
105                let mut builder = http::Response::builder()
106                    .status(response.status())
107                    .version(http::Version::HTTP_11);
108                for name in response.headers_names() {
109                    if let Some(value) = response.header(&name) {
110                        builder = builder.header(name, value);
111                    }
112                }
113
114                let body = AsyncBody::from_reader(UreqResponseReader::new(executor, response));
115                let http_response = builder.body(body)?;
116
117                Ok(http_response)
118            })
119            .boxed()
120    }
121}
122
123struct UreqResponseReader {
124    receiver: mpsc::Receiver<std::io::Result<Vec<u8>>>,
125    buffer: Vec<u8>,
126    idx: usize,
127    _task: gpui::Task<()>,
128}
129
130impl UreqResponseReader {
131    fn new(background_executor: gpui::BackgroundExecutor, response: ureq::Response) -> Self {
132        let (mut sender, receiver) = mpsc::channel(1);
133        let mut reader = response.into_reader();
134        let task = background_executor.spawn(async move {
135            let mut buffer = vec![0; 8192];
136            loop {
137                let n = match reader.read(&mut buffer) {
138                    Ok(0) => break,
139                    Ok(n) => n,
140                    Err(e) => {
141                        let _ = sender.send(Err(e)).await;
142                        break;
143                    }
144                };
145                let _ = sender.send(Ok(buffer[..n].to_vec())).await;
146            }
147        });
148
149        UreqResponseReader {
150            _task: task,
151            receiver,
152            buffer: Vec::new(),
153            idx: 0,
154        }
155    }
156}
157
158impl AsyncRead for UreqResponseReader {
159    fn poll_read(
160        mut self: Pin<&mut Self>,
161        cx: &mut std::task::Context<'_>,
162        buf: &mut [u8],
163    ) -> Poll<std::io::Result<usize>> {
164        if self.buffer.is_empty() {
165            match self.receiver.poll_next_unpin(cx) {
166                Poll::Ready(Some(Ok(data))) => self.buffer = data,
167                Poll::Ready(Some(Err(e))) => {
168                    return Poll::Ready(Err(e));
169                }
170                Poll::Ready(None) => {
171                    return Poll::Ready(Ok(0));
172                }
173                Poll::Pending => {
174                    return Poll::Pending;
175                }
176            }
177        }
178        let n = std::cmp::min(buf.len(), self.buffer.len() - self.idx);
179        buf[..n].copy_from_slice(&self.buffer[self.idx..self.idx + n]);
180        self.idx += n;
181        if self.idx == self.buffer.len() {
182            self.buffer.clear();
183            self.idx = 0;
184        }
185        Poll::Ready(Ok(n))
186    }
187}