Add `dns` feature for DNS stuff (not just in starttls)

xmppftw created

Change summary

tokio-xmpp/Cargo.toml                     |  4 
tokio-xmpp/src/client/async_client.rs     |  3 
tokio-xmpp/src/client/simple_client.rs    |  3 
tokio-xmpp/src/connect.rs                 | 85 +++++++++++++++++++++++++
tokio-xmpp/src/error.rs                   | 43 ++++++++++++
tokio-xmpp/src/starttls/error.rs          | 11 ---
tokio-xmpp/src/starttls/happy_eyeballs.rs | 75 ----------------------
tokio-xmpp/src/starttls/mod.rs            | 17 ++--
8 files changed, 145 insertions(+), 96 deletions(-)

Detailed changes

tokio-xmpp/Cargo.toml 🔗

@@ -41,7 +41,7 @@ tokio-xmpp = { path = ".", features = ["insecure-tcp"]}
 
 [features]
 default = ["starttls-rust"]
-starttls = ["hickory-resolver", "idna"]
+starttls = ["dns"]
 tls-rust = ["tokio-rustls", "webpki-roots"]
 tls-native = ["tokio-native-tls", "native-tls"]
 starttls-native = ["starttls", "tls-native"]
@@ -50,6 +50,8 @@ insecure-tcp = []
 syntax-highlighting = ["syntect"]
 # Enable serde support in jid crate
 serde = [ "xmpp-parsers/serde" ]
+# Required by starttls, and used by insecure-tcp by default
+dns = [ "hickory-resolver", "idna" ]
 
 [lints.rust]
 unexpected_cfgs = { level = "warn", check-cfg = ['cfg(xmpprs_doc_build)'] }

tokio-xmpp/src/client/async_client.rs 🔗

@@ -10,9 +10,11 @@ use super::connect::client_login;
 use crate::connect::{AsyncReadAndWrite, ServerConnector};
 use crate::error::{Error, ProtocolError};
 use crate::event::Event;
+#[cfg(feature = "starttls")]
 use crate::starttls::ServerConfig;
 use crate::xmpp_codec::Packet;
 use crate::xmpp_stream::{add_stanza_id, XMPPStream};
+#[cfg(feature = "starttls")]
 use crate::AsyncConfig;
 
 /// XMPP client connection and state
@@ -46,6 +48,7 @@ enum ClientState<S: AsyncReadAndWrite> {
     Connected(XMPPStream<S>),
 }
 
+#[cfg(feature = "starttls")]
 impl Client<ServerConfig> {
     /// Start a new XMPP client using StartTLS transport and autoreconnect
     ///

tokio-xmpp/src/client/simple_client.rs 🔗

@@ -1,12 +1,14 @@
 use futures::{sink::SinkExt, Sink, Stream};
 use minidom::Element;
 use std::pin::Pin;
+#[cfg(feature = "starttls")]
 use std::str::FromStr;
 use std::task::{Context, Poll};
 use tokio_stream::StreamExt;
 use xmpp_parsers::{jid::Jid, ns, stream_features::StreamFeatures};
 
 use crate::connect::ServerConnector;
+#[cfg(feature = "starttls")]
 use crate::starttls::ServerConfig;
 use crate::xmpp_codec::Packet;
 use crate::xmpp_stream::{add_stanza_id, XMPPStream};
@@ -22,6 +24,7 @@ pub struct Client<C: ServerConnector> {
     stream: XMPPStream<C::Stream>,
 }
 
+#[cfg(feature = "starttls")]
 impl Client<ServerConfig> {
     /// Start a new XMPP client and wait for a usable session
     pub async fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, Error> {

tokio-xmpp/src/connect.rs 🔗

@@ -1,7 +1,17 @@
 //! `ServerConnector` provides streams for XMPP clients
 
+#[cfg(feature = "dns")]
+use futures::{future::select_ok, FutureExt};
+#[cfg(feature = "dns")]
+use hickory_resolver::{
+    config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
+};
+#[cfg(feature = "dns")]
+use log::debug;
 use sasl::common::ChannelBinding;
+use std::net::{IpAddr, SocketAddr};
 use tokio::io::{AsyncRead, AsyncWrite};
+use tokio::net::TcpStream;
 use xmpp_parsers::jid::Jid;
 
 use crate::xmpp_stream::XMPPStream;
@@ -32,3 +42,78 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
         Ok(ChannelBinding::None)
     }
 }
+
+/// A simple wrapper to build [`TcpStream`]
+pub struct Tcp;
+
+impl Tcp {
+    /// Connect directly to an IP/Port combo
+    pub async fn connect(ip: IpAddr, port: u16) -> Result<TcpStream, Error> {
+        Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?)
+    }
+
+    /// Connect over TCP, resolving A/AAAA records (happy eyeballs)
+    #[cfg(feature = "dns")]
+    pub async fn resolve(domain: &str, port: u16) -> Result<TcpStream, Error> {
+        let ascii_domain = idna::domain_to_ascii(&domain)?;
+
+        if let Ok(ip) = ascii_domain.parse() {
+            return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
+        }
+
+        let (config, mut options) = hickory_resolver::system_conf::read_system_conf()?;
+        options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
+        let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
+
+        let ips = resolver.lookup_ip(ascii_domain).await?;
+
+        // Happy Eyeballs: connect to all records in parallel, return the
+        // first to succeed
+        select_ok(
+            ips.into_iter()
+                .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
+        )
+        .await
+        .map(|(result, _)| result)
+        .map_err(|_| Error::Disconnected)
+    }
+
+    /// Connect over TCP, resolving SRV records
+    #[cfg(feature = "dns")]
+    pub async fn resolve_with_srv(
+        domain: &str,
+        srv: &str,
+        fallback_port: u16,
+    ) -> Result<TcpStream, Error> {
+        let ascii_domain = idna::domain_to_ascii(&domain)?;
+
+        if let Ok(ip) = ascii_domain.parse() {
+            debug!("Attempting connection to {ip}:{fallback_port}");
+            return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
+        }
+
+        let resolver = TokioAsyncResolver::tokio_from_system_conf()?;
+
+        let srv_domain = format!("{}.{}.", srv, ascii_domain).into_name()?;
+        let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
+
+        match srv_records {
+            Some(lookup) => {
+                // TODO: sort lookup records by priority/weight
+                for srv in lookup.iter() {
+                    debug!("Attempting connection to {srv_domain} {srv}");
+                    match Self::resolve(&srv.target().to_ascii(), srv.port()).await {
+                        Ok(stream) => return Ok(stream),
+                        Err(_) => {}
+                    }
+                }
+                Err(Error::Disconnected)
+            }
+            None => {
+                // SRV lookup error, retry with hostname
+                debug!("Attempting connection to {domain}:{fallback_port}");
+                Self::resolve(domain, fallback_port).await
+            }
+        }
+    }
+}

