DNS/TLS deps are now optional, component now also uses ServerConnector

moparisthebest created

Change summary

tokio-xmpp/Cargo.toml                     |  19 +-
tokio-xmpp/src/client/async_client.rs     | 100 --------------
tokio-xmpp/src/client/connect.rs          |  31 ---
tokio-xmpp/src/client/mod.rs              |   3 
tokio-xmpp/src/client/simple_client.rs    |  19 --
tokio-xmpp/src/component/connect.rs       |  18 ++
tokio-xmpp/src/component/mod.rs           |  40 +----
tokio-xmpp/src/connect.rs                 |  35 +++++
tokio-xmpp/src/error.rs                   |  63 +--------
tokio-xmpp/src/lib.rs                     |  24 ++-
tokio-xmpp/src/starttls.rs                |  85 ------------
tokio-xmpp/src/starttls/client.rs         |  35 +++++
tokio-xmpp/src/starttls/error.rs          | 105 +++++++++++++++
tokio-xmpp/src/starttls/happy_eyeballs.rs |  22 +-
tokio-xmpp/src/starttls/mod.rs            | 168 +++++++++++++++++++++++++
xmpp/Cargo.toml                           |   6 
xmpp/src/lib.rs                           |   4 
17 files changed, 440 insertions(+), 337 deletions(-)

Detailed changes

tokio-xmpp/Cargo.toml πŸ”—

@@ -14,17 +14,12 @@ edition = "2021"
 [dependencies]
 bytes = "1"
 futures = "0.3"
-idna = "0.4"
 log = "0.4"
-native-tls = { version = "0.2", optional = true }
 tokio = { version = "1", features = ["net", "rt", "rt-multi-thread", "macros"] }
-tokio-native-tls = { version = "0.3", optional = true }
-tokio-rustls = { version = "0.24", optional = true }
 tokio-stream = { version = "0.1", features = [] }
 tokio-util = { version = "0.7", features = ["codec"] }
-hickory-resolver = "0.24"
-rxml = "0.9.1"
 webpki-roots = { version = "0.25", optional = true }
+rxml = "0.9.1"
 rand = "^0.8"
 syntect = { version = "5", optional = true }
 # same repository dependencies
@@ -32,11 +27,21 @@ minidom = { version = "0.15", path = "../minidom" }
 sasl = { version = "0.5", path = "../sasl" }
 xmpp-parsers = { version = "0.20", path = "../parsers" }
 
+# these are only needed for starttls ServerConnector support
+hickory-resolver = { version = "0.24", optional = true}
+idna = { version = "0.4", optional = true}
+native-tls = { version = "0.2", optional = true }
+tokio-native-tls = { version = "0.3", optional = true }
+tokio-rustls = { version = "0.24", optional = true }
+
 [dev-dependencies]
 env_logger = { version = "0.10", default-features = false, features = ["auto-color", "humantime"] }
 
 [features]
-default = ["tls-native"]
+default = ["starttls-rust"]
+starttls = ["hickory-resolver", "idna"]
 tls-rust = ["tokio-rustls", "webpki-roots"]
 tls-native = ["tokio-native-tls", "native-tls"]
+starttls-native = ["starttls", "tls-native"]
+starttls-rust = ["starttls", "tls-rust"]
 syntax-highlighting = ["syntect"]

tokio-xmpp/src/client/async_client.rs πŸ”—

@@ -1,23 +1,16 @@
 use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
-use sasl::common::ChannelBinding;
 use std::mem::replace;
 use std::pin::Pin;
 use std::task::Context;
-use tokio::net::TcpStream;
 use tokio::task::JoinHandle;
 use xmpp_parsers::{ns, Element, Jid};
 
-use super::connect::{AsyncReadAndWrite, ServerConnector};
+use super::connect::client_login;
+use crate::connect::{AsyncReadAndWrite, ServerConnector};
 use crate::event::Event;
-use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
-use crate::starttls::starttls;
 use crate::xmpp_codec::Packet;
-use crate::xmpp_stream::{self, add_stanza_id, XMPPStream};
-use crate::{client_login, Error, ProtocolError};
-#[cfg(feature = "tls-native")]
-use tokio_native_tls::TlsStream;
-#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-use tokio_rustls::client::TlsStream;
+use crate::xmpp_stream::{add_stanza_id, XMPPStream};
+use crate::{Error, ProtocolError};
 
 /// XMPP client connection and state
 ///
@@ -43,76 +36,6 @@ pub struct Config<C> {
     pub server: C,
 }
 
