dns.rs

  1use core::{fmt, net::SocketAddr};
  2#[cfg(feature = "dns")]
  3use futures::{future::select_ok, FutureExt};
  4#[cfg(feature = "dns")]
  5use hickory_resolver::{config::LookupIpStrategy, IntoName, TokioResolver};
  6#[cfg(feature = "dns")]
  7use log::debug;
  8use tokio::net::TcpStream;
  9
 10use crate::Error;
 11
 12/// XMPP server connection configuration
 13#[derive(Clone, Debug)]
 14pub enum DnsConfig {
 15    /// Use SRV record to find server host
 16    #[cfg(feature = "dns")]
 17    UseSrv {
 18        /// Hostname to resolve
 19        host: String,
 20        /// TXT field eg. _xmpp-client._tcp
 21        srv: String,
 22        /// When SRV resolution fails what port to use
 23        fallback_port: u16,
 24    },
 25
 26    /// Manually define server host and port
 27    #[allow(unused)]
 28    #[cfg(feature = "dns")]
 29    NoSrv {
 30        /// Server host name
 31        host: String,
 32        /// Server port
 33        port: u16,
 34    },
 35
 36    /// Manually define IP: port (TODO: socket)
 37    #[allow(unused)]
 38    Addr {
 39        /// IP:port
 40        addr: String,
 41    },
 42}
 43
 44impl fmt::Display for DnsConfig {
 45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
 46        match self {
 47            #[cfg(feature = "dns")]
 48            Self::UseSrv { host, .. } => write!(f, "{}", host),
 49            #[cfg(feature = "dns")]
 50            Self::NoSrv { host, port } => write!(f, "{}:{}", host, port),
 51            Self::Addr { addr } => write!(f, "{}", addr),
 52        }
 53    }
 54}
 55
 56impl DnsConfig {
 57    /// Constructor for DnsConfig::UseSrv variant
 58    #[cfg(feature = "dns")]
 59    pub fn srv(host: &str, srv: &str, fallback_port: u16) -> Self {
 60        Self::UseSrv {
 61            host: host.to_string(),
 62            srv: srv.to_string(),
 63            fallback_port,
 64        }
 65    }
 66
 67    /// Constructor for the default SRV resolution strategy for clients (StartTLS)
 68    #[cfg(feature = "dns")]
 69    pub fn srv_default_client(host: &str) -> Self {
 70        Self::UseSrv {
 71            host: host.to_string(),
 72            srv: "_xmpp-client._tcp".to_string(),
 73            fallback_port: 5222,
 74        }
 75    }
 76
 77    /// Constructor for direct TLS connections using RFC 7590 _xmpps-client._tcp
 78    #[cfg(feature = "dns")]
 79    pub fn srv_xmpps(host: &str) -> Self {
 80        Self::UseSrv {
 81            host: host.to_string(),
 82            srv: "_xmpps-client._tcp".to_string(),
 83            fallback_port: 5223,
 84        }
 85    }
 86
 87    /// Constructor for DnsConfig::NoSrv variant
 88    #[cfg(feature = "dns")]
 89    pub fn no_srv(host: &str, port: u16) -> Self {
 90        Self::NoSrv {
 91            host: host.to_string(),
 92            port,
 93        }
 94    }
 95
 96    /// Constructor for DnsConfig::Addr variant
 97    pub fn addr(addr: &str) -> Self {
 98        Self::Addr {
 99            addr: addr.to_string(),
100        }
101    }
102
103    /// Try resolve the DnsConfig to a TcpStream
104    pub async fn resolve(&self) -> Result<TcpStream, Error> {
105        match self {
106            #[cfg(feature = "dns")]
107            Self::UseSrv {
108                host,
109                srv,
110                fallback_port,
111            } => Self::resolve_srv(host, srv, *fallback_port).await,
112            #[cfg(feature = "dns")]
113            Self::NoSrv { host, port } => Self::resolve_no_srv(host, *port).await,
114            Self::Addr { addr } => {
115                // TODO: Unix domain socket
116                let addr: SocketAddr = addr.parse()?;
117                return Ok(TcpStream::connect(&SocketAddr::new(addr.ip(), addr.port())).await?);
118            }
119        }
120    }
121
122    #[cfg(feature = "dns")]
123    async fn resolve_srv(host: &str, srv: &str, fallback_port: u16) -> Result<TcpStream, Error> {
124        let ascii_domain = idna::domain_to_ascii(host)?;
125
126        if let Ok(ip) = ascii_domain.parse() {
127            debug!("Attempting connection to {ip}:{fallback_port}");
128            return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
129        }
130
131        let (_config, options) = hickory_resolver::system_conf::read_system_conf()?;
132        let resolver = TokioResolver::builder_tokio()?
133            .with_options(options)
134            .build();
135
136        let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
137        let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
138
139        match srv_records {
140            Some(lookup) => {
141                // TODO: sort lookup records by priority/weight
142                for srv in lookup.iter() {
143                    debug!("Attempting connection to {srv_domain} {srv}");
144                    if let Ok(stream) =
145                        Self::resolve_no_srv(&srv.target().to_ascii(), srv.port()).await
146                    {
147                        return Ok(stream);
148                    }
149                }
150                Err(Error::Disconnected)
151            }
152            None => {
153                // SRV lookup error, retry with hostname
154                debug!("Attempting connection to {host}:{fallback_port}");
155                Self::resolve_no_srv(host, fallback_port).await
156            }
157        }
158    }
159
160    #[cfg(feature = "dns")]
161    async fn resolve_no_srv(host: &str, port: u16) -> Result<TcpStream, Error> {
162        let ascii_domain = idna::domain_to_ascii(host)?;
163
164        if let Ok(ip) = ascii_domain.parse() {
165            return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
166        }
167
168        let (_config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
169        options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
170        let resolver = TokioResolver::builder_tokio()?
171            .with_options(options)
172            .build();
173
174        let ips = resolver.lookup_ip(ascii_domain).await?;
175
176        // Happy Eyeballs: connect to all records in parallel, return the
177        // first to succeed
178        select_ok(
179            ips.into_iter()
180                .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
181        )
182        .await
183        .map(|(result, _)| result)
184        .map_err(|_| Error::Disconnected)
185    }
186}