tokio-xmpp client: condense fn connect(), refactor out into stream_features

Astro created

Change summary

tokio-xmpp/src/client/auth.rs     |  8 ---
tokio-xmpp/src/client/bind.rs     | 60 ++++++++++++++-----------------
tokio-xmpp/src/client/mod.rs      | 62 +++++++++++---------------------
tokio-xmpp/src/lib.rs             |  1 
tokio-xmpp/src/stream_features.rs | 45 +++++++++++++++++++++++
tokio-xmpp/src/xmpp_stream.rs     |  5 +-
6 files changed, 100 insertions(+), 81 deletions(-)

Detailed changes

tokio-xmpp/src/client/auth.rs 🔗

@@ -13,8 +13,6 @@ use crate::xmpp_codec::Packet;
 use crate::xmpp_stream::XMPPStream;
 use crate::{AuthError, Error, ProtocolError};
 
-const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
-
 pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
     mut stream: XMPPStream<S>,
     creds: Credentials,
@@ -28,11 +26,7 @@ pub async fn auth<S: AsyncRead + AsyncWrite + Unpin>(
 
     let remote_mechs: HashSet<String> = stream
         .stream_features
-        .get_child("mechanisms", NS_XMPP_SASL)
-        .ok_or(AuthError::NoMechanism)?
-        .children()
-        .filter(|child| child.is("mechanism", NS_XMPP_SASL))
-        .map(|mech_el| mech_el.text())
+        .sasl_mechanisms()?
         .collect();
 
     for local_mech in local_mechs {

tokio-xmpp/src/client/bind.rs 🔗

@@ -10,46 +10,42 @@ use crate::xmpp_codec::Packet;
 use crate::xmpp_stream::XMPPStream;
 use crate::{Error, ProtocolError};
 
-const NS_XMPP_BIND: &str = "urn:ietf:params:xml:ns:xmpp-bind";
 const BIND_REQ_ID: &str = "resource-bind";
 
 pub async fn bind<S: AsyncRead + AsyncWrite + Unpin>(
     mut stream: XMPPStream<S>,
 ) -> Result<XMPPStream<S>, Error> {
-    match stream.stream_features.get_child("bind", NS_XMPP_BIND) {
-        None => {
-            // No resource binding available,
-            // return the (probably // usable) stream immediately
-            return Ok(stream);
-        }
-        Some(_) => {
-            let resource = if let Jid::Full(jid) = stream.jid.clone() {
-                Some(jid.resource)
-            } else {
-                None
-            };
-            let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
-            stream.send_stanza(iq).await?;
+    if stream.stream_features.can_bind() {
+        let resource = if let Jid::Full(jid) = stream.jid.clone() {
+            Some(jid.resource)
+        } else {
+            None
+        };
+        let iq = Iq::from_set(BIND_REQ_ID, BindQuery::new(resource));
+        stream.send_stanza(iq).await?;
 
-            loop {
-                match stream.next().await {
-                    Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) {
-                        Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload {
-                            IqType::Result(payload) => {
-                                payload
-                                    .and_then(|payload| BindResponse::try_from(payload).ok())
-                                    .map(|bind| stream.jid = bind.into());
-                                return Ok(stream);
-                            }
-                            _ => return Err(ProtocolError::InvalidBindResponse.into()),
-                        },
-                        _ => {}
+        loop {
+            match stream.next().await {
+                Some(Ok(Packet::Stanza(stanza))) => match Iq::try_from(stanza) {
+                    Ok(iq) if iq.id == BIND_REQ_ID => match iq.payload {
+                        IqType::Result(payload) => {
+                            payload
+                                .and_then(|payload| BindResponse::try_from(payload).ok())
+                                .map(|bind| stream.jid = bind.into());
+                            return Ok(stream);
+                        }
+                        _ => return Err(ProtocolError::InvalidBindResponse.into()),
                     },
-                    Some(Ok(_)) => {}
-                    Some(Err(e)) => return Err(e),
-                    None => return Err(Error::Disconnected),
-                }
+                    _ => {}
+                },
+                Some(Ok(_)) => {}
+                Some(Err(e)) => return Err(e),
+                None => return Err(Error::Disconnected),
             }
         }
+    } else {
+        // No resource binding available,
+        // return the (probably // usable) stream immediately
+        return Ok(stream);
     }
 }

tokio-xmpp/src/client/mod.rs 🔗

@@ -5,7 +5,6 @@ use std::mem::replace;
 use std::pin::Pin;
 use std::str::FromStr;
 use std::task::Context;
-use tokio::io::{AsyncRead, AsyncWrite};
 use tokio::net::TcpStream;
 use tokio::task::JoinHandle;
 use tokio::task::LocalSet;
@@ -14,13 +13,18 @@ use xmpp_parsers::{Element, Jid, JidParseError};
 
 use super::event::Event;
 use super::happy_eyeballs::connect;
-use super::starttls::{starttls, NS_XMPP_TLS};
+use super::starttls::starttls;
 use super::xmpp_codec::Packet;
 use super::xmpp_stream;
 use super::{Error, ProtocolError};
 
 mod auth;
+use auth::auth;
 mod bind;
+use bind::bind;
+
+pub const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
+pub const NS_XMPP_BIND: &str = "urn:ietf:params:xml:ns:xmpp-bind";
 
 /// XMPP client connection and state
 ///
@@ -79,56 +83,34 @@ impl Client {
         let password = password;
         let domain = idna::domain_to_ascii(&jid.clone().domain()).map_err(|_| Error::Idna)?;
 
+        // TCP connection
         let tcp_stream = connect(&domain, Some("_xmpp-client._tcp"), 5222).await?;
 
+        // Unencryped XMPPStream
         let xmpp_stream =
-            xmpp_stream::XMPPStream::start(tcp_stream, jid, NS_JABBER_CLIENT.to_owned()).await?;
-        let xmpp_stream = if Self::can_starttls(&xmpp_stream) {
-            Self::starttls(xmpp_stream).await?
+            xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), NS_JABBER_CLIENT.to_owned()).await?;
+
+        let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
+            // TlsStream
+            let tls_stream = starttls(xmpp_stream).await?;
+            // Encrypted XMPPStream
+            xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), NS_JABBER_CLIENT.to_owned()).await?
         } else {
             return Err(Error::Protocol(ProtocolError::NoTls));
         };
 
-        let xmpp_stream = Self::auth(xmpp_stream, username, password).await?;
-        let xmpp_stream = Self::bind(xmpp_stream).await?;
-        Ok(xmpp_stream)
-    }
-
-    fn can_starttls<S: AsyncRead + AsyncWrite + Unpin>(
-        xmpp_stream: &xmpp_stream::XMPPStream<S>,
-    ) -> bool {
-        xmpp_stream
-            .stream_features
-            .get_child("starttls", NS_XMPP_TLS)
-            .is_some()
-    }
-
-    async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
-        xmpp_stream: xmpp_stream::XMPPStream<S>,
-    ) -> Result<xmpp_stream::XMPPStream<TlsStream<S>>, Error> {
-        let jid = xmpp_stream.jid.clone();
-        let tls_stream = starttls(xmpp_stream).await?;
-        xmpp_stream::XMPPStream::start(tls_stream, jid, NS_JABBER_CLIENT.to_owned()).await
-    }
-
-    async fn auth<S: AsyncRead + AsyncWrite + Unpin + 'static>(
-        xmpp_stream: xmpp_stream::XMPPStream<S>,
-        username: String,
-        password: String,
-    ) -> Result<xmpp_stream::XMPPStream<S>, Error> {
-        let jid = xmpp_stream.jid.clone();
         let creds = Credentials::default()
             .with_username(username)
             .with_password(password)
             .with_channel_binding(ChannelBinding::None);
-        let stream = auth::auth(xmpp_stream, creds).await?;
-        xmpp_stream::XMPPStream::start(stream, jid, NS_JABBER_CLIENT.to_owned()).await
-    }
+        // Authenticated (unspecified) stream
+        let stream = auth(xmpp_stream, creds).await?;
+        // Authenticated XMPPStream
+        let xmpp_stream = xmpp_stream::XMPPStream::start(stream, jid, NS_JABBER_CLIENT.to_owned()).await?;
 
-    async fn bind<S: Unpin + AsyncRead + AsyncWrite>(
-        stream: xmpp_stream::XMPPStream<S>,
-    ) -> Result<xmpp_stream::XMPPStream<S>, Error> {
-        bind::bind(stream).await
+        // XMPPStream bound to user session
+        let xmpp_stream = bind(xmpp_stream).await?;
+        Ok(xmpp_stream)
     }
 
     /// Get the client's bound JID (the one reported by the XMPP

tokio-xmpp/src/lib.rs 🔗

@@ -9,6 +9,7 @@ pub use crate::xmpp_codec::Packet;
 mod event;
 mod happy_eyeballs;
 pub mod xmpp_stream;
+pub mod stream_features;
 pub use crate::event::Event;
 mod client;
 pub use crate::client::Client;

tokio-xmpp/src/stream_features.rs 🔗

@@ -0,0 +1,45 @@
+//! Contains wrapper for `<stream:features/>`
+
+use xmpp_parsers::Element;
+use crate::starttls::NS_XMPP_TLS;
+use crate::client::{NS_XMPP_SASL, NS_XMPP_BIND};
+use crate::error::AuthError;
+
+/// Wraps `<stream:features/>`, usually the very first nonza of an
+/// XMPPStream.
+///
+/// TODO: should this rather go into xmpp-parsers, kept in a decoded
+/// struct?
+pub struct StreamFeatures(pub Element);
+
+impl StreamFeatures {
+    /// Wrap the nonza
+    pub fn new(element: Element) -> Self {
+        StreamFeatures(element)
+    }
+
+    /// Can initiate TLS session with this server?
+    pub fn can_starttls(&self) -> bool {
+        self.0
+            .get_child("starttls", NS_XMPP_TLS)
+            .is_some()
+    }
+
+    /// Iterate over SASL mechanisms
+    pub fn sasl_mechanisms<'a>(&'a self) -> Result<impl Iterator<Item=String> + 'a, AuthError> {
+        Ok(self.0
+           .get_child("mechanisms", NS_XMPP_SASL)
+           .ok_or(AuthError::NoMechanism)?
+           .children()
+           .filter(|child| child.is("mechanism", NS_XMPP_SASL))
+           .map(|mech_el| mech_el.text())
+        )
+    }
+
+    /// Does server support user resource binding?
+    pub fn can_bind(&self) -> bool {
+        self.0
+            .get_child("bind", NS_XMPP_BIND)
+            .is_some()
+    }
+}

tokio-xmpp/src/xmpp_stream.rs 🔗

@@ -11,6 +11,7 @@ use tokio_util::codec::Framed;
 use xmpp_parsers::{Element, Jid};
 
 use crate::stream_start;
+use crate::stream_features::StreamFeatures;
 use crate::xmpp_codec::{Packet, XMPPCodec};
 use crate::Error;
 
@@ -27,7 +28,7 @@ pub struct XMPPStream<S: AsyncRead + AsyncWrite + Unpin> {
     /// Codec instance
     pub stream: Mutex<Framed<S, XMPPCodec>>,
     /// `<stream:features/>` for XMPP version 1.0
-    pub stream_features: Element,
+    pub stream_features: StreamFeatures,
     /// Root namespace
     ///
     /// This is different for either c2s, s2s, or component
@@ -49,7 +50,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> XMPPStream<S> {
         XMPPStream {
             jid,
             stream: Mutex::new(stream),
-            stream_features,
+            stream_features: StreamFeatures::new(stream_features),
             ns,
             id,
         }