1use anyhow::{Context, Result};
2use base64::Engine;
3use httparse::{EMPTY_HEADER, Response};
4use tokio::{
5 io::{AsyncBufReadExt, AsyncWriteExt, BufStream},
6 net::TcpStream,
7};
8#[cfg(any(target_os = "windows", target_os = "macos"))]
9use tokio_native_tls::{TlsConnector, native_tls};
10#[cfg(not(any(target_os = "windows", target_os = "macos")))]
11use tokio_rustls::TlsConnector;
12use url::Url;
13
14use super::AsyncReadWrite;
15
16pub(super) enum HttpProxyType<'t> {
17 HTTP(Option<HttpProxyAuthorization<'t>>),
18 HTTPS(Option<HttpProxyAuthorization<'t>>),
19}
20
21pub(super) struct HttpProxyAuthorization<'t> {
22 username: &'t str,
23 password: &'t str,
24}
25
26pub(super) fn parse_http_proxy<'t>(scheme: &str, proxy: &'t Url) -> HttpProxyType<'t> {
27 let auth = proxy.password().map(|password| HttpProxyAuthorization {
28 username: proxy.username(),
29 password,
30 });
31 if scheme.starts_with("https") {
32 HttpProxyType::HTTPS(auth)
33 } else {
34 HttpProxyType::HTTP(auth)
35 }
36}
37
38pub(crate) async fn connect_http_proxy_stream(
39 stream: TcpStream,
40 http_proxy: HttpProxyType<'_>,
41 rpc_host: (&str, u16),
42 proxy_domain: &str,
43) -> Result<Box<dyn AsyncReadWrite>> {
44 match http_proxy {
45 HttpProxyType::HTTP(auth) => http_connect(stream, rpc_host, auth).await,
46 HttpProxyType::HTTPS(auth) => https_connect(stream, rpc_host, auth, proxy_domain).await,
47 }
48 .context("error connecting to http/https proxy")
49}
50
51async fn http_connect<T>(
52 stream: T,
53 target: (&str, u16),
54 auth: Option<HttpProxyAuthorization<'_>>,
55) -> Result<Box<dyn AsyncReadWrite>>
56where
57 T: AsyncReadWrite,
58{
59 let mut stream = BufStream::new(stream);
60 let request = make_request(target, auth);
61 stream.write_all(request.as_bytes()).await?;
62 stream.flush().await?;
63 check_response(&mut stream).await?;
64 Ok(Box::new(stream))
65}
66
67#[cfg(any(target_os = "windows", target_os = "macos"))]
68async fn https_connect<T>(
69 stream: T,
70 target: (&str, u16),
71 auth: Option<HttpProxyAuthorization<'_>>,
72 proxy_domain: &str,
73) -> Result<Box<dyn AsyncReadWrite>>
74where
75 T: AsyncReadWrite,
76{
77 let tls_connector = TlsConnector::from(native_tls::TlsConnector::new()?);
78 let stream = tls_connector.connect(proxy_domain, stream).await?;
79 http_connect(stream, target, auth).await
80}
81
82#[cfg(not(any(target_os = "windows", target_os = "macos")))]
83async fn https_connect<T>(
84 stream: T,
85 target: (&str, u16),
86 auth: Option<HttpProxyAuthorization<'_>>,
87 proxy_domain: &str,
88) -> Result<Box<dyn AsyncReadWrite>>
89where
90 T: AsyncReadWrite,
91{
92 let proxy_domain = rustls_pki_types::ServerName::try_from(proxy_domain)
93 .context("Address resolution failed")?
94 .to_owned();
95 let tls_connector = TlsConnector::from(std::sync::Arc::new(http_client_tls::tls_config()));
96 let stream = tls_connector.connect(proxy_domain, stream).await?;
97 http_connect(stream, target, auth).await
98}
99
100fn make_request(target: (&str, u16), auth: Option<HttpProxyAuthorization<'_>>) -> String {
101 let (host, port) = target;
102 let mut request = format!(
103 "CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\nProxy-Connection: Keep-Alive\r\n"
104 );
105 if let Some(HttpProxyAuthorization { username, password }) = auth {
106 let auth =
107 base64::prelude::BASE64_STANDARD.encode(format!("{username}:{password}").as_bytes());
108 let auth = format!("Proxy-Authorization: Basic {auth}\r\n");
109 request.push_str(&auth);
110 }
111 request.push_str("\r\n");
112 request
113}
114
115async fn check_response<T>(stream: &mut BufStream<T>) -> Result<()>
116where
117 T: AsyncReadWrite,
118{
119 let response = recv_response(stream).await?;
120 let mut dummy_headers = [EMPTY_HEADER; MAX_RESPONSE_HEADERS];
121 let mut parser = Response::new(&mut dummy_headers);
122 parser.parse(response.as_bytes())?;
123
124 match parser.code {
125 Some(code) => {
126 if code == 200 {
127 Ok(())
128 } else {
129 Err(anyhow::anyhow!(
130 "Proxy connection failed with HTTP code: {code}"
131 ))
132 }
133 }
134 None => Err(anyhow::anyhow!(
135 "Proxy connection failed with no HTTP code: {}",
136 parser.reason.unwrap_or("Unknown reason")
137 )),
138 }
139}
140
141const MAX_RESPONSE_HEADER_LENGTH: usize = 4096;
142const MAX_RESPONSE_HEADERS: usize = 16;
143
144async fn recv_response<T>(stream: &mut BufStream<T>) -> Result<String>
145where
146 T: AsyncReadWrite,
147{
148 let mut response = String::new();
149 loop {
150 if stream.read_line(&mut response).await? == 0 {
151 return Err(anyhow::anyhow!("End of stream"));
152 }
153
154 if MAX_RESPONSE_HEADER_LENGTH < response.len() {
155 return Err(anyhow::anyhow!("Maximum response header length exceeded"));
156 }
157
158 if response.ends_with("\r\n\r\n") {
159 return Ok(response);
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use url::Url;
167
168 use super::{HttpProxyAuthorization, HttpProxyType, parse_http_proxy};
169
170 #[test]
171 fn test_parse_http_proxy() {
172 let proxy = Url::parse("http://proxy.example.com:1080").unwrap();
173 let scheme = proxy.scheme();
174
175 let version = parse_http_proxy(scheme, &proxy);
176 assert!(matches!(version, HttpProxyType::HTTP(None)))
177 }
178
179 #[test]
180 fn test_parse_http_proxy_with_auth() {
181 let proxy = Url::parse("http://username:password@proxy.example.com:1080").unwrap();
182 let scheme = proxy.scheme();
183
184 let version = parse_http_proxy(scheme, &proxy);
185 assert!(matches!(
186 version,
187 HttpProxyType::HTTP(Some(HttpProxyAuthorization {
188 username: "username",
189 password: "password"
190 }))
191 ))
192 }
193}