socks.rs

 1//! socks proxy
 2use anyhow::{Result, anyhow};
 3use futures::io::{AsyncRead, AsyncWrite};
 4use http_client::Uri;
 5use tokio_socks::{
 6    io::Compat,
 7    tcp::{Socks4Stream, Socks5Stream},
 8};
 9
10pub(crate) async fn connect_socks_proxy_stream(
11    proxy: Option<&Uri>,
12    rpc_host: (&str, u16),
13) -> Result<Box<dyn AsyncReadWrite>> {
14    let stream = match parse_socks_proxy(proxy) {
15        Some((socks_proxy, SocksVersion::V4)) => {
16            let stream = Socks4Stream::connect_with_socket(
17                Compat::new(smol::net::TcpStream::connect(socks_proxy).await?),
18                rpc_host,
19            )
20            .await
21            .map_err(|err| anyhow!("error connecting to socks {}", err))?;
22            Box::new(stream) as Box<dyn AsyncReadWrite>
23        }
24        Some((socks_proxy, SocksVersion::V5)) => Box::new(
25            Socks5Stream::connect_with_socket(
26                Compat::new(smol::net::TcpStream::connect(socks_proxy).await?),
27                rpc_host,
28            )
29            .await
30            .map_err(|err| anyhow!("error connecting to socks {}", err))?,
31        ) as Box<dyn AsyncReadWrite>,
32        None => Box::new(smol::net::TcpStream::connect(rpc_host).await?) as Box<dyn AsyncReadWrite>,
33    };
34    Ok(stream)
35}
36
37fn parse_socks_proxy(proxy: Option<&Uri>) -> Option<((String, u16), SocksVersion)> {
38    let proxy_uri = proxy?;
39    let scheme = proxy_uri.scheme_str()?;
40    let socks_version = if scheme.starts_with("socks4") {
41        // socks4
42        SocksVersion::V4
43    } else if scheme.starts_with("socks") {
44        // socks, socks5
45        SocksVersion::V5
46    } else {
47        return None;
48    };
49    if let (Some(host), Some(port)) = (proxy_uri.host(), proxy_uri.port_u16()) {
50        Some(((host.to_string(), port), socks_version))
51    } else {
52        None
53    }
54}
55
56// private helper structs and traits
57
58enum SocksVersion {
59    V4,
60    V5,
61}
62
63pub(crate) trait AsyncReadWrite: AsyncRead + AsyncWrite + Unpin + Send + 'static {}
64impl<T: AsyncRead + AsyncWrite + Unpin + Send + 'static> AsyncReadWrite for T {}