http_client.rs

  1mod async_body;
  2#[cfg(not(target_family = "wasm"))]
  3pub mod github;
  4#[cfg(not(target_family = "wasm"))]
  5pub mod github_download;
  6
  7pub use anyhow::{Result, anyhow};
  8pub use async_body::{AsyncBody, Inner, Json};
  9use derive_more::Deref;
 10use http::HeaderValue;
 11pub use http::{self, Method, Request, Response, StatusCode, Uri, request::Builder};
 12
 13use futures::future::BoxFuture;
 14use parking_lot::Mutex;
 15use serde::Serialize;
 16use std::sync::Arc;
 17#[cfg(feature = "test-support")]
 18use std::{any::type_name, fmt};
 19pub use url::{Host, Url};
 20
 21#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)]
 22pub enum RedirectPolicy {
 23    #[default]
 24    NoFollow,
 25    FollowLimit(u32),
 26    FollowAll,
 27}
 28pub struct FollowRedirects(pub bool);
 29
 30pub trait HttpRequestExt {
 31    /// Conditionally modify self with the given closure.
 32    fn when(self, condition: bool, then: impl FnOnce(Self) -> Self) -> Self
 33    where
 34        Self: Sized,
 35    {
 36        if condition { then(self) } else { self }
 37    }
 38
 39    /// Conditionally unwrap and modify self with the given closure, if the given option is Some.
 40    fn when_some<T>(self, option: Option<T>, then: impl FnOnce(Self, T) -> Self) -> Self
 41    where
 42        Self: Sized,
 43    {
 44        match option {
 45            Some(value) => then(self, value),
 46            None => self,
 47        }
 48    }
 49
 50    /// Whether or not to follow redirects
 51    fn follow_redirects(self, follow: RedirectPolicy) -> Self;
 52}
 53
 54impl HttpRequestExt for http::request::Builder {
 55    fn follow_redirects(self, follow: RedirectPolicy) -> Self {
 56        self.extension(follow)
 57    }
 58}
 59
 60pub trait HttpClient: 'static + Send + Sync {
 61    fn user_agent(&self) -> Option<&HeaderValue>;
 62
 63    fn proxy(&self) -> Option<&Url>;
 64
 65    fn send(
 66        &self,
 67        req: http::Request<AsyncBody>,
 68    ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>>;
 69
 70    fn get(
 71        &self,
 72        uri: &str,
 73        body: AsyncBody,
 74        follow_redirects: bool,
 75    ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
 76        let request = Builder::new()
 77            .uri(uri)
 78            .follow_redirects(if follow_redirects {
 79                RedirectPolicy::FollowAll
 80            } else {
 81                RedirectPolicy::NoFollow
 82            })
 83            .body(body);
 84
 85        match request {
 86            Ok(request) => self.send(request),
 87            Err(e) => Box::pin(async move { Err(e.into()) }),
 88        }
 89    }
 90
 91    fn post_json(
 92        &self,
 93        uri: &str,
 94        body: AsyncBody,
 95    ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
 96        let request = Builder::new()
 97            .uri(uri)
 98            .method(Method::POST)
 99            .header("Content-Type", "application/json")
100            .body(body);
101
102        match request {
103            Ok(request) => self.send(request),
104            Err(e) => Box::pin(async move { Err(e.into()) }),
105        }
106    }
107
108    #[cfg(feature = "test-support")]
109    fn as_fake(&self) -> &FakeHttpClient {
110        panic!("called as_fake on {}", type_name::<Self>())
111    }
112}
113
114/// An [`HttpClient`] that may have a proxy.
115#[derive(Deref)]
116pub struct HttpClientWithProxy {
117    #[deref]
118    client: Arc<dyn HttpClient>,
119    proxy: Option<Url>,
120}
121
122impl HttpClientWithProxy {
123    /// Returns a new [`HttpClientWithProxy`] with the given proxy URL.
124    pub fn new(client: Arc<dyn HttpClient>, proxy_url: Option<String>) -> Self {
125        let proxy_url = proxy_url
126            .and_then(|proxy| proxy.parse().ok())
127            .or_else(read_proxy_from_env);
128
129        Self::new_url(client, proxy_url)
130    }
131    pub fn new_url(client: Arc<dyn HttpClient>, proxy_url: Option<Url>) -> Self {
132        Self {
133            client,
134            proxy: proxy_url,
135        }
136    }
137}
138
139impl HttpClient for HttpClientWithProxy {
140    fn send(
141        &self,
142        req: Request<AsyncBody>,
143    ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
144        self.client.send(req)
145    }
146
147    fn user_agent(&self) -> Option<&HeaderValue> {
148        self.client.user_agent()
149    }
150
151    fn proxy(&self) -> Option<&Url> {
152        self.proxy.as_ref()
153    }
154
155    #[cfg(feature = "test-support")]
156    fn as_fake(&self) -> &FakeHttpClient {
157        self.client.as_fake()
158    }
159}
160
161/// An [`HttpClient`] that has a base URL.
162#[derive(Deref)]
163pub struct HttpClientWithUrl {
164    base_url: Mutex<String>,
165    #[deref]
166    client: HttpClientWithProxy,
167}
168
169impl HttpClientWithUrl {
170    /// Returns a new [`HttpClientWithUrl`] with the given base URL.
171    pub fn new(
172        client: Arc<dyn HttpClient>,
173        base_url: impl Into<String>,
174        proxy_url: Option<String>,
175    ) -> Self {
176        let client = HttpClientWithProxy::new(client, proxy_url);
177
178        Self {
179            base_url: Mutex::new(base_url.into()),
180            client,
181        }
182    }
183
184    pub fn new_url(
185        client: Arc<dyn HttpClient>,
186        base_url: impl Into<String>,
187        proxy_url: Option<Url>,
188    ) -> Self {
189        let client = HttpClientWithProxy::new_url(client, proxy_url);
190
191        Self {
192            base_url: Mutex::new(base_url.into()),
193            client,
194        }
195    }
196
197    /// Returns the base URL.
198    pub fn base_url(&self) -> String {
199        self.base_url.lock().clone()
200    }
201
202    /// Sets the base URL.
203    pub fn set_base_url(&self, base_url: impl Into<String>) {
204        let base_url = base_url.into();
205        *self.base_url.lock() = base_url;
206    }
207
208    /// Builds a URL using the given path.
209    pub fn build_url(&self, path: &str) -> String {
210        format!("{}{}", self.base_url(), path)
211    }
212
213    /// Builds a Zed API URL using the given path.
214    pub fn build_zed_api_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
215        let base_url = self.base_url();
216        let base_api_url = match base_url.as_ref() {
217            "https://zed.dev" => "https://api.zed.dev",
218            "https://staging.zed.dev" => "https://api-staging.zed.dev",
219            "http://localhost:3000" => "http://localhost:8080",
220            other => other,
221        };
222
223        Ok(Url::parse_with_params(
224            &format!("{}{}", base_api_url, path),
225            query,
226        )?)
227    }
228
229    /// Builds a Zed Cloud URL using the given path.
230    pub fn build_zed_cloud_url(&self, path: &str) -> Result<Url> {
231        let base_url = self.base_url();
232        let base_api_url = match base_url.as_ref() {
233            "https://zed.dev" => "https://cloud.zed.dev",
234            "https://staging.zed.dev" => "https://cloud.zed.dev",
235            "http://localhost:3000" => "http://localhost:8787",
236            other => other,
237        };
238
239        Ok(Url::parse(&format!("{}{}", base_api_url, path))?)
240    }
241
242    /// Builds a Zed Cloud URL using the given path and query params.
243    pub fn build_zed_cloud_url_with_query(&self, path: &str, query: impl Serialize) -> Result<Url> {
244        let base_url = self.base_url();
245        let base_api_url = match base_url.as_ref() {
246            "https://zed.dev" => "https://cloud.zed.dev",
247            "https://staging.zed.dev" => "https://cloud.zed.dev",
248            "http://localhost:3000" => "http://localhost:8787",
249            other => other,
250        };
251        let query = serde_urlencoded::to_string(&query)?;
252        Ok(Url::parse(&format!("{}{}?{}", base_api_url, path, query))?)
253    }
254
255    /// Builds a Zed LLM URL using the given path.
256    pub fn build_zed_llm_url(&self, path: &str, query: &[(&str, &str)]) -> Result<Url> {
257        let base_url = self.base_url();
258        let base_api_url = match base_url.as_ref() {
259            "https://zed.dev" => "https://cloud.zed.dev",
260            "https://staging.zed.dev" => "https://llm-staging.zed.dev",
261            "http://localhost:3000" => "http://localhost:8787",
262            other => other,
263        };
264
265        Ok(Url::parse_with_params(
266            &format!("{}{}", base_api_url, path),
267            query,
268        )?)
269    }
270}
271
272impl HttpClient for HttpClientWithUrl {
273    fn send(
274        &self,
275        req: Request<AsyncBody>,
276    ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
277        self.client.send(req)
278    }
279
280    fn user_agent(&self) -> Option<&HeaderValue> {
281        self.client.user_agent()
282    }
283
284    fn proxy(&self) -> Option<&Url> {
285        self.client.proxy.as_ref()
286    }
287
288    #[cfg(feature = "test-support")]
289    fn as_fake(&self) -> &FakeHttpClient {
290        self.client.as_fake()
291    }
292}
293
294fn html_escape(input: &str) -> String {
295    let mut output = String::with_capacity(input.len());
296    for ch in input.chars() {
297        match ch {
298            '&' => output.push_str("&amp;"),
299            '<' => output.push_str("&lt;"),
300            '>' => output.push_str("&gt;"),
301            '"' => output.push_str("&quot;"),
302            '\'' => output.push_str("&#x27;"),
303            _ => output.push(ch),
304        }
305    }
306    output
307}
308
309/// Generate a styled HTML page for OAuth callback responses.
310///
311/// Returns a complete HTML document (no HTTP headers) with a centered card
312/// layout styled to match Zed's dark theme. The `title` is rendered as a
313/// heading and `message` as body text below it.
314///
315/// When `is_error` is true, a red X icon is shown instead of the green
316/// checkmark.
317pub fn oauth_callback_page(title: &str, message: &str, is_error: bool) -> String {
318    let title = html_escape(title);
319    let message = html_escape(message);
320    let (icon_bg, icon_svg) = if is_error {
321        (
322            "#f38ba8",
323            r#"<svg viewBox="0 0 24 24"><line x1="18" y1="6" x2="6" y2="18"/><line x1="6" y1="6" x2="18" y2="18"/></svg>"#,
324        )
325    } else {
326        (
327            "#a6e3a1",
328            r#"<svg viewBox="0 0 24 24"><polyline points="20 6 9 17 4 12"/></svg>"#,
329        )
330    };
331    format!(
332        r#"<!DOCTYPE html>
333<html lang="en">
334<head>
335<meta charset="utf-8">
336<meta name="viewport" content="width=device-width, initial-scale=1">
337<title>{title} — Zed</title>
338<style>
339  * {{ margin: 0; padding: 0; box-sizing: border-box; }}
340  body {{
341    font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
342    background: #1e1e2e;
343    color: #cdd6f4;
344    display: flex;
345    align-items: center;
346    justify-content: center;
347    min-height: 100vh;
348    padding: 1rem;
349  }}
350  .card {{
351    background: #313244;
352    border-radius: 12px;
353    padding: 2.5rem;
354    max-width: 420px;
355    width: 100%;
356    text-align: center;
357    box-shadow: 0 4px 24px rgba(0, 0, 0, 0.3);
358  }}
359  .icon {{
360    width: 48px;
361    height: 48px;
362    margin: 0 auto 1.5rem;
363    background: {icon_bg};
364    border-radius: 50%;
365    display: flex;
366    align-items: center;
367    justify-content: center;
368  }}
369  .icon svg {{
370    width: 24px;
371    height: 24px;
372    stroke: #1e1e2e;
373    stroke-width: 3;
374    fill: none;
375  }}
376  h1 {{
377    font-size: 1.25rem;
378    font-weight: 600;
379    margin-bottom: 0.75rem;
380    color: #cdd6f4;
381  }}
382  p {{
383    font-size: 0.925rem;
384    line-height: 1.5;
385    color: #a6adc8;
386  }}
387  .brand {{
388    margin-top: 1.5rem;
389    font-size: 0.8rem;
390    color: #585b70;
391    letter-spacing: 0.05em;
392  }}
393</style>
394</head>
395<body>
396<div class="card">
397  <div class="icon">
398    {icon_svg}
399  </div>
400  <h1>{title}</h1>
401  <p>{message}</p>
402  <div class="brand">Zed</div>
403</div>
404</body>
405</html>"#,
406        title = title,
407        message = message,
408        icon_bg = icon_bg,
409        icon_svg = icon_svg,
410    )
411}
412
413pub fn read_proxy_from_env() -> Option<Url> {
414    const ENV_VARS: &[&str] = &[
415        "ALL_PROXY",
416        "all_proxy",
417        "HTTPS_PROXY",
418        "https_proxy",
419        "HTTP_PROXY",
420        "http_proxy",
421    ];
422
423    ENV_VARS
424        .iter()
425        .find_map(|var| std::env::var(var).ok())
426        .and_then(|env| env.parse().ok())
427}
428
429pub fn read_no_proxy_from_env() -> Option<String> {
430    const ENV_VARS: &[&str] = &["NO_PROXY", "no_proxy"];
431
432    ENV_VARS.iter().find_map(|var| std::env::var(var).ok())
433}
434
435pub struct BlockedHttpClient;
436
437impl BlockedHttpClient {
438    pub fn new() -> Self {
439        BlockedHttpClient
440    }
441}
442
443impl HttpClient for BlockedHttpClient {
444    fn send(
445        &self,
446        _req: Request<AsyncBody>,
447    ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
448        Box::pin(async {
449            Err(std::io::Error::new(
450                std::io::ErrorKind::PermissionDenied,
451                "BlockedHttpClient disallowed request",
452            )
453            .into())
454        })
455    }
456
457    fn user_agent(&self) -> Option<&HeaderValue> {
458        None
459    }
460
461    fn proxy(&self) -> Option<&Url> {
462        None
463    }
464
465    #[cfg(feature = "test-support")]
466    fn as_fake(&self) -> &FakeHttpClient {
467        panic!("called as_fake on {}", type_name::<Self>())
468    }
469}
470
471#[cfg(feature = "test-support")]
472type FakeHttpHandler = Arc<
473    dyn Fn(Request<AsyncBody>) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>>
474        + Send
475        + Sync
476        + 'static,
477>;
478
479#[cfg(feature = "test-support")]
480pub struct FakeHttpClient {
481    handler: Mutex<Option<FakeHttpHandler>>,
482    user_agent: HeaderValue,
483}
484
485#[cfg(feature = "test-support")]
486impl FakeHttpClient {
487    pub fn create<Fut, F>(handler: F) -> Arc<HttpClientWithUrl>
488    where
489        Fut: futures::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send + 'static,
490        F: Fn(Request<AsyncBody>) -> Fut + Send + Sync + 'static,
491    {
492        Arc::new(HttpClientWithUrl {
493            base_url: Mutex::new("http://test.example".into()),
494            client: HttpClientWithProxy {
495                client: Arc::new(Self {
496                    handler: Mutex::new(Some(Arc::new(move |req| Box::pin(handler(req))))),
497                    user_agent: HeaderValue::from_static(type_name::<Self>()),
498                }),
499                proxy: None,
500            },
501        })
502    }
503
504    pub fn with_404_response() -> Arc<HttpClientWithUrl> {
505        log::warn!("Using fake HTTP client with 404 response");
506        Self::create(|_| async move {
507            Ok(Response::builder()
508                .status(404)
509                .body(Default::default())
510                .unwrap())
511        })
512    }
513
514    pub fn with_200_response() -> Arc<HttpClientWithUrl> {
515        log::warn!("Using fake HTTP client with 200 response");
516        Self::create(|_| async move {
517            Ok(Response::builder()
518                .status(200)
519                .body(Default::default())
520                .unwrap())
521        })
522    }
523
524    pub fn replace_handler<Fut, F>(&self, new_handler: F)
525    where
526        Fut: futures::Future<Output = anyhow::Result<Response<AsyncBody>>> + Send + 'static,
527        F: Fn(FakeHttpHandler, Request<AsyncBody>) -> Fut + Send + Sync + 'static,
528    {
529        let mut handler = self.handler.lock();
530        let old_handler = handler.take().unwrap();
531        *handler = Some(Arc::new(move |req| {
532            Box::pin(new_handler(old_handler.clone(), req))
533        }));
534    }
535}
536
537#[cfg(feature = "test-support")]
538impl fmt::Debug for FakeHttpClient {
539    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
540        f.debug_struct("FakeHttpClient").finish()
541    }
542}
543
544#[cfg(feature = "test-support")]
545impl HttpClient for FakeHttpClient {
546    fn send(
547        &self,
548        req: Request<AsyncBody>,
549    ) -> BoxFuture<'static, anyhow::Result<Response<AsyncBody>>> {
550        ((self.handler.lock().as_ref().unwrap())(req)) as _
551    }
552
553    fn user_agent(&self) -> Option<&HeaderValue> {
554        Some(&self.user_agent)
555    }
556
557    fn proxy(&self) -> Option<&Url> {
558        None
559    }
560
561    fn as_fake(&self) -> &FakeHttpClient {
562        self
563    }
564}
565
566// ---------------------------------------------------------------------------
567// Shared OAuth callback server (non-wasm only)
568// ---------------------------------------------------------------------------
569
570#[cfg(not(target_family = "wasm"))]
571mod oauth_callback_server {
572    use super::*;
573    use anyhow::Context as _;
574    use std::str::FromStr;
575    use std::time::Duration;
576
577    /// Parsed OAuth callback parameters from the authorization server redirect.
578    pub struct OAuthCallbackParams {
579        pub code: String,
580        pub state: String,
581    }
582
583    impl OAuthCallbackParams {
584        /// Parse the query string from a callback URL like
585        /// `http://127.0.0.1:<port>/callback?code=...&state=...`.
586        pub fn parse_query(query: &str) -> Result<Self> {
587            let mut code: Option<String> = None;
588            let mut state: Option<String> = None;
589            let mut error: Option<String> = None;
590            let mut error_description: Option<String> = None;
591
592            for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
593                match key.as_ref() {
594                    "code" => {
595                        if !value.is_empty() {
596                            code = Some(value.into_owned());
597                        }
598                    }
599                    "state" => {
600                        if !value.is_empty() {
601                            state = Some(value.into_owned());
602                        }
603                    }
604                    "error" => {
605                        if !value.is_empty() {
606                            error = Some(value.into_owned());
607                        }
608                    }
609                    "error_description" => {
610                        if !value.is_empty() {
611                            error_description = Some(value.into_owned());
612                        }
613                    }
614                    _ => {}
615                }
616            }
617
618            if let Some(error_code) = error {
619                anyhow::bail!(
620                    "OAuth authorization failed: {} ({})",
621                    error_code,
622                    error_description.as_deref().unwrap_or("no description")
623                );
624            }
625
626            let code = code.ok_or_else(|| anyhow!("missing 'code' parameter in OAuth callback"))?;
627            let state =
628                state.ok_or_else(|| anyhow!("missing 'state' parameter in OAuth callback"))?;
629
630            Ok(Self { code, state })
631        }
632    }
633
634    /// How long to wait for the browser to complete the OAuth flow before giving
635    /// up and releasing the loopback port.
636    const OAUTH_CALLBACK_TIMEOUT: Duration = Duration::from_secs(2 * 60);
637
638    /// Start a loopback HTTP server to receive the OAuth authorization callback.
639    ///
640    /// Binds to an ephemeral loopback port. Returns `(redirect_uri, callback_future)`.
641    /// The caller should use the redirect URI in the authorization request, open
642    /// the browser, then await the future to receive the callback.
643    pub fn start_oauth_callback_server() -> Result<(
644        String,
645        futures::channel::oneshot::Receiver<Result<OAuthCallbackParams>>,
646    )> {
647        let server = tiny_http::Server::http("127.0.0.1:0").map_err(|e| {
648            anyhow!(e).context("Failed to bind loopback listener for OAuth callback")
649        })?;
650        let port = server
651            .server_addr()
652            .to_ip()
653            .ok_or_else(|| anyhow!("server not bound to a TCP address"))?
654            .port();
655
656        let redirect_uri = format!("http://127.0.0.1:{}/callback", port);
657
658        let (tx, rx) = futures::channel::oneshot::channel();
659
660        std::thread::spawn(move || {
661            let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT;
662
663            loop {
664                if tx.is_canceled() {
665                    return;
666                }
667                let remaining = deadline.saturating_duration_since(std::time::Instant::now());
668                if remaining.is_zero() {
669                    return;
670                }
671
672                let timeout = remaining.min(Duration::from_millis(500));
673                let Some(request) = (match server.recv_timeout(timeout) {
674                    Ok(req) => req,
675                    Err(_) => {
676                        let _ = tx.send(Err(anyhow!("OAuth callback server I/O error")));
677                        return;
678                    }
679                }) else {
680                    continue;
681                };
682
683                let result = handle_oauth_callback_request(&request);
684
685                let (status_code, body) = match &result {
686                    Ok(_) => (
687                        200,
688                        oauth_callback_page(
689                            "Authorization Successful",
690                            "You can close this tab and return to Zed.",
691                            false,
692                        ),
693                    ),
694                    Err(err) => {
695                        log::error!("OAuth callback error: {}", err);
696                        (
697                            400,
698                            oauth_callback_page(
699                                "Authorization Failed",
700                                "Something went wrong. Please try again from Zed.",
701                                true,
702                            ),
703                        )
704                    }
705                };
706
707                let response = tiny_http::Response::from_string(body)
708                    .with_status_code(status_code)
709                    .with_header(
710                        tiny_http::Header::from_str("Content-Type: text/html")
711                            .expect("failed to construct response header"),
712                    )
713                    .with_header(
714                        tiny_http::Header::from_str("Keep-Alive: timeout=0,max=0")
715                            .expect("failed to construct response header"),
716                    );
717                if let Err(err) = request.respond(response) {
718                    log::error!("Failed to send OAuth callback response: {}", err);
719                }
720
721                let _ = tx.send(result);
722                return;
723            }
724        });
725
726        Ok((redirect_uri, rx))
727    }
728
729    fn handle_oauth_callback_request(request: &tiny_http::Request) -> Result<OAuthCallbackParams> {
730        let url = Url::parse(&format!("http://localhost{}", request.url()))
731            .context("malformed callback request URL")?;
732
733        if url.path() != "/callback" {
734            anyhow::bail!("unexpected path in OAuth callback: {}", url.path());
735        }
736
737        let query = url
738            .query()
739            .ok_or_else(|| anyhow!("OAuth callback has no query string"))?;
740        OAuthCallbackParams::parse_query(query)
741    }
742}
743
744#[cfg(not(target_family = "wasm"))]
745pub use oauth_callback_server::{OAuthCallbackParams, start_oauth_callback_server};