ServerConnector and AsyncClient support channel binding, SimpleClient uses ServerConnector

moparisthebest created

Change summary

tokio-xmpp/src/client/async_client.rs  |  78 ++++++++-----------
tokio-xmpp/src/client/connect.rs       |  56 ++++++++++++++
tokio-xmpp/src/client/mod.rs           |   1 
tokio-xmpp/src/client/simple_client.rs | 110 +++++----------------------
tokio-xmpp/src/lib.rs                  |   4 
5 files changed, 112 insertions(+), 137 deletions(-)

Detailed changes

tokio-xmpp/src/client/async_client.rs 🔗

@@ -1,5 +1,5 @@
 use futures::{sink::SinkExt, task::Poll, Future, Sink, Stream};
-use sasl::common::{ChannelBinding, Credentials};
+use sasl::common::ChannelBinding;
 use std::mem::replace;
 use std::pin::Pin;
 use std::task::Context;
@@ -7,14 +7,13 @@ use tokio::net::TcpStream;
 use tokio::task::JoinHandle;
 use xmpp_parsers::{ns, Element, Jid};
 
-use super::auth::auth;
-use super::bind::bind;
+use super::connect::{AsyncReadAndWrite, ServerConnector};
 use crate::event::Event;
 use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
 use crate::starttls::starttls;
 use crate::xmpp_codec::Packet;
 use crate::xmpp_stream::{self, add_stanza_id, XMPPStream};
-use crate::{Error, ProtocolError};
+use crate::{client_login, Error, ProtocolError};
 #[cfg(feature = "tls-native")]
 use tokio_native_tls::TlsStream;
 #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
@@ -44,17 +43,6 @@ pub struct Config<C> {
     pub server: C,
 }
 
-/// Trait called to connect to an XMPP server, perhaps called multiple times
-pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
-    /// The type of Stream this ServerConnector produces
-    type Stream: AsyncReadAndWrite;
-    /// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the <stream headers are exchanged
-    fn connect(
-        &self,
-        jid: &Jid,
-    ) -> impl std::future::Future<Output = Result<XMPPStream<Self::Stream>, Error>> + Send;
-}
-
 /// XMPP server connection configuration
 #[derive(Clone, Debug)]
 pub enum ServerConfig {
@@ -96,11 +84,34 @@ impl ServerConnector for ServerConfig {
             return Err(Error::Protocol(ProtocolError::NoTls));
         }
     }
-}
 
-/// trait used by XMPPStream type
-pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send {}
-impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
+    fn channel_binding(
+        #[allow(unused_variables)] stream: &Self::Stream,
+    ) -> Result<sasl::common::ChannelBinding, Error> {
+        #[cfg(feature = "tls-native")]
+        {
+            log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
+            Ok(ChannelBinding::None)
+        }
+        #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+        {
+            let (_, connection) = stream.get_ref();
+            Ok(match connection.protocol_version() {
+                // TODO: Add support for TLS 1.2 and earlier.
+                Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
+                    let data = vec![0u8; 32];
+                    let data = connection.export_keying_material(
+                        data,
+                        b"EXPORTER-Channel-Binding",
+                        None,
+                    )?;
+                    ChannelBinding::TlsExporter(data)
+                }
+                _ => ChannelBinding::None,
+            })
+        }
+    }
+}
 
 enum ClientState<S: AsyncReadAndWrite> {
     Invalid,
@@ -127,7 +138,7 @@ impl Client<ServerConfig> {
 impl<C: ServerConnector> Client<C> {
     /// Start a new client given that the JID is already parsed.
     pub fn new_with_config(config: Config<C>) -> Self {
-        let connect = tokio::spawn(Self::connect(
+        let connect = tokio::spawn(client_login(
             config.server.clone(),
             config.jid.clone(),
             config.password.clone(),
@@ -147,31 +158,6 @@ impl<C: ServerConnector> Client<C> {
         self
     }
 
-    async fn connect(
-        server: C,
-        jid: Jid,
-        password: String,
-    ) -> Result<XMPPStream<C::Stream>, Error> {
-        let username = jid.node_str().unwrap();
-        let password = password;
-
-        let xmpp_stream = server.connect(&jid).await?;
-
-        let creds = Credentials::default()
-            .with_username(username)
-            .with_password(password)
-            .with_channel_binding(ChannelBinding::None);
-        // Authenticated (unspecified) stream
-        let stream = auth(xmpp_stream, creds).await?;
-        // Authenticated XMPPStream
-        let xmpp_stream =
-            xmpp_stream::XMPPStream::start(stream, jid, ns::JABBER_CLIENT.to_owned()).await?;
-
-        // XMPPStream bound to user session
-        let xmpp_stream = bind(xmpp_stream).await?;
-        Ok(xmpp_stream)
-    }
-
     /// Get the client's bound JID (the one reported by the XMPP
     /// server).
     pub fn bound_jid(&self) -> Option<&Jid> {
@@ -222,7 +208,7 @@ impl<C: ServerConnector> Stream for Client<C> {
             ClientState::Invalid => panic!("Invalid client state"),
             ClientState::Disconnected if self.reconnect => {
                 // TODO: add timeout
-                let connect = tokio::spawn(Self::connect(
+                let connect = tokio::spawn(client_login(
                     self.config.server.clone(),
                     self.config.jid.clone(),
                     self.config.password.clone(),

tokio-xmpp/src/client/connect.rs 🔗

@@ -0,0 +1,56 @@
+use sasl::common::{ChannelBinding, Credentials};
+use tokio::io::{AsyncRead, AsyncWrite};
+use xmpp_parsers::{ns, Jid};
+
+use super::{auth::auth, bind::bind};
+use crate::{xmpp_stream::XMPPStream, Error};
+
+/// trait returned wrapped in XMPPStream by ServerConnector
+pub trait AsyncReadAndWrite: AsyncRead + AsyncWrite + Unpin + Send {}
+impl<T: AsyncRead + AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
+
+/// Trait called to connect to an XMPP server, perhaps called multiple times
+pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
+    /// The type of Stream this ServerConnector produces
+    type Stream: AsyncReadAndWrite;
+    /// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the <stream headers are exchanged
+    fn connect(
+        &self,
+        jid: &Jid,
+    ) -> impl std::future::Future<Output = Result<XMPPStream<Self::Stream>, Error>> + Send;
+
+    /// Return channel binding data if available
+    /// do not fail if channel binding is simply unavailable, just return Ok(None)
+    /// this should only be called after the TLS handshake is finished
+    fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Error> {
+        Ok(ChannelBinding::None)
+    }
+}
+
+/// Log into an XMPP server as a client with a jid+pass
+/// does channel binding if supported
+pub async fn client_login<C: ServerConnector>(
+    server: C,
+    jid: Jid,
+    password: String,
+) -> Result<XMPPStream<C::Stream>, Error> {
+    let username = jid.node_str().unwrap();
+    let password = password;
+
+    let xmpp_stream = server.connect(&jid).await?;
+
+    let channel_binding = C::channel_binding(xmpp_stream.stream.get_ref())?;
+
+    let creds = Credentials::default()
+        .with_username(username)
+        .with_password(password)
+        .with_channel_binding(channel_binding);
+    // Authenticated (unspecified) stream
+    let stream = auth(xmpp_stream, creds).await?;
+    // Authenticated XMPPStream
+    let xmpp_stream = XMPPStream::start(stream, jid, ns::JABBER_CLIENT.to_owned()).await?;
+
+    // XMPPStream bound to user session
+    let xmpp_stream = bind(xmpp_stream).await?;
+    Ok(xmpp_stream)
+}

tokio-xmpp/src/client/simple_client.rs 🔗

@@ -1,119 +1,51 @@
 use futures::{sink::SinkExt, Sink, Stream};
-use idna;
-#[cfg(feature = "tls-native")]
-use log::warn;
-use sasl::common::{ChannelBinding, Credentials};
 use std::pin::Pin;
 use std::str::FromStr;
 use std::task::{Context, Poll};
-use tokio::net::TcpStream;
-#[cfg(feature = "tls-native")]
-use tokio_native_tls::TlsStream;
-#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-use tokio_rustls::{client::TlsStream, rustls::ProtocolVersion};
 use tokio_stream::StreamExt;
 use xmpp_parsers::{ns, Element, Jid};
 
-use super::auth::auth;
-use super::bind::bind;
-use crate::happy_eyeballs::connect_with_srv;
-use crate::starttls::starttls;
 use crate::xmpp_codec::Packet;
-use crate::xmpp_stream::{self, add_stanza_id};
-use crate::{Error, ProtocolError};
+use crate::xmpp_stream::{add_stanza_id, XMPPStream};
+use crate::{client_login, AsyncServerConfig, Error, ServerConnector};
 
 /// A simple XMPP client connection
 ///
 /// This implements the `futures` crate's [`Stream`](#impl-Stream) and
 /// [`Sink`](#impl-Sink<Packet>) traits.
-pub struct Client {
-    stream: XMPPStream,
+pub struct Client<C: ServerConnector> {
+    stream: XMPPStream<C::Stream>,
 }
 
-type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
-
-impl Client {
+impl Client<AsyncServerConfig> {
     /// Start a new XMPP client and wait for a usable session
     pub async fn new<P: Into<String>>(jid: &str, password: P) -> Result<Self, Error> {
         let jid = Jid::from_str(jid)?;
-        let client = Self::new_with_jid(jid, password.into()).await?;
-        Ok(client)
+        Self::new_with_jid(jid, password.into()).await
     }
 
     /// Start a new client given that the JID is already parsed.
     pub async fn new_with_jid(jid: Jid, password: String) -> Result<Self, Error> {
-        let stream = Self::connect(jid, password).await?;
+        Self::new_with_jid_connector(AsyncServerConfig::UseSrv, jid, password).await
+    }
+}
+
+impl<C: ServerConnector> Client<C> {
+    /// Start a new client given that the JID is already parsed.
+    pub async fn new_with_jid_connector(
+        connector: C,
+        jid: Jid,
+        password: String,
+    ) -> Result<Self, Error> {
+        let stream = client_login(connector, jid, password).await?;
         Ok(Client { stream })
     }
 
     /// Get direct access to inner XMPP Stream
-    pub fn into_inner(self) -> XMPPStream {
+    pub fn into_inner(self) -> XMPPStream<C::Stream> {
         self.stream
     }
 
-    async fn connect(jid: Jid, password: String) -> Result<XMPPStream, Error> {
-        let username = jid.node_str().unwrap();
-        let password = password;
-        let domain = idna::domain_to_ascii(jid.domain_str()).map_err(|_| Error::Idna)?;
-
-        // TCP connection
-        let tcp_stream = connect_with_srv(&domain, "_xmpp-client._tcp", 5222).await?;
-
-        // Unencryped XMPPStream
-        let xmpp_stream =
-            xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
-                .await?;
-
-        let channel_binding;
-        let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
-            // TlsStream
-            let tls_stream = starttls(xmpp_stream).await?;
-            #[cfg(feature = "tls-native")]
-            {
-                warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
-                channel_binding = ChannelBinding::None;
-            }
-            #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-            {
-                let (_, connection) = tls_stream.get_ref();
-                match connection.protocol_version() {
-                    // TODO: Add support for TLS 1.2 and earlier.
-                    Some(ProtocolVersion::TLSv1_3) => {
-                        let data = vec![0u8; 32];
-                        let data = connection.export_keying_material(
-                            data,
-                            b"EXPORTER-Channel-Binding",
-                            None,
-                        )?;
-                        channel_binding = ChannelBinding::TlsExporter(data);
-                    }
-                    _ => {
-                        channel_binding = ChannelBinding::None;
-                    }
-                }
-            }
-            // Encrypted XMPPStream
-            xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
-                .await?
-        } else {
-            return Err(Error::Protocol(ProtocolError::NoTls));
-        };
-
-        let creds = Credentials::default()
-            .with_username(username)
-            .with_password(password)
-            .with_channel_binding(channel_binding);
-        // Authenticated (unspecified) stream
-        let stream = auth(xmpp_stream, creds).await?;
-        // Authenticated XMPPStream
-        let xmpp_stream =
-            xmpp_stream::XMPPStream::start(stream, jid, ns::JABBER_CLIENT.to_owned()).await?;
-
-        // XMPPStream bound to user session
-        let xmpp_stream = bind(xmpp_stream).await?;
-        Ok(xmpp_stream)
-    }
-
     /// Get the client's bound JID (the one reported by the XMPP
     /// server).
     pub fn bound_jid(&self) -> &Jid {
@@ -150,7 +82,7 @@ impl Client {
 ///
 /// In an `async fn` you may want to use this with `use
 /// futures::stream::StreamExt;`
-impl Stream for Client {
+impl<C: ServerConnector> Stream for Client<C> {
     type Item = Result<Element, Error>;
 
     /// Low-level read on the XMPP stream
@@ -177,7 +109,7 @@ impl Stream for Client {
 /// Outgoing XMPP packets
 ///
 /// See `send_stanza()` for an `async fn`
-impl Sink<Packet> for Client {
+impl<C: ServerConnector> Sink<Packet> for Client<C> {
     type Error = Error;
 
     fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {

tokio-xmpp/src/lib.rs 🔗

@@ -20,9 +20,9 @@ pub mod stream_features;
 pub mod xmpp_stream;
 pub use client::{
     async_client::{
-        AsyncReadAndWrite, Client as AsyncClient, Config as AsyncConfig,
-        ServerConfig as AsyncServerConfig, ServerConnector as AsyncServerConnector,
+        Client as AsyncClient, Config as AsyncConfig, ServerConfig as AsyncServerConfig,
     },
+    connect::{client_login, AsyncReadAndWrite, ServerConnector},
     simple_client::Client as SimpleClient,
 };
 mod component;