Use ServerConfig enum for tokio-xmpp client config

Paul Fariello created

And expose connect_to_host from happy_eyeballs to let clients explicitly
choose to use SRV or not. (Rename connect to connect_with_srv)

Change summary

tokio-xmpp/src/client/async_client.rs  | 33 +++++++++++++++++----------
tokio-xmpp/src/client/simple_client.rs |  4 +-
tokio-xmpp/src/component/mod.rs        |  4 +-
tokio-xmpp/src/happy_eyeballs.rs       | 29 +++++++++++++++---------
4 files changed, 43 insertions(+), 27 deletions(-)

Detailed changes

tokio-xmpp/src/client/async_client.rs 🔗

@@ -1,5 +1,4 @@
 use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
-use idna;
 use sasl::common::{ChannelBinding, Credentials};
 use std::mem::replace;
 use std::pin::Pin;
@@ -14,7 +13,7 @@ use xmpp_parsers::{ns, Element, Jid, JidParseError};
 use super::auth::auth;
 use super::bind::bind;
 use crate::event::Event;
-use crate::happy_eyeballs::connect;
+use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
 use crate::starttls::starttls;
 use crate::xmpp_codec::Packet;
 use crate::xmpp_stream;
@@ -33,12 +32,22 @@ pub struct Client {
     // TODO: tls_required=true
 }
 
+/// XMPP server connection configuration
+#[derive(Clone)]
+pub enum ServerConfig {
+    UseSrv,
+    #[allow(unused)]
+    Manual {
+        host: String,
+        port: u16,
+    },
+}
+
 /// XMMPP client configuration
 pub struct Config {
     jid: Jid,
     password: String,
-    server: String,
-    port: u16,
+    server: ServerConfig,
 }
 
 type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
@@ -60,8 +69,7 @@ impl Client {
         let config = Config {
             jid: jid.clone(),
             password: password.into(),
-            server: jid.clone().domain(),
-            port: 5222,
+            server: ServerConfig::UseSrv,
         };
         let client = Self::new_with_config(config);
         Ok(client)
@@ -72,7 +80,6 @@ impl Client {
         let local = LocalSet::new();
         let connect = local.spawn_local(Self::connect(
             config.server.clone(),
-            config.port,
             config.jid.clone(),
             config.password.clone(),
         ));
@@ -92,17 +99,20 @@ impl Client {
     }
 
     async fn connect(
-        server: String,
-        port: u16,
+        server: ServerConfig,
         jid: Jid,
         password: String,
     ) -> Result<XMPPStream, Error> {
         let username = jid.clone().node().unwrap();
         let password = password;
-        let domain = idna::domain_to_ascii(&server).map_err(|_| Error::Idna)?;
 
         // TCP connection
-        let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), port).await?;
+        let tcp_stream = match server {
+            ServerConfig::UseSrv => {
+                connect_with_srv(&jid.clone().domain(), Some("_xmpp-client._tcp"), 5222).await?
+            }
+            ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
+        };
 
         // Unencryped XMPPStream
         let xmpp_stream =
@@ -186,7 +196,6 @@ impl Stream for Client {
                 let mut local = LocalSet::new();
                 let connect = local.spawn_local(Self::connect(
                     self.config.server.clone(),
-                    self.config.port,
                     self.config.jid.clone(),
                     self.config.password.clone(),
                 ));

tokio-xmpp/src/client/simple_client.rs 🔗

@@ -11,7 +11,7 @@ use xmpp_parsers::{ns, Element, Jid};
 
 use super::auth::auth;
 use super::bind::bind;
-use crate::happy_eyeballs::connect;
+use crate::happy_eyeballs::connect_with_srv;
 use crate::starttls::starttls;
 use crate::xmpp_codec::Packet;
 use crate::xmpp_stream;
@@ -47,7 +47,7 @@ impl Client {
         let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?;
 
         // TCP connection
-        let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?;
+        let tcp_stream = connect_with_srv(&domain, Some("_xmpp-client._tcp"), 5222).await?;
 
         // Unencryped XMPPStream
         let xmpp_stream =

tokio-xmpp/src/component/mod.rs 🔗

@@ -8,7 +8,7 @@ use std::task::Context;
 use tokio::net::TcpStream;
 use xmpp_parsers::{ns, Element, Jid};
 
-use super::happy_eyeballs::connect;
+use super::happy_eyeballs::connect_to_host;
 use super::xmpp_codec::Packet;
 use super::xmpp_stream;
 use super::Error;
@@ -43,7 +43,7 @@ impl Component {
         port: u16,
     ) -> Result<XMPPStream, Error> {
         let password = password;
-        let tcp_stream = connect(server, None, port).await?;
+        let tcp_stream = connect_to_host(server, port).await?;
         let mut xmpp_stream =
             xmpp_stream::XMPPStream::start(tcp_stream, jid, ns::COMPONENT_ACCEPT.to_owned())
                 .await?;

tokio-xmpp/src/happy_eyeballs.rs 🔗

@@ -1,15 +1,20 @@
 use crate::{ConnecterError, Error};
+use idna;
 use std::net::SocketAddr;
 use tokio::net::TcpStream;
 use trust_dns_resolver::{IntoName, TokioAsyncResolver};
 
-async fn connect_to_host(
-    resolver: &TokioAsyncResolver,
-    host: &str,
-    port: u16,
-) -> Result<TcpStream, Error> {
+pub async fn connect_to_host(domain: &str, port: u16) -> Result<TcpStream, Error> {
+    let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| Error::Idna)?;
+
+    if let Ok(ip) = ascii_domain.parse() {
+        return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
+    }
+
+    let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnecterError::Resolve)?;
+
     let ips = resolver
-        .lookup_ip(host)
+        .lookup_ip(ascii_domain)
         .await
         .map_err(ConnecterError::Resolve)?;
     for ip in ips.iter() {
@@ -21,12 +26,14 @@ async fn connect_to_host(
     Err(Error::Disconnected)
 }
 
-pub async fn connect(
+pub async fn connect_with_srv(
     domain: &str,
     srv: Option<&str>,
     fallback_port: u16,
 ) -> Result<TcpStream, Error> {
-    if let Ok(ip) = domain.parse() {
+    let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| Error::Idna)?;
+
+    if let Ok(ip) = ascii_domain.parse() {
         return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
     }
 
@@ -34,7 +41,7 @@ pub async fn connect(
 
     let srv_records = match srv {
         Some(srv) => {
-            let srv_domain = format!("{}.{}.", srv, domain)
+            let srv_domain = format!("{}.{}.", srv, ascii_domain)
                 .into_name()
                 .map_err(ConnecterError::Dns)?;
             resolver.srv_lookup(srv_domain).await.ok()
@@ -46,7 +53,7 @@ pub async fn connect(
         Some(lookup) => {
             // TODO: sort lookup records by priority/weight
             for srv in lookup.iter() {
-                match connect_to_host(&resolver, &srv.target().to_ascii(), srv.port()).await {
+                match connect_to_host(&srv.target().to_ascii(), srv.port()).await {
                     Ok(stream) => return Ok(stream),
                     Err(_) => {}
                 }
@@ -55,7 +62,7 @@ pub async fn connect(
         }
         None => {
             // SRV lookup error, retry with hostname
-            connect_to_host(&resolver, domain, fallback_port).await
+            connect_to_host(domain, fallback_port).await
         }
     }
 }