-/// XMPP server connection configuration
-#[derive(Clone, Debug)]
-pub enum ServerConfig {
-    /// Use SRV record to find server host
-    UseSrv,
-    #[allow(unused)]
-    /// Manually define server host and port
-    Manual {
-        /// Server host name
-        host: String,
-        /// Server port
-        port: u16,
-    },
-}
-
-impl ServerConnector for ServerConfig {
-    type Stream = TlsStream<TcpStream>;
-    async fn connect(&self, jid: &Jid) -> Result<XMPPStream<Self::Stream>, Error> {
-        // TCP connection
-        let tcp_stream = match self {
-            ServerConfig::UseSrv => {
-                connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
-            }
-            ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
-        };
-
-        // Unencryped XMPPStream
-        let xmpp_stream =
-            xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
-                .await?;
-
-        if xmpp_stream.stream_features.can_starttls() {
-            // TlsStream
-            let tls_stream = starttls(xmpp_stream).await?;
-            // Encrypted XMPPStream
-            xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
-                .await
-        } else {
-            return Err(Error::Protocol(ProtocolError::NoTls));
-        }
-    }
-
-    fn channel_binding(
-        #[allow(unused_variables)] stream: &Self::Stream,
-    ) -> Result<sasl::common::ChannelBinding, Error> {
-        #[cfg(feature = "tls-native")]
-        {
-            log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
-            Ok(ChannelBinding::None)
-        }
-        #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-        {
-            let (_, connection) = stream.get_ref();
-            Ok(match connection.protocol_version() {
-                // TODO: Add support for TLSΒ 1.2 and earlier.
-                Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
-                    let data = vec![0u8; 32];
-                    let data = connection.export_keying_material(
-                        data,
-                        b"EXPORTER-Channel-Binding",
-                        None,
-                    )?;
-                    ChannelBinding::TlsExporter(data)
-                }
-                _ => ChannelBinding::None,
-            })
-        }
-    }
-}
-
 enum ClientState<S: AsyncReadAndWrite> {
     Invalid,
     Disconnected,
@@ -120,21 +43,6 @@ enum ClientState<S: AsyncReadAndWrite> {
     Connected(XMPPStream<S>),
 }
 
-impl Client<ServerConfig> {
-    /// Start a new XMPP client
-    ///
-    /// Start polling the returned instance so that it will connect
-    /// and yield events.
-    pub fn new<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
-        let config = Config {
-            jid: jid.into(),
-            password: password.into(),
-            server: ServerConfig::UseSrv,
-        };
-        Self::new_with_config(config)
-    }
-}
-
 impl<C: ServerConnector> Client<C> {
     /// Start a new client given that the JID is already parsed.
     pub fn new_with_config(config: Config<C>) -> Self {

tokio-xmpp/src/client/connect.rs πŸ”—

@@ -1,32 +1,11 @@
-use sasl::common::{ChannelBinding, Credentials};
-use tokio::io::{AsyncRead, AsyncWrite};
+use sasl::common::Credentials;
 use xmpp_parsers::{ns, Jid};
 
-use super::{auth::auth, bind::bind};
+use crate::client::auth::auth;
+use crate::client::bind::bind;
+use crate::connect::ServerConnector;
 use crate::{xmpp_stream::XMPPStream, Error};
 
-/// trait returned wrapped in XMPPStream by ServerConnector
-pub trait AsyncReadAndWrite: AsyncRead + AsyncWrite + Unpin + Send {}
-impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
-
-/// Trait called to connect to an XMPP server, perhaps called multiple times
-pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
-    /// The type of Stream this ServerConnector produces
-    type Stream: AsyncReadAndWrite;
-    /// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the <stream headers are exchanged
-    fn connect(
-        &self,
-        jid: &Jid,
-    ) -> impl std::future::Future<Output = Result<XMPPStream<Self::Stream>, Error>> + Send;
-
-    /// Return channel binding data if available
-    /// do not fail if channel binding is simply unavailable, just return Ok(None)
-    /// this should only be called after the TLS handshake is finished
-    fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Error> {
-        Ok(ChannelBinding::None)
-    }
-}
-
 /// Log into an XMPP server as a client with a jid+pass
 /// does channel binding if supported
 pub async fn client_login<C: ServerConnector>(
@@ -37,7 +16,7 @@ pub async fn client_login<C: ServerConnector>(
     let username = jid.node_str().unwrap();
     let password = password;
 
-    let xmpp_stream = server.connect(&jid).await?;
+    let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT).await?;
 
     let channel_binding = C::channel_binding(xmpp_stream.stream.get_ref())?;
 

tokio-xmpp/src/client/simple_client.rs πŸ”—

@@ -1,13 +1,15 @@
 use futures::{sink::SinkExt, Sink, Stream};
 use std::pin::Pin;
-use std::str::FromStr;
 use std::task::{Context, Poll};
 use tokio_stream::StreamExt;
 use xmpp_parsers::{ns, Element, Jid};
 
+use crate::connect::ServerConnector;
 use crate::xmpp_codec::Packet;
 use crate::xmpp_stream::{add_stanza_id, XMPPStream};
-use crate::{client_login, AsyncServerConfig, Error, ServerConnector};
+use crate::Error;
+
+use super::connect::client_login;
 
 /// A simple XMPP client connection
 ///
@@ -17,19 +19,6 @@ pub struct Client<C: ServerConnector> {
     stream: XMPPStream<C::Stream>,
 }
 
-impl Client<AsyncServerConfig> {
-    /// Start a new XMPP client and wait for a usable session
-    pub async fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, Error> {
-        let jid = Jid::from_str(jid)?;
-        Self::new_with_jid(jid, password.into()).await
-    }
-
-    /// Start a new client given that the JID is already parsed.
-    pub async fn new_with_jid(jid: Jid, password: String) -> Result<Self, Error> {
-        Self::new_with_jid_connector(AsyncServerConfig::UseSrv, jid, password).await
-    }
-}
-
 impl<C: ServerConnector> Client<C> {
     /// Start a new client given that the JID is already parsed.
     pub async fn new_with_jid_connector(

tokio-xmpp/src/component/connect.rs πŸ”—

@@ -0,0 +1,18 @@
+use xmpp_parsers::{ns, Jid};
+
+use crate::connect::ServerConnector;
+use crate::{xmpp_stream::XMPPStream, Error};
+
+use super::auth::auth;
+
+/// Log into an XMPP server as a client with a jid+pass
+pub async fn component_login<C: ServerConnector>(
+    connector: C,
+    jid: Jid,
+    password: String,
+) -> Result<XMPPStream<C::Stream>, Error> {
+    let password = password;
+    let mut xmpp_stream = connector.connect(&jid, ns::COMPONENT).await?;
+    auth(&mut xmpp_stream, password).await?;
+    Ok(xmpp_stream)
+}

tokio-xmpp/src/component/mod.rs πŸ”—

@@ -5,53 +5,39 @@ use futures::{sink::SinkExt, task::Poll, Sink, Stream};
 use std::pin::Pin;
 use std::str::FromStr;
 use std::task::Context;
-use tokio::net::TcpStream;
 use xmpp_parsers::{ns, Element, Jid};
 
-use super::happy_eyeballs::connect_to_host;
+use self::connect::component_login;
+
 use super::xmpp_codec::Packet;
-use super::xmpp_stream;
 use super::Error;
+use crate::connect::ServerConnector;
 use crate::xmpp_stream::add_stanza_id;
+use crate::xmpp_stream::XMPPStream;
 
 mod auth;
 
+pub(crate) mod connect;
+
 /// Component connection to an XMPP server
 ///
 /// This simplifies the `XMPPStream` to a `Stream`/`Sink` of `Element`
 /// (stanzas). Connection handling however is up to the user.
-pub struct Component {
+pub struct Component<C: ServerConnector> {
     /// The component's Jabber-Id
     pub jid: Jid,
-    stream: XMPPStream,
+    stream: XMPPStream<C::Stream>,
 }
 
-type XMPPStream = xmpp_stream::XMPPStream<TcpStream>;
-
-impl Component {
+impl<C: ServerConnector> Component<C> {
     /// Start a new XMPP component
-    pub async fn new(jid: &str, password: &str, server: &str, port: u16) -> Result<Self, Error> {
+    pub async fn new(jid: &str, password: &str, connector: C) -> Result<Self, Error> {
         let jid = Jid::from_str(jid)?;
         let password = password.to_owned();
-        let stream = Self::connect(jid.clone(), password, server, port).await?;
+        let stream = component_login(connector, jid.clone(), password).await?;
         Ok(Component { jid, stream })
     }
 
-    async fn connect(
-        jid: Jid,
-        password: String,
-        server: &str,
-        port: u16,
-    ) -> Result<XMPPStream, Error> {
-        let password = password;
-        let tcp_stream = connect_to_host(server, port).await?;
-        let mut xmpp_stream =
-            xmpp_stream::XMPPStream::start(tcp_stream, jid, ns::COMPONENT_ACCEPT.to_owned())
-                .await?;
-        auth::auth(&mut xmpp_stream, password).await?;
-        Ok(xmpp_stream)
-    }
-
     /// Send stanza
     pub async fn send_stanza(&mut self, stanza: Element) -> Result<(), Error> {
         self.send(add_stanza_id(stanza, ns::COMPONENT_ACCEPT)).await
@@ -63,7 +49,7 @@ impl Component {
     }
 }
 
-impl Stream for Component {
+impl<C: ServerConnector> Stream for Component<C> {
     type Item = Element;
 
     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
@@ -86,7 +72,7 @@ impl Stream for Component {
     }
 }
 
-impl Sink<Element> for Component {
+impl<C: ServerConnector> Sink<Element> for Component<C> {
     type Error = Error;
 
     fn start_send(mut self: Pin<&mut Self>, item: Element) -> Result<(), Self::Error> {

tokio-xmpp/src/connect.rs πŸ”—

@@ -0,0 +1,35 @@
+//! `ServerConnector` provides streams for XMPP clients
+
+use sasl::common::ChannelBinding;
+use tokio::io::{AsyncRead, AsyncWrite};
+use xmpp_parsers::Jid;
+
+use crate::xmpp_stream::XMPPStream;
+
+/// trait returned wrapped in XMPPStream by ServerConnector
+pub trait AsyncReadAndWrite: AsyncRead + AsyncWrite + Unpin + Send {}
+impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
+
+/// Trait that must be extended by the implementation of ServerConnector
+pub trait ServerConnectorError: std::error::Error + Send {}
+
+/// Trait called to connect to an XMPP server, perhaps called multiple times
+pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
+    /// The type of Stream this ServerConnector produces
+    type Stream: AsyncReadAndWrite;
+    /// Error type to return
+    type Error: ServerConnectorError;
+    /// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the <stream headers are exchanged
+    fn connect(
+        &self,
+        jid: &Jid,
+        ns: &str,
+    ) -> impl std::future::Future<Output = Result<XMPPStream<Self::Stream>, Self::Error>> + Send;
+
+    /// Return channel binding data if available
+    /// do not fail if channel binding is simply unavailable, just return Ok(None)
+    /// this should only be called after the TLS handshake is finished
+    fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Self::Error> {
+        Ok(ChannelBinding::None)
+    }
+}

tokio-xmpp/src/error.rs πŸ”—

@@ -1,41 +1,26 @@
-use hickory_resolver::{error::ResolveError, proto::error::ProtoError};
-#[cfg(feature = "tls-native")]
-use native_tls::Error as TlsError;
 use sasl::client::MechanismError as SaslMechanismError;
 use std::borrow::Cow;
 use std::error::Error as StdError;
 use std::fmt;
 use std::io::Error as IoError;
 use std::str::Utf8Error;
-#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-use tokio_rustls::rustls::client::InvalidDnsNameError;
-#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-use tokio_rustls::rustls::Error as TlsError;
 
 use xmpp_parsers::sasl::DefinedCondition as SaslDefinedCondition;
 use xmpp_parsers::{Error as ParsersError, JidParseError};
 
+use crate::connect::ServerConnectorError;
+
 /// Top-level error type
 #[derive(Debug)]
 pub enum Error {
     /// I/O error
     Io(IoError),
-    /// Error resolving DNS and establishing a connection
-    Connection(ConnecterError),
-    /// DNS label conversion error, no details available from module
-    /// `idna`
-    Idna,
     /// Error parsing Jabber-Id
     JidParse(JidParseError),
     /// Protocol-level error
     Protocol(ProtocolError),
     /// Authentication error
     Auth(AuthError),
-    /// TLS error
-    Tls(TlsError),
-    #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-    /// DNS name parsing error
-    DnsNameError(InvalidDnsNameError),
     /// Connection closed
     Disconnected,
     /// Shoud never happen
@@ -44,6 +29,8 @@ pub enum Error {
     Fmt(fmt::Error),
     /// Utf8 error
     Utf8(Utf8Error),
+    /// Error resolving DNS and/or establishing a connection, returned by a ServerConnector impl
+    Connection(Box<dyn ServerConnectorError>),
 }
 
 impl fmt::Display for Error {
@@ -51,13 +38,9 @@ impl fmt::Display for Error {
         match self {
             Error::Io(e) => write!(fmt, "IO error: {}", e),
             Error::Connection(e) => write!(fmt, "connection error: {}", e),
-            Error::Idna => write!(fmt, "IDNA error"),
             Error::JidParse(e) => write!(fmt, "jid parse error: {}", e),
             Error::Protocol(e) => write!(fmt, "protocol error: {}", e),
             Error::Auth(e) => write!(fmt, "authentication error: {}", e),
-            Error::Tls(e) => write!(fmt, "TLS error: {}", e),
-            #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-            Error::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
             Error::Disconnected => write!(fmt, "disconnected"),
             Error::InvalidState => write!(fmt, "invalid state"),
             Error::Fmt(e) => write!(fmt, "Fmt error: {}", e),
@@ -74,9 +57,9 @@ impl From<IoError> for Error {
     }
 }
 
-impl From<ConnecterError> for Error {
-    fn from(e: ConnecterError) -> Self {
-        Error::Connection(e)
+impl<T: ServerConnectorError + 'static> From<T> for Error {
+    fn from(e: T) -> Self {
+        Error::Connection(Box::new(e))
     }
 }
 
@@ -98,12 +81,6 @@ impl From<AuthError> for Error {
     }
 }
 
-impl From<TlsError> for Error {
-    fn from(e: TlsError) -> Self {
-        Error::Tls(e)
-    }
-}
-
 impl From<fmt::Error> for Error {
     fn from(e: fmt::Error) -> Self {
         Error::Fmt(e)
@@ -116,13 +93,6 @@ impl From<Utf8Error> for Error {
     }
 }
 
-#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-impl From<InvalidDnsNameError> for Error {
-    fn from(e: InvalidDnsNameError) -> Self {
-        Error::DnsNameError(e)
-    }
-}
-
 /// XML parse error wrapper type
 #[derive(Debug)]
 pub struct ParseError(pub Cow<'static, str>);
@@ -227,22 +197,3 @@ impl fmt::Display for AuthError {
         }
     }
 }
-
-/// Error establishing connection
-#[derive(Debug)]
-pub enum ConnecterError {
-    /// All attempts failed, no error available
-    AllFailed,
-    /// DNS protocol error
-    Dns(ProtoError),
-    /// DNS resolution error
-    Resolve(ResolveError),
-}
-
-impl StdError for ConnecterError {}
-
-impl std::fmt::Display for ConnecterError {
-    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
-        write!(fmt, "{:?}", self)
-    }
-}

tokio-xmpp/src/lib.rs πŸ”—

@@ -5,31 +5,35 @@
 #[cfg(all(feature = "tls-native", feature = "tls-rust"))]
 compile_error!("Both tls-native and tls-rust features can't be enabled at the same time.");
 
-#[cfg(all(not(feature = "tls-native"), not(feature = "tls-rust")))]
-compile_error!("One of tls-native and tls-rust features must be enabled.");
+#[cfg(all(
+    feature = "starttls",
+    not(feature = "tls-native"),
+    not(feature = "tls-rust")
+))]
+compile_error!(
+    "when starttls feature enabled one of tls-native and tls-rust features must be enabled."
+);
 
-mod starttls;
+#[cfg(feature = "starttls")]
+pub mod starttls;
 mod stream_start;
 mod xmpp_codec;
 pub use crate::xmpp_codec::Packet;
 mod event;
 pub use event::Event;
 mod client;
-mod happy_eyeballs;
+pub mod connect;
 pub mod stream_features;
 pub mod xmpp_stream;
+
 pub use client::{
-    async_client::{
-        Client as AsyncClient, Config as AsyncConfig, ServerConfig as AsyncServerConfig,
-    },
-    connect::{client_login, AsyncReadAndWrite, ServerConnector},
+    async_client::{Client as AsyncClient, Config as AsyncConfig},
     simple_client::Client as SimpleClient,
 };
 mod component;
 pub use crate::component::Component;
 mod error;
-pub use crate::error::{AuthError, ConnecterError, Error, ParseError, ProtocolError};
-pub use starttls::starttls;
+pub use crate::error::{AuthError, Error, ParseError, ProtocolError};
 
 // Re-exports
 pub use minidom::Element;

tokio-xmpp/src/starttls.rs πŸ”—

@@ -1,85 +0,0 @@
-use futures::{sink::SinkExt, stream::StreamExt};
-
-#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-use {
-    std::sync::Arc,
-    tokio_rustls::{
-        client::TlsStream,
-        rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
-        TlsConnector,
-    },
-    webpki_roots,
-};
-
-#[cfg(feature = "tls-native")]
-use {
-    native_tls::TlsConnector as NativeTlsConnector,
-    tokio_native_tls::{TlsConnector, TlsStream},
-};
-
-use tokio::io::{AsyncRead, AsyncWrite};
-use xmpp_parsers::{ns, Element};
-
-use crate::xmpp_codec::Packet;
-use crate::xmpp_stream::XMPPStream;
-use crate::{Error, ProtocolError};
-
-#[cfg(feature = "tls-native")]
-async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
-    xmpp_stream: XMPPStream<S>,
-) -> Result<TlsStream<S>, Error> {
-    let domain = xmpp_stream.jid.domain_str().to_owned();
-    let stream = xmpp_stream.into_inner();
-    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
-        .connect(&domain, stream)
-        .await?;
-    Ok(tls_stream)
-}
-
-#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
-    xmpp_stream: XMPPStream<S>,
-) -> Result<TlsStream<S>, Error> {
-    let domain = xmpp_stream.jid.domain_str().to_owned();
-    let domain = ServerName::try_from(domain.as_str())?;
-    let stream = xmpp_stream.into_inner();
-    let mut root_store = RootCertStore::empty();
-    root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
-        OwnedTrustAnchor::from_subject_spki_name_constraints(
-            ta.subject,
-            ta.spki,
-            ta.name_constraints,
-        )
-    }));
-    let config = ClientConfig::builder()
-        .with_safe_defaults()
-        .with_root_certificates(root_store)
-        .with_no_client_auth();
-    let tls_stream = TlsConnector::from(Arc::new(config))
-        .connect(domain, stream)
-        .await?;
-    Ok(tls_stream)
-}
-
-/// Performs `<starttls/>` on an XMPPStream and returns a binary
-/// TlsStream.
-pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
-    mut xmpp_stream: XMPPStream<S>,
-) -> Result<TlsStream<S>, Error> {
-    let nonza = Element::builder("starttls", ns::TLS).build();
-    let packet = Packet::Stanza(nonza);
-    xmpp_stream.send(packet).await?;
-
-    loop {
-        match xmpp_stream.next().await {
-            Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
-            Some(Ok(Packet::Text(_))) => {}
-            Some(Err(e)) => return Err(e.into()),
-            _ => {
-                return Err(ProtocolError::NoTls.into());
-            }
-        }
-    }
-
-    get_tls_stream(xmpp_stream).await
-}

tokio-xmpp/src/starttls/client.rs πŸ”—

@@ -0,0 +1,35 @@
+use std::str::FromStr;
+
+use xmpp_parsers::Jid;
+
+use crate::{AsyncClient, AsyncConfig, Error, SimpleClient};
+
+use super::ServerConfig;
+
+impl AsyncClient<ServerConfig> {
+    /// Start a new XMPP client
+    ///
+    /// Start polling the returned instance so that it will connect
+    /// and yield events.
+    pub fn new<J: Into<Jid>, P: Into<String>>(jid: J, password: P) -> Self {
+        let config = AsyncConfig {
+            jid: jid.into(),
+            password: password.into(),
+            server: ServerConfig::UseSrv,
+        };
+        Self::new_with_config(config)
+    }
+}
+
+impl SimpleClient<ServerConfig> {
+    /// Start a new XMPP client and wait for a usable session
+    pub async fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, Error> {
+        let jid = Jid::from_str(jid)?;
+        Self::new_with_jid(jid, password.into()).await
+    }
+
+    /// Start a new client given that the JID is already parsed.
+    pub async fn new_with_jid(jid: Jid, password: String) -> Result<Self, Error> {
+        Self::new_with_jid_connector(ServerConfig::UseSrv, jid, password).await
+    }
+}

tokio-xmpp/src/starttls/error.rs πŸ”—

@@ -0,0 +1,105 @@
+use hickory_resolver::{error::ResolveError, proto::error::ProtoError};
+#[cfg(feature = "tls-native")]
+use native_tls::Error as TlsError;
+use std::borrow::Cow;
+use std::error::Error as StdError;
+use std::fmt;
+#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+use tokio_rustls::rustls::client::InvalidDnsNameError;
+#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+use tokio_rustls::rustls::Error as TlsError;
+
+/// Top-level error type
+#[derive(Debug)]
+pub enum Error {
+    /// Error resolving DNS and establishing a connection
+    Connection(ConnectorError),
+    /// DNS label conversion error, no details available from module
+    /// `idna`
+    Idna,
+    /// TLS error
+    Tls(TlsError),
+    #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+    /// DNS name parsing error
+    DnsNameError(InvalidDnsNameError),
+    /// tokio-xmpp error
+    TokioXMPP(crate::error::Error),
+}
+
+impl fmt::Display for Error {
+    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Error::Connection(e) => write!(fmt, "connection error: {}", e),
+            Error::Idna => write!(fmt, "IDNA error"),
+            Error::Tls(e) => write!(fmt, "TLS error: {}", e),
+            #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+            Error::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
+            Error::TokioXMPP(e) => write!(fmt, "TokioXMPP error: {}", e),
+        }
+    }
+}
+
+impl StdError for Error {}
+
+impl From<crate::error::Error> for Error {
+    fn from(e: crate::error::Error) -> Self {
+        Error::TokioXMPP(e)
+    }
+}
+
+impl From<ConnectorError> for Error {
+    fn from(e: ConnectorError) -> Self {
+        Error::Connection(e)
+    }
+}
+
+impl From<TlsError> for Error {
+    fn from(e: TlsError) -> Self {
+        Error::Tls(e)
+    }
+}
+
+#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+impl From<InvalidDnsNameError> for Error {
+    fn from(e: InvalidDnsNameError) -> Self {
+        Error::DnsNameError(e)
+    }
+}
+
+/// XML parse error wrapper type
+#[derive(Debug)]
+pub struct ParseError(pub Cow<'static, str>);
+
+impl StdError for ParseError {
+    fn description(&self) -> &str {
+        self.0.as_ref()
+    }
+    fn cause(&self) -> Option<&dyn StdError> {
+        None
+    }
+}
+
+impl fmt::Display for ParseError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "{}", self.0)
+    }
+}
+
+/// Error establishing connection
+#[derive(Debug)]
+pub enum ConnectorError {
+    /// All attempts failed, no error available
+    AllFailed,
+    /// DNS protocol error
+    Dns(ProtoError),
+    /// DNS resolution error
+    Resolve(ResolveError),
+}
+
+impl StdError for ConnectorError {}
+
+impl std::fmt::Display for ConnectorError {
+    fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+        write!(fmt, "{:?}", self)
+    }
+}

