Add AsyncServerConnector to AsyncClient to be able to support any stream

moparisthebest created

Unfortunately API breaking unless we do some export mangling

Change summary

tokio-xmpp/src/client/async_client.rs | 120 +++++++++++++++++-----------
tokio-xmpp/src/lib.rs                 |   7 +
xmpp/Cargo.toml                       |   6 
xmpp/src/agent.rs                     |   3 
xmpp/src/builder.rs                   |   4 
xmpp/src/lib.rs                       |   3 
6 files changed, 86 insertions(+), 57 deletions(-)

Detailed changes

tokio-xmpp/src/client/async_client.rs 🔗

@@ -5,10 +5,6 @@ use std::pin::Pin;
 use std::task::Context;
 use tokio::net::TcpStream;
 use tokio::task::JoinHandle;
-#[cfg(feature = "tls-native")]
-use tokio_native_tls::TlsStream;
-#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-use tokio_rustls::client::TlsStream;
 use xmpp_parsers::{ns, Element, Jid};
 
 use super::auth::auth;
@@ -17,8 +13,12 @@ use crate::event::Event;
 use crate::happy_eyeballs::{connect_to_host, connect_with_srv};
 use crate::starttls::starttls;
 use crate::xmpp_codec::Packet;
-use crate::xmpp_stream::{self, add_stanza_id};
+use crate::xmpp_stream::{self, add_stanza_id, XMPPStream};
 use crate::{Error, ProtocolError};
+#[cfg(feature = "tls-native")]
+use tokio_native_tls::TlsStream;
+#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+use tokio_rustls::client::TlsStream;
 
 /// XMPP client connection and state
 ///
@@ -26,13 +26,35 @@ use crate::{Error, ProtocolError};
 ///
 /// This implements the `futures` crate's [`Stream`](#impl-Stream) and
 /// [`Sink`](#impl-Sink<Packet>) traits.
-pub struct Client {
-    config: Config,
-    state: ClientState,
+pub struct Client<C: ServerConnector> {
+    config: Config<C>,
+    state: ClientState<C::Stream>,
     reconnect: bool,
     // TODO: tls_required=true
 }
 
+/// XMPP client configuration
+#[derive(Clone, Debug)]
+pub struct Config<C> {
+    /// jid of the account
+    pub jid: Jid,
+    /// password of the account
+    pub password: String,
+    /// server configuration for the account
+    pub server: C,
+}
+
+/// Trait called to connect to an XMPP server, perhaps called multiple times
+pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
+    /// The type of Stream this ServerConnector produces
+    type Stream: AsyncReadAndWrite;
+    /// This must return the connection ready to login, ie if starttls is involved, after TLS has been started, and then after the <stream headers are exchanged
+    fn connect(
+        &self,
+        jid: &Jid,
+    ) -> impl std::future::Future<Output = Result<XMPPStream<Self::Stream>, Error>> + Send;
+}
+
 /// XMPP server connection configuration
 #[derive(Clone, Debug)]
 pub enum ServerConfig {
@@ -48,27 +70,46 @@ pub enum ServerConfig {
     },
 }
 
-/// XMPP client configuration
-#[derive(Clone, Debug)]
-pub struct Config {
-    /// jid of the account
-    pub jid: Jid,
-    /// password of the account
-    pub password: String,
-    /// server configuration for the account
-    pub server: ServerConfig,
+impl ServerConnector for ServerConfig {
+    type Stream = TlsStream<TcpStream>;
+    async fn connect(&self, jid: &Jid) -> Result<XMPPStream<Self::Stream>, Error> {
+        // TCP connection
+        let tcp_stream = match self {
+            ServerConfig::UseSrv => {
+                connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
+            }
+            ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), *port).await?,
+        };
+
+        // Unencryped XMPPStream
+        let xmpp_stream =
+            xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
+                .await?;
+
+        if xmpp_stream.stream_features.can_starttls() {
+            // TlsStream
+            let tls_stream = starttls(xmpp_stream).await?;
+            // Encrypted XMPPStream
+            xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
+                .await
+        } else {
+            return Err(Error::Protocol(ProtocolError::NoTls));
+        }
+    }
 }
 
-type XMPPStream = xmpp_stream::XMPPStream<TlsStream<TcpStream>>;
+/// trait used by XMPPStream type
+pub trait AsyncReadAndWrite: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send {}
+impl<T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send> AsyncReadAndWrite for T {}
 
