Add support for both rustls and tlsnative

Paul Fariello created

Change summary

tokio-xmpp/Cargo.toml                  | 13 +++++--
tokio-xmpp/src/client/async_client.rs  |  5 ++
tokio-xmpp/src/client/simple_client.rs |  5 ++
tokio-xmpp/src/error.rs                |  3 +
tokio-xmpp/src/starttls.rs             | 46 +++++++++++++++++++++++----
5 files changed, 58 insertions(+), 14 deletions(-)

Detailed changes

tokio-xmpp/Cargo.toml 🔗

@@ -16,19 +16,24 @@ bytes = "1"
 futures = "0.3"
 idna = "0.2"
 log = "0.4"
-native-tls = "0.2"
+native-tls = { version = "0.2", optional = true }
 sasl = "0.5"
 tokio = { version = "1", features = ["net", "rt", "rt-multi-thread", "macros"] }
-tokio-util = { version = "0.6", features = ["codec"] }
+tokio-native-tls = { version = "0.3", optional = true }
+tokio-rustls = { version = "0.22", optional = true }
 tokio-stream = { version = "0.1", features = [] }
-tokio-tls = { package = "tokio-native-tls", version = "0.3" }
-trust-dns-resolver = "0.20"
+tokio-util = { version = "0.6", features = ["codec"] }
 trust-dns-proto = "0.20"
+trust-dns-resolver = "0.20"
 xml5ever = "0.16"
 xmpp-parsers = "0.18"
+webpki = { version = "0.21", optional = true }
 
 [build-dependencies]
 rustc_version = "0.3"
 
 [features]
+default = ["tls-native"]
+tls-rust = ["tokio-rustls", "webpki"]
+tls-native = ["tokio-native-tls", "native-tls"]
 serde = ["xmpp-parsers/serde"]

tokio-xmpp/src/client/async_client.rs 🔗

@@ -7,7 +7,10 @@ use std::task::Context;
 use tokio::net::TcpStream;
 use tokio::task::JoinHandle;
 use tokio::task::LocalSet;
-use tokio_tls::TlsStream;
+#[cfg(feature = "tls-native")]
+use tokio_native_tls::TlsStream;
+#[cfg(feature = "tls-rust")]
+use tokio_rustls::client::TlsStream;
 use xmpp_parsers::{ns, Element, Jid, JidParseError};
 
 use super::auth::auth;

tokio-xmpp/src/client/simple_client.rs 🔗

@@ -5,8 +5,11 @@ use std::pin::Pin;
 use std::str::FromStr;
 use std::task::{Context, Poll};
 use tokio::net::TcpStream;
+#[cfg(feature = "tls-native")]
+use tokio_native_tls::TlsStream;
+#[cfg(feature = "tls-rust")]
+use tokio_rustls::client::TlsStream;
 use tokio_stream::StreamExt;
-use tokio_tls::TlsStream;
 use xmpp_parsers::{ns, Element, Jid};
 
 use super::auth::auth;

tokio-xmpp/src/error.rs 🔗

@@ -1,3 +1,4 @@
+#[cfg(feature = "tls-native")]
 use native_tls::Error as TlsError;
 use sasl::client::MechanismError as SaslMechanismError;
 use std::borrow::Cow;
@@ -5,6 +6,8 @@ use std::error::Error as StdError;
 use std::fmt;
 use std::io::Error as IoError;
 use std::str::Utf8Error;
+#[cfg(feature = "tls-rust")]
+use tokio_rustls::rustls::TLSError as TlsError;
 use trust_dns_proto::error::ProtoError;
 use trust_dns_resolver::error::ResolveError;
 

tokio-xmpp/src/starttls.rs 🔗

@@ -1,13 +1,49 @@
 use futures::{sink::SinkExt, stream::StreamExt};
+#[cfg(feature = "tls-rust")]
+use idna;
+#[cfg(feature = "tls-native")]
 use native_tls::TlsConnector as NativeTlsConnector;
+#[cfg(feature = "tls-rust")]
+use std::sync::Arc;
 use tokio::io::{AsyncRead, AsyncWrite};
-use tokio_tls::{TlsConnector, TlsStream};
+#[cfg(feature = "tls-native")]
+use tokio_native_tls::{TlsConnector, TlsStream};
+#[cfg(feature = "tls-rust")]
+use tokio_rustls::{client::TlsStream, rustls::ClientConfig, TlsConnector};
+#[cfg(feature = "tls-rust")]
+use webpki::DNSNameRef;
 use xmpp_parsers::{ns, Element};
 
 use crate::xmpp_codec::Packet;
 use crate::xmpp_stream::XMPPStream;
 use crate::{Error, ProtocolError};
 
+#[cfg(feature = "tls-native")]
+async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
+    xmpp_stream: XMPPStream<S>,
+) -> Result<TlsStream<S>, Error> {
+    let domain = &xmpp_stream.jid.clone().domain();
+    let stream = xmpp_stream.into_inner();
+    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
+        .connect(&domain, stream)
+        .await?;
+    Ok(tls_stream)
+}
+
+#[cfg(feature = "tls-rust")]
+async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
+    xmpp_stream: XMPPStream<S>,
+) -> Result<TlsStream<S>, Error> {
+    let domain = &xmpp_stream.jid.clone().domain();
+    let ascii_domain = idna::domain_to_ascii(domain).map_err(|_| Error::Idna)?;
+    let domain = DNSNameRef::try_from_ascii_str(&ascii_domain).unwrap();
+    let stream = xmpp_stream.into_inner();
+    let tls_stream = TlsConnector::from(Arc::new(ClientConfig::new()))
+        .connect(domain, stream)
+        .await?;
+    Ok(tls_stream)
+}
+
 /// Performs `<starttls/>` on an XMPPStream and returns a binary
 /// TlsStream.
 pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
@@ -28,11 +64,5 @@ pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
         }
     }
 
-    let domain = xmpp_stream.jid.clone().domain();
-    let stream = xmpp_stream.into_inner();
-    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
-        .connect(&domain, stream)
-        .await?;
-
-    Ok(tls_stream)
+    get_tls_stream(xmpp_stream).await
 }