1use std::collections::HashMap;
2use std::io::Read;
3use std::sync::Arc;
4use std::time::Duration;
5use std::{pin::Pin, task::Poll};
6
7use anyhow::Error;
8use futures::channel::mpsc;
9use futures::future::BoxFuture;
10use futures::{AsyncRead, SinkExt, StreamExt};
11use http_client::{http, AsyncBody, HttpClient, RedirectPolicy, Uri};
12use smol::future::FutureExt;
13use util::ResultExt;
14
15pub struct UreqClient {
16 // Note in ureq 2.x the options are stored on the Agent.
17 // In ureq 3.x we'll be able to set these on the request.
18 // In practice it's probably "fine" to have many clients, the number of distinct options
19 // is low; and most requests to the same connection will have the same options so the
20 // connection pool will work.
21 clients: Arc<parking_lot::Mutex<HashMap<(Duration, RedirectPolicy), ureq::Agent>>>,
22 proxy_url: Option<Uri>,
23 proxy: Option<ureq::Proxy>,
24 user_agent: String,
25 background_executor: gpui::BackgroundExecutor,
26}
27
28impl UreqClient {
29 pub fn new(
30 proxy_url: Option<Uri>,
31 user_agent: String,
32 background_executor: gpui::BackgroundExecutor,
33 ) -> Self {
34 Self {
35 clients: Arc::default(),
36 proxy_url: proxy_url.clone(),
37 proxy: proxy_url.and_then(|url| ureq::Proxy::new(url.to_string()).log_err()),
38 user_agent,
39 background_executor,
40 }
41 }
42
43 fn agent_for(&self, redirect_policy: RedirectPolicy, timeout: Duration) -> ureq::Agent {
44 let mut clients = self.clients.lock();
45 // in case our assumption of distinct options is wrong, we'll sporadically clean it out.
46 if clients.len() > 50 {
47 clients.clear()
48 }
49
50 clients
51 .entry((timeout, redirect_policy.clone()))
52 .or_insert_with(|| {
53 let mut builder = ureq::AgentBuilder::new()
54 .timeout_connect(Duration::from_secs(5))
55 .timeout_read(timeout)
56 .timeout_write(timeout)
57 .user_agent(&self.user_agent)
58 .tls_config(http_client::TLS_CONFIG.clone())
59 .redirects(match redirect_policy {
60 RedirectPolicy::NoFollow => 0,
61 RedirectPolicy::FollowLimit(limit) => limit,
62 RedirectPolicy::FollowAll => 100,
63 });
64 if let Some(proxy) = &self.proxy {
65 builder = builder.proxy(proxy.clone());
66 }
67 builder.build()
68 })
69 .clone()
70 }
71}
72impl HttpClient for UreqClient {
73 fn proxy(&self) -> Option<&Uri> {
74 self.proxy_url.as_ref()
75 }
76
77 fn send(
78 &self,
79 request: http::Request<AsyncBody>,
80 ) -> BoxFuture<'static, Result<http::Response<AsyncBody>, Error>> {
81 let agent = self.agent_for(
82 request
83 .extensions()
84 .get::<RedirectPolicy>()
85 .cloned()
86 .unwrap_or_default(),
87 request
88 .extensions()
89 .get::<http_client::ReadTimeout>()
90 .cloned()
91 .unwrap_or_default()
92 .0,
93 );
94 let mut req = agent.request(&request.method().as_ref(), &request.uri().to_string());
95 for (name, value) in request.headers().into_iter() {
96 req = req.set(name.as_str(), value.to_str().unwrap());
97 }
98 let body = request.into_body();
99 let executor = self.background_executor.clone();
100
101 self.background_executor
102 .spawn(async move {
103 let response = req.send(body)?;
104
105 let mut builder = http::Response::builder()
106 .status(response.status())
107 .version(http::Version::HTTP_11);
108 for name in response.headers_names() {
109 if let Some(value) = response.header(&name) {
110 builder = builder.header(name, value);
111 }
112 }
113
114 let body = AsyncBody::from_reader(UreqResponseReader::new(executor, response));
115 let http_response = builder.body(body)?;
116
117 Ok(http_response)
118 })
119 .boxed()
120 }
121}
122
123struct UreqResponseReader {
124 receiver: mpsc::Receiver<std::io::Result<Vec<u8>>>,
125 buffer: Vec<u8>,
126 idx: usize,
127 _task: gpui::Task<()>,
128}
129
130impl UreqResponseReader {
131 fn new(background_executor: gpui::BackgroundExecutor, response: ureq::Response) -> Self {
132 let (mut sender, receiver) = mpsc::channel(1);
133 let mut reader = response.into_reader();
134 let task = background_executor.spawn(async move {
135 let mut buffer = vec![0; 8192];
136 loop {
137 let n = match reader.read(&mut buffer) {
138 Ok(0) => break,
139 Ok(n) => n,
140 Err(e) => {
141 let _ = sender.send(Err(e)).await;
142 break;
143 }
144 };
145 let _ = sender.send(Ok(buffer[..n].to_vec())).await;
146 }
147 });
148
149 UreqResponseReader {
150 _task: task,
151 receiver,
152 buffer: Vec::new(),
153 idx: 0,
154 }
155 }
156}
157
158impl AsyncRead for UreqResponseReader {
159 fn poll_read(
160 mut self: Pin<&mut Self>,
161 cx: &mut std::task::Context<'_>,
162 buf: &mut [u8],
163 ) -> Poll<std::io::Result<usize>> {
164 if self.buffer.is_empty() {
165 match self.receiver.poll_next_unpin(cx) {
166 Poll::Ready(Some(Ok(data))) => self.buffer = data,
167 Poll::Ready(Some(Err(e))) => {
168 return Poll::Ready(Err(e));
169 }
170 Poll::Ready(None) => {
171 return Poll::Ready(Ok(0));
172 }
173 Poll::Pending => {
174 return Poll::Pending;
175 }
176 }
177 }
178 let n = std::cmp::min(buf.len(), self.buffer.len() - self.idx);
179 buf[..n].copy_from_slice(&self.buffer[self.idx..self.idx + n]);
180 self.idx += n;
181 if self.idx == self.buffer.len() {
182 self.buffer.clear();
183 self.idx = 0;
184 }
185 Poll::Ready(Ok(n))
186 }
187}