-enum ClientState {
+enum ClientState<S: AsyncReadAndWrite> {
     Invalid,
     Disconnected,
-    Connecting(JoinHandle<Result<XMPPStream, Error>>),
-    Connected(XMPPStream),
+    Connecting(JoinHandle<Result<XMPPStream<S>, Error>>),
+    Connected(XMPPStream<S>),
 }
 
-impl Client {
+impl Client<ServerConfig> {
     /// Start a new XMPP client
     ///
     /// Start polling the returned instance so that it will connect
@@ -81,9 +122,11 @@ impl Client {
         };
         Self::new_with_config(config)
     }
+}
 
+impl<C: ServerConnector> Client<C> {
     /// Start a new client given that the JID is already parsed.
-    pub fn new_with_config(config: Config) -> Self {
+    pub fn new_with_config(config: Config<C>) -> Self {
         let connect = tokio::spawn(Self::connect(
             config.server.clone(),
             config.jid.clone(),
@@ -105,35 +148,14 @@ impl Client {
     }
 
     async fn connect(
-        server: ServerConfig,
+        server: C,
         jid: Jid,
         password: String,
-    ) -> Result<XMPPStream, Error> {
+    ) -> Result<XMPPStream<C::Stream>, Error> {
         let username = jid.node_str().unwrap();
         let password = password;
 
-        // TCP connection
-        let tcp_stream = match server {
-            ServerConfig::UseSrv => {
-                connect_with_srv(jid.domain_str(), "_xmpp-client._tcp", 5222).await?
-            }
-            ServerConfig::Manual { host, port } => connect_to_host(host.as_str(), port).await?,
-        };
-
-        // Unencryped XMPPStream
-        let xmpp_stream =
-            xmpp_stream::XMPPStream::start(tcp_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
-                .await?;
-
-        let xmpp_stream = if xmpp_stream.stream_features.can_starttls() {
-            // TlsStream
-            let tls_stream = starttls(xmpp_stream).await?;
-            // Encrypted XMPPStream
-            xmpp_stream::XMPPStream::start(tls_stream, jid.clone(), ns::JABBER_CLIENT.to_owned())
-                .await?
-        } else {
-            return Err(Error::Protocol(ProtocolError::NoTls));
-        };
+        let xmpp_stream = server.connect(&jid).await?;
 
         let creds = Credentials::default()
             .with_username(username)
@@ -180,7 +202,7 @@ impl Client {
 ///
 /// In an `async fn` you may want to use this with `use
 /// futures::stream::StreamExt;`
-impl Stream for Client {
+impl<C: ServerConnector> Stream for Client<C> {
     type Item = Event;
 
     /// Low-level read on the XMPP stream, allowing the underlying
@@ -297,7 +319,7 @@ impl Stream for Client {
 /// Outgoing XMPP packets
 ///
 /// See `send_stanza()` for an `async fn`
-impl Sink<Packet> for Client {
+impl<C: ServerConnector> Sink<Packet> for Client<C> {
     type Error = Error;
 
     fn start_send(mut self: Pin<&mut Self>, item: Packet) -> Result<(), Self::Error> {

tokio-xmpp/src/lib.rs 🔗

@@ -19,8 +19,11 @@ mod happy_eyeballs;
 pub mod stream_features;
 pub mod xmpp_stream;
 pub use client::{
-    async_client::Client as AsyncClient, async_client::Config as AsyncConfig,
-    async_client::ServerConfig as AsyncServerConfig, simple_client::Client as SimpleClient,
+    async_client::{
+        AsyncReadAndWrite, Client as AsyncClient, Config as AsyncConfig,
+        ServerConfig as AsyncServerConfig, ServerConnector as AsyncServerConnector,
+    },
+    simple_client::Client as SimpleClient,
 };
 mod component;
 pub use crate::component::Component;

xmpp/Cargo.toml 🔗

@@ -21,7 +21,7 @@ log = "0.4"
 reqwest = { version = "0.11.8", features = ["stream"] }
 tokio-util = { version = "0.7", features = ["codec"] }
 # same repository dependencies
-tokio-xmpp = { version = "3.4", path = "../tokio-xmpp" }
+tokio-xmpp = { version = "3.4", path = "../tokio-xmpp", default-features = false }
 
 [dev-dependencies]
 env_logger = { version = "0.10", default-features = false, features = ["auto-color", "humantime"] }
@@ -31,5 +31,7 @@ name = "hello_bot"
 required-features = ["avatars"]
 
 [features]
-default = ["avatars"]
+default = ["avatars", "tls-native"]
+tls-native = ["tokio-xmpp/tls-native"]
+tls-rust = ["tokio-xmpp/tls-rust"]
 avatars = []

xmpp/src/agent.rs 🔗

@@ -8,10 +8,9 @@ use std::path::{Path, PathBuf};
 use std::sync::{Arc, RwLock};
 pub use tokio_xmpp::parsers;
 use tokio_xmpp::parsers::{disco::DiscoInfoResult, message::MessageType};
-use tokio_xmpp::AsyncClient as TokioXmppClient;
 pub use tokio_xmpp::{BareJid, Element, FullJid, Jid};
 
-use crate::{event_loop, message, muc, upload, Error, Event, RoomNick};
+use crate::{event_loop, message, muc, upload, Error, Event, RoomNick, TokioXmppClient};
 
 pub struct Agent {
     pub(crate) client: TokioXmppClient,

xmpp/src/builder.rs 🔗

@@ -10,10 +10,10 @@ use tokio_xmpp::{
         disco::{DiscoInfoResult, Feature, Identity},
         ns,
     },
-    AsyncClient as TokioXmppClient, BareJid, Jid,
+    BareJid, Jid,
 };
 
-use crate::{Agent, ClientFeature};
+use crate::{Agent, ClientFeature, TokioXmppClient};
 
 #[derive(Debug)]
 pub enum ClientType {

xmpp/src/lib.rs 🔗

@@ -7,6 +7,7 @@
 #![deny(bare_trait_objects)]
 
 pub use tokio_xmpp::parsers;
+use tokio_xmpp::{AsyncClient, AsyncServerConfig};
 pub use tokio_xmpp::{BareJid, Element, FullJid, Jid};
 #[macro_use]
 extern crate log;
@@ -31,6 +32,8 @@ pub use builder::{ClientBuilder, ClientType};
 pub use event::Event;
 pub use feature::ClientFeature;
 
+type TokioXmppClient = AsyncClient<AsyncServerConfig>;
+
 pub type Error = tokio_xmpp::Error;
 pub type Id = Option<String>;
 pub type RoomNick = String;