http_proxy.rs

  1use anyhow::{Context, Result};
  2use base64::Engine;
  3use httparse::{EMPTY_HEADER, Response};
  4use tokio::{
  5    io::{AsyncBufReadExt, AsyncWriteExt, BufStream},
  6    net::TcpStream,
  7};
  8#[cfg(any(target_os = "windows", target_os = "macos"))]
  9use tokio_native_tls::{TlsConnector, native_tls};
 10#[cfg(not(any(target_os = "windows", target_os = "macos")))]
 11use tokio_rustls::TlsConnector;
 12use url::Url;
 13
 14use super::AsyncReadWrite;
 15
 16pub(super) enum HttpProxyType<'t> {
 17    HTTP(Option<HttpProxyAuthorization<'t>>),
 18    HTTPS(Option<HttpProxyAuthorization<'t>>),
 19}
 20
 21pub(super) struct HttpProxyAuthorization<'t> {
 22    username: &'t str,
 23    password: &'t str,
 24}
 25
 26pub(super) fn parse_http_proxy<'t>(scheme: &str, proxy: &'t Url) -> HttpProxyType<'t> {
 27    let auth = proxy.password().map(|password| HttpProxyAuthorization {
 28        username: proxy.username(),
 29        password,
 30    });
 31    if scheme.starts_with("https") {
 32        HttpProxyType::HTTPS(auth)
 33    } else {
 34        HttpProxyType::HTTP(auth)
 35    }
 36}
 37
 38pub(crate) async fn connect_http_proxy_stream(
 39    stream: TcpStream,
 40    http_proxy: HttpProxyType<'_>,
 41    rpc_host: (&str, u16),
 42    proxy_domain: &str,
 43) -> Result<Box<dyn AsyncReadWrite>> {
 44    match http_proxy {
 45        HttpProxyType::HTTP(auth) => http_connect(stream, rpc_host, auth).await,
 46        HttpProxyType::HTTPS(auth) => https_connect(stream, rpc_host, auth, proxy_domain).await,
 47    }
 48    .context("error connecting to http/https proxy")
 49}
 50
 51async fn http_connect<T>(
 52    stream: T,
 53    target: (&str, u16),
 54    auth: Option<HttpProxyAuthorization<'_>>,
 55) -> Result<Box<dyn AsyncReadWrite>>
 56where
 57    T: AsyncReadWrite,
 58{
 59    let mut stream = BufStream::new(stream);
 60    let request = make_request(target, auth);
 61    stream.write_all(request.as_bytes()).await?;
 62    stream.flush().await?;
 63    check_response(&mut stream).await?;
 64    Ok(Box::new(stream))
 65}
 66
 67#[cfg(any(target_os = "windows", target_os = "macos"))]
 68async fn https_connect<T>(
 69    stream: T,
 70    target: (&str, u16),
 71    auth: Option<HttpProxyAuthorization<'_>>,
 72    proxy_domain: &str,
 73) -> Result<Box<dyn AsyncReadWrite>>
 74where
 75    T: AsyncReadWrite,
 76{
 77    let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?);
 78    let stream = tls_connector.connect(proxy_domain, stream).await?;
 79    http_connect(stream, target, auth).await
 80}
 81
 82#[cfg(not(any(target_os = "windows", target_os = "macos")))]
 83async fn https_connect<T>(
 84    stream: T,
 85    target: (&str, u16),
 86    auth: Option<HttpProxyAuthorization<'_>>,
 87    proxy_domain: &str,
 88) -> Result<Box<dyn AsyncReadWrite>>
 89where
 90    T: AsyncReadWrite,
 91{
 92    let proxy_domain = rustls_pki_types::ServerName::try_from(proxy_domain)
 93        .context("Address resolution failed")?
 94        .to_owned();
 95    let tls_connector = TlsConnector::from(std::sync::Arc::new(http_client_tls::tls_config()));
 96    let stream = tls_connector.connect(proxy_domain, stream).await?;
 97    http_connect(stream, target, auth).await
 98}
 99
100fn make_request(target: (&str, u16), auth: Option<HttpProxyAuthorization<'_>>) -> String {
101    let (host, port) = target;
102    let mut request = format!(
103        "CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\nProxy-Connection: Keep-Alive\r\n"
104    );
105    if let Some(HttpProxyAuthorization { username, password }) = auth {
106        let auth =
107            base64::prelude::BASE64_STANDARD.encode(format!("{username}:{password}").as_bytes());
108        let auth = format!("Proxy-Authorization: Basic {auth}\r\n");
109        request.push_str(&auth);
110    }
111    request.push_str("\r\n");
112    request
113}
114
115async fn check_response<T>(stream: &mut BufStream<T>) -> Result<()>
116where
117    T: AsyncReadWrite,
118{
119    let response = recv_response(stream).await?;
120    let mut dummy_headers = [EMPTY_HEADER; MAX_RESPONSE_HEADERS];
121    let mut parser = Response::new(&mut dummy_headers);
122    parser.parse(response.as_bytes())?;
123
124    match parser.code {
125        Some(code) => {
126            if code == 200 {
127                Ok(())
128            } else {
129                Err(anyhow::anyhow!(
130                    "Proxy connection failed with HTTP code: {code}"
131                ))
132            }
133        }
134        None => Err(anyhow::anyhow!(
135            "Proxy connection failed with no HTTP code: {}",
136            parser.reason.unwrap_or("Unknown reason")
137        )),
138    }
139}
140
141const MAX_RESPONSE_HEADER_LENGTH: usize = 4096;
142const MAX_RESPONSE_HEADERS: usize = 16;
143
144async fn recv_response<T>(stream: &mut BufStream<T>) -> Result<String>
145where
146    T: AsyncReadWrite,
147{
148    let mut response = String::new();
149    loop {
150        if stream.read_line(&mut response).await? == 0 {
151            return Err(anyhow::anyhow!("End of stream"));
152        }
153
154        if MAX_RESPONSE_HEADER_LENGTH < response.len() {
155            return Err(anyhow::anyhow!("Maximum response header length exceeded"));
156        }
157
158        if response.ends_with("\r\n\r\n") {
159            return Ok(response);
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use url::Url;
167
168    use super::{HttpProxyAuthorization, HttpProxyType, parse_http_proxy};
169
170    #[test]
171    fn test_parse_http_proxy() {
172        let proxy = Url::parse("http://proxy.example.com:1080").unwrap();
173        let scheme = proxy.scheme();
174
175        let version = parse_http_proxy(scheme, &proxy);
176        assert!(matches!(version, HttpProxyType::HTTP(None)))
177    }
178
179    #[test]
180    fn test_parse_http_proxy_with_auth() {
181        let proxy = Url::parse("http://username:password@proxy.example.com:1080").unwrap();
182        let scheme = proxy.scheme();
183
184        let version = parse_http_proxy(scheme, &proxy);
185        assert!(matches!(
186            version,
187            HttpProxyType::HTTP(Some(HttpProxyAuthorization {
188                username: "username",
189                password: "password"
190            }))
191        ))
192    }
193}