tokio-xmpp/src/error.rs 🔗

@@ -1,3 +1,7 @@
+#[cfg(feature = "dns")]
+use hickory_resolver::{
+    error::ResolveError as DnsResolveError, proto::error::ProtoError as DnsProtoError,
+};
 use sasl::client::MechanismError as SaslMechanismError;
 use std::error::Error as StdError;
 use std::fmt;
@@ -28,8 +32,18 @@ pub enum Error {
     Fmt(fmt::Error),
     /// Utf8 error
     Utf8(Utf8Error),
-    /// Error resolving DNS and/or establishing a connection, returned by a ServerConnector impl
+    /// Error specific to ServerConnector impl
     Connection(Box<dyn ServerConnectorError>),
+    /// DNS protocol error
+    #[cfg(feature = "dns")]
+    Dns(DnsProtoError),
+    /// DNS resolution error
+    #[cfg(feature = "dns")]
+    Resolve(DnsResolveError),
+    /// DNS label conversion error, no details available from module
+    /// `idna`
+    #[cfg(feature = "dns")]
+    Idna,
 }
 
 impl fmt::Display for Error {
@@ -44,6 +58,12 @@ impl fmt::Display for Error {
             Error::InvalidState => write!(fmt, "invalid state"),
             Error::Fmt(e) => write!(fmt, "Fmt error: {}", e),
             Error::Utf8(e) => write!(fmt, "Utf8 error: {}", e),
+            #[cfg(feature = "dns")]
+            Error::Dns(e) => write!(fmt, "{:?}", e),
+            #[cfg(feature = "dns")]
+            Error::Resolve(e) => write!(fmt, "{:?}", e),
+            #[cfg(feature = "dns")]
+            Error::Idna => write!(fmt, "IDNA error"),
         }
     }
 }
@@ -92,6 +112,27 @@ impl From<Utf8Error> for Error {
     }
 }
 
+#[cfg(feature = "dns")]
+impl From<idna::Errors> for Error {
+    fn from(_e: idna::Errors) -> Self {
+        Error::Idna
+    }
+}
+
+#[cfg(feature = "dns")]
+impl From<DnsResolveError> for Error {
+    fn from(e: DnsResolveError) -> Error {
+        Error::Resolve(e)
+    }
+}
+
+#[cfg(feature = "dns")]
+impl From<DnsProtoError> for Error {
+    fn from(e: DnsProtoError) -> Error {
+        Error::Dns(e)
+    }
+}
+
 /// XMPP protocol-level error
 #[derive(Debug)]
 pub enum ProtocolError {

tokio-xmpp/src/starttls/error.rs 🔗

@@ -1,6 +1,5 @@
 //! StartTLS ServerConnector Error
 
-use hickory_resolver::{error::ResolveError, proto::error::ProtoError};
 #[cfg(feature = "tls-native")]
 use native_tls::Error as TlsError;
 use std::error::Error as StdError;
@@ -15,13 +14,6 @@ use super::ServerConnectorError;
 /// StartTLS ServerConnector Error
 #[derive(Debug)]
 pub enum Error {
-    /// DNS protocol error
-    Dns(ProtoError),
-    /// DNS resolution error
-    Resolve(ResolveError),
-    /// DNS label conversion error, no details available from module
-    /// `idna`
-    Idna,
     /// TLS error
     Tls(TlsError),
     #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
@@ -34,9 +26,6 @@ impl ServerConnectorError for Error {}
 impl fmt::Display for Error {
     fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
         match self {
-            Self::Dns(e) => write!(fmt, "{:?}", e),
-            Self::Resolve(e) => write!(fmt, "{:?}", e),
-            Self::Idna => write!(fmt, "IDNA error"),
             Self::Tls(e) => write!(fmt, "TLS error: {}", e),
             #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
             Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),

tokio-xmpp/src/starttls/happy_eyeballs.rs 🔗

@@ -1,75 +0,0 @@
-use super::error::Error as StartTlsError;
-use crate::Error;
-use futures::{future::select_ok, FutureExt};
-use hickory_resolver::{
-    config::LookupIpStrategy, name_server::TokioConnectionProvider, IntoName, TokioAsyncResolver,
-};
-use log::debug;
-use std::net::SocketAddr;
-use tokio::net::TcpStream;
-
-pub async fn connect_to_host(domain: &str, port: u16) -> Result<TcpStream, Error> {
-    let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;
-
-    if let Ok(ip) = ascii_domain.parse() {
-        return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
-    }
-
-    let (config, mut options) =
-        hickory_resolver::system_conf::read_system_conf().map_err(StartTlsError::Resolve)?;
-    options.ip_strategy = LookupIpStrategy::Ipv4AndIpv6;
-    let resolver = TokioAsyncResolver::new(config, options, TokioConnectionProvider::default());
-
-    let ips = resolver
-        .lookup_ip(ascii_domain)
-        .await
-        .map_err(StartTlsError::Resolve)?;
-    // Happy Eyeballs: connect to all records in parallel, return the
-    // first to succeed
-    select_ok(
-        ips.into_iter()
-            .map(|ip| TcpStream::connect(SocketAddr::new(ip, port)).boxed()),
-    )
-    .await
-    .map(|(result, _)| result)
-    .map_err(|_| crate::Error::Disconnected)
-}
-
-pub async fn connect_with_srv(
-    domain: &str,
-    srv: &str,
-    fallback_port: u16,
-) -> Result<TcpStream, Error> {
-    let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| StartTlsError::Idna)?;
-
-    if let Ok(ip) = ascii_domain.parse() {
-        debug!("Attempting connection to {ip}:{fallback_port}");
-        return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
-    }
-
-    let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(StartTlsError::Resolve)?;
-
-    let srv_domain = format!("{}.{}.", srv, ascii_domain)
-        .into_name()
-        .map_err(StartTlsError::Dns)?;
-    let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
-
-    match srv_records {
-        Some(lookup) => {
-            // TODO: sort lookup records by priority/weight
-            for srv in lookup.iter() {
-                debug!("Attempting connection to {srv_domain} {srv}");
-                match connect_to_host(&srv.target().to_ascii(), srv.port()).await {
-                    Ok(stream) => return Ok(stream),
-                    Err(_) => {}
-                }
-            }
-            Err(crate::Error::Disconnected.into())
-        }
-        None => {
-            // SRV lookup error, retry with hostname
-            debug!("Attempting connection to {domain}:{fallback_port}");
-            connect_to_host(domain, fallback_port).await
-        }
-    }
-}

tokio-xmpp/src/starttls/mod.rs 🔗

@@ -27,16 +27,17 @@ use tokio::{
 };
 use xmpp_parsers::{jid::Jid, ns};
 
-use crate::error::ProtocolError;
-use crate::Error;
-use crate::{connect::ServerConnector, xmpp_codec::Packet, AsyncClient, SimpleClient};
-use crate::{connect::ServerConnectorError, xmpp_stream::XMPPStream};
+use crate::{
+    connect::{ServerConnector, ServerConnectorError, Tcp},
+    error::{Error, ProtocolError},
+    xmpp_codec::Packet,
+    xmpp_stream::XMPPStream,
+    AsyncClient, SimpleClient,
+};
 
 use self::error::Error as StartTlsError;
-use self::happy_eyeballs::{connect_to_host, connect_with_srv};
 
 pub mod error;
-mod happy_eyeballs;
 
 /// AsyncClient that connects over StartTls
 pub type StartTlsAsyncClient = AsyncClient<ServerConfig>;
@@ -64,9 +65,9 @@ impl ServerConnector for ServerConfig {
         // TCP connection
         let tcp_stream = match self {
             ServerConfig::UseSrv => {
-                connect_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await?
+                Tcp::resolve_with_srv(jid.domain().as_str(), "_xmpp-client._tcp", 5222).await?
             }
-            ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
+            ServerConfig::Manual { host, port } => Tcp::resolve(host.as_str(), *port).await?,
         };
 
         // Unencryped XMPPStream