tokio-xmpp/src/happy_eyeballs.rs β†’ tokio-xmpp/src/starttls/happy_eyeballs.rs πŸ”—

@@ -1,4 +1,4 @@
-use crate::{ConnecterError, Error};
+use super::error::{ConnectorError, Error};
 use hickory_resolver::{IntoName, TokioAsyncResolver};
 use idna;
 use log::debug;
@@ -9,22 +9,24 @@ pub async fn connect_to_host(domain: &str, port: u16) -> Result<TcpStream, Error
     let ascii_domain = idna::domain_to_ascii(&domain).map_err(|_| Error::Idna)?;
 
     if let Ok(ip) = ascii_domain.parse() {
-        return Ok(TcpStream::connect(&SocketAddr::new(ip, port)).await?);
+        return Ok(TcpStream::connect(&SocketAddr::new(ip, port))
+            .await
+            .map_err(|e| Error::from(crate::Error::Io(e)))?);
     }
 
-    let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnecterError::Resolve)?;
+    let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnectorError::Resolve)?;
 
     let ips = resolver
         .lookup_ip(ascii_domain)
         .await
-        .map_err(ConnecterError::Resolve)?;
+        .map_err(ConnectorError::Resolve)?;
     for ip in ips.iter() {
         match TcpStream::connect(&SocketAddr::new(ip, port)).await {
             Ok(stream) => return Ok(stream),
             Err(_) => {}
         }
     }
-    Err(Error::Disconnected)
+    Err(crate::Error::Disconnected.into())
 }
 
 pub async fn connect_with_srv(
@@ -36,14 +38,16 @@ pub async fn connect_with_srv(
 
     if let Ok(ip) = ascii_domain.parse() {
         debug!("Attempting connection to {ip}:{fallback_port}");
-        return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port)).await?);
+        return Ok(TcpStream::connect(&SocketAddr::new(ip, fallback_port))
+            .await
+            .map_err(|e| Error::from(crate::Error::Io(e)))?);
     }
 
-    let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnecterError::Resolve)?;
+    let resolver = TokioAsyncResolver::tokio_from_system_conf().map_err(ConnectorError::Resolve)?;
 
     let srv_domain = format!("{}.{}.", srv, ascii_domain)
         .into_name()
-        .map_err(ConnecterError::Dns)?;
+        .map_err(ConnectorError::Dns)?;
     let srv_records = resolver.srv_lookup(srv_domain.clone()).await.ok();
 
     match srv_records {
@@ -56,7 +60,7 @@ pub async fn connect_with_srv(
                     Err(_) => {}
                 }
             }
-            Err(Error::Disconnected)
+            Err(crate::Error::Disconnected.into())
         }
         None => {
             // SRV lookup error, retry with hostname

tokio-xmpp/src/starttls/mod.rs πŸ”—

@@ -0,0 +1,168 @@
+//! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections
+
+use futures::{sink::SinkExt, stream::StreamExt};
+
+#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+use {
+    std::sync::Arc,
+    tokio_rustls::{
+        client::TlsStream,
+        rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
+        TlsConnector,
+    },
+    webpki_roots,
+};
+
+#[cfg(feature = "tls-native")]
+use {
+    native_tls::TlsConnector as NativeTlsConnector,
+    tokio_native_tls::{TlsConnector, TlsStream},
+};
+
+use sasl::common::ChannelBinding;
+use tokio::{
+    io::{AsyncRead, AsyncWrite},
+    net::TcpStream,
+};
+use xmpp_parsers::{ns, Element, Jid};
+
+use crate::{connect::ServerConnector, xmpp_codec::Packet};
+use crate::{connect::ServerConnectorError, xmpp_stream::XMPPStream};
+
+use self::error::Error;
+use self::happy_eyeballs::{connect_to_host, connect_with_srv};
+
+mod client;
+mod error;
+mod happy_eyeballs;
+
+/// StartTLS XMPP server connection configuration
+#[derive(Clone, Debug)]
+pub enum ServerConfig {
+    /// Use SRV record to find server host
+    UseSrv,
+    #[allow(unused)]
+    /// Manually define server host and port
+    Manual {
+        /// Server host name
+        host: String,
+        /// Server port
+        port: u16,
+    },
+}
+
+impl ServerConnectorError for Error {}
+
+impl ServerConnector for ServerConfig {
+    type Stream = TlsStream<TcpStream>;
+    type Error = Error;
+    async fn connect(&self, jid: &Jid, ns: &str) -> Result<XMPPStream<Self::Stream>, Error> {
+        // TCP connection
+        let tcp_stream = match self {
+            ServerConfig::UseSrv => {
+                connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
+            }
+            ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
+        };
+
+        // Unencryped XMPPStream
+        let xmpp_stream = XMPPStream::start(tcp_stream, jid.clone(), ns.to_owned()).await?;
+
+        if xmpp_stream.stream_features.can_starttls() {
+            // TlsStream
+            let tls_stream = starttls(xmpp_stream).await?;
+            // Encrypted XMPPStream
+            Ok(XMPPStream::start(tls_stream, jid.clone(), ns.to_owned()).await?)
+        } else {
+            return Err(crate::Error::Protocol(crate::ProtocolError::NoTls).into());
+        }
+    }
+
+    fn channel_binding(
+        #[allow(unused_variables)] stream: &Self::Stream,
+    ) -> Result<sasl::common::ChannelBinding, Error> {
+        #[cfg(feature = "tls-native")]
+        {
+            log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
+            Ok(ChannelBinding::None)
+        }
+        #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+        {
+            let (_, connection) = stream.get_ref();
+            Ok(match connection.protocol_version() {
+                // TODO: Add support for TLS 1.2 and earlier.
+                Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
+                    let data = vec![0u8; 32];
+                    let data = connection.export_keying_material(
+                        data,
+                        b"EXPORTER-Channel-Binding",
+                        None,
+                    )?;
+                    ChannelBinding::TlsExporter(data)
+                }
+                _ => ChannelBinding::None,
+            })
+        }
+    }
+}
+
+#[cfg(feature = "tls-native")]
+async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
+    xmpp_stream: XMPPStream<S>,
+) -> Result<TlsStream<S>, Error> {
+    let domain = xmpp_stream.jid.domain_str().to_owned();
+    let stream = xmpp_stream.into_inner();
+    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
+        .connect(&domain, stream)
+        .await?;
+    Ok(tls_stream)
+}
+
+#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
+    xmpp_stream: XMPPStream<S>,
+) -> Result<TlsStream<S>, Error> {
+    let domain = xmpp_stream.jid.domain_str().to_owned();
+    let domain = ServerName::try_from(domain.as_str())?;
+    let stream = xmpp_stream.into_inner();
+    let mut root_store = RootCertStore::empty();
+    root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
+        OwnedTrustAnchor::from_subject_spki_name_constraints(
+            ta.subject,
+            ta.spki,
+            ta.name_constraints,
+        )
+    }));
+    let config = ClientConfig::builder()
+        .with_safe_defaults()
+        .with_root_certificates(root_store)
+        .with_no_client_auth();
+    let tls_stream = TlsConnector::from(Arc::new(config))
+        .connect(domain, stream)
+        .await
+        .map_err(|e| Error::from(crate::Error::Io(e)))?;
+    Ok(tls_stream)
+}
+
+/// Performs `<starttls/>` on an XMPPStream and returns a binary
+/// TlsStream.
+pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
+    mut xmpp_stream: XMPPStream<S>,
+) -> Result<TlsStream<S>, Error> {
+    let nonza = Element::builder("starttls", ns::TLS).build();
+    let packet = Packet::Stanza(nonza);
+    xmpp_stream.send(packet).await?;
+
+    loop {
+        match xmpp_stream.next().await {
+            Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
+            Some(Ok(Packet::Text(_))) => {}
+            Some(Err(e)) => return Err(e.into()),
+            _ => {
+                return Err(crate::Error::Protocol(crate::ProtocolError::NoTls).into());
+            }
+        }
+    }
+
+    get_tls_stream(xmpp_stream).await
+}

xmpp/Cargo.toml πŸ”—

@@ -31,7 +31,7 @@ name = "hello_bot"
 required-features = ["avatars"]
 
 [features]
-default = ["avatars", "tls-native"]
-tls-native = ["tokio-xmpp/tls-native"]
-tls-rust = ["tokio-xmpp/tls-rust"]
+default = ["avatars", "starttls-rust"]
+starttls-native = ["tokio-xmpp/starttls", "tokio-xmpp/tls-native"]
+starttls-rust = ["tokio-xmpp/starttls", "tokio-xmpp/tls-rust"]
 avatars = []

xmpp/src/lib.rs πŸ”—

@@ -7,7 +7,7 @@
 #![deny(bare_trait_objects)]
 
 pub use tokio_xmpp::parsers;
-use tokio_xmpp::{AsyncClient, AsyncServerConfig};
+use tokio_xmpp::AsyncClient;
 pub use tokio_xmpp::{BareJid, Element, FullJid, Jid};
 #[macro_use]
 extern crate log;
@@ -32,7 +32,7 @@ pub use builder::{ClientBuilder, ClientType};
 pub use event::Event;
 pub use feature::ClientFeature;
 
-type TokioXmppClient = AsyncClient<AsyncServerConfig>;
+type TokioXmppClient = AsyncClient<tokio_xmpp::starttls::ServerConfig>;
 
 pub type Error = tokio_xmpp::Error;
 pub type Id = Option<String>;