tokio-xmpp: remove code duplicates for direct tls and starttls connectors

Saarko created

Change summary

tokio-xmpp/ChangeLog                 |   2 
tokio-xmpp/src/connect/direct_tls.rs | 154 +-------------------------
tokio-xmpp/src/connect/mod.rs        |   3 
tokio-xmpp/src/connect/starttls.rs   | 157 +--------------------------
tokio-xmpp/src/connect/tls_common.rs | 170 ++++++++++++++++++++++++++++++
5 files changed, 189 insertions(+), 297 deletions(-)

Detailed changes

tokio-xmpp/ChangeLog πŸ”—

@@ -35,7 +35,7 @@ XXXX-YY-ZZ RELEASER <admin@example.com>
 
         Please refer to the crate docs for details. (!581)
     * Added:
-      - Add new directTLS connection method to the `Client`. (Placeholder for PR number)
+      - Add new `direct-tls` connection method to the `Client`. (!585)
       - Support for sending IQ requests while tracking their responses in a
         Future.
       - `rustls` is now re-exported if it is enabled, to allow applications to

tokio-xmpp/src/connect/direct_tls.rs πŸ”—

@@ -7,50 +7,15 @@
 //! `direct_tls::ServerConfig` provides a `ServerConnector` for direct TLS connections
 
 use alloc::borrow::Cow;
-use core::{error::Error as StdError, fmt};
-#[cfg(feature = "native-tls")]
-use native_tls::Error as TlsError;
-#[cfg(feature = "rustls-any-backend")]
-use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
-// Note: feature = "rustls-any-backend" and feature = "native-tls" are
-// mutually exclusive during normal compiles, but we allow it for rustdoc
-// builds. Thus, we have to make sure that the compilation still succeeds in
-// such a case.
-#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
-use tokio_rustls::rustls::Error as TlsError;
-
-#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
-use {
-    alloc::sync::Arc,
-    tokio_rustls::{
-        rustls::pki_types::ServerName,
-        rustls::{ClientConfig, RootCertStore},
-        TlsConnector,
-    },
-};
-
-#[cfg(all(
-    feature = "rustls-any-backend",
-    not(feature = "ktls"),
-    not(feature = "native-tls")
-))]
-use tokio_rustls::client::TlsStream;
-
-#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
-type TlsStream<S> = ktls::KtlsStream<S>;
-
-#[cfg(feature = "native-tls")]
-use {
-    native_tls::TlsConnector as NativeTlsConnector,
-    tokio_native_tls::{TlsConnector, TlsStream},
-};
-
 use sasl::common::ChannelBinding;
 use tokio::{io::BufStream, net::TcpStream};
 use xmpp_parsers::jid::Jid;
 
 use crate::{
-    connect::{DnsConfig, ServerConnector, ServerConnectorError},
+    connect::{
+        tls_common::{establish_tls_connection, TlsConnectorError, TlsStream},
+        DnsConfig, ServerConnector,
+    },
     error::Error,
     xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
 };
@@ -78,7 +43,7 @@ impl ServerConnector for DirectTlsServerConnector {
 
         // Immediately establish TLS connection
         let (tls_stream, channel_binding) =
-            establish_tls(tcp_stream, jid.domain().as_str()).await?;
+            establish_tls_connection(tcp_stream, jid.domain().as_str()).await?;
 
         // Establish XMPP stream over TLS
         Ok((
@@ -100,110 +65,5 @@ impl ServerConnector for DirectTlsServerConnector {
     }
 }
 
-#[cfg(feature = "native-tls")]
-async fn establish_tls(
-    tcp_stream: TcpStream,
-    domain: &str,
-) -> Result<(TlsStream<TcpStream>, ChannelBinding), Error> {
-    let domain = domain.to_owned();
-    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
-        .connect(&domain, tcp_stream)
-        .await
-        .map_err(|e| DirectTlsError::Tls(e))?;
-    log::warn!(
-        "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
-    );
-    Ok((tls_stream, ChannelBinding::None))
-}
-
-#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
-async fn establish_tls(
-    tcp_stream: TcpStream,
-    domain: &str,
-) -> Result<(TlsStream<TcpStream>, ChannelBinding), Error> {
-    let domain = ServerName::try_from(domain.to_owned()).map_err(DirectTlsError::DnsNameError)?;
-    let mut root_store = RootCertStore::empty();
-    #[cfg(feature = "webpki-roots")]
-    {
-        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
-    }
-    #[cfg(feature = "rustls-native-certs")]
-    {
-        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
-    }
-    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
-    let mut config = ClientConfig::builder()
-        .with_root_certificates(root_store)
-        .with_no_client_auth();
-    #[cfg(feature = "ktls")]
-    let tcp_stream = {
-        config.enable_secret_extraction = true;
-        ktls::CorkStream::new(tcp_stream)
-    };
-    let tls_stream = TlsConnector::from(Arc::new(config))
-        .connect(domain, tcp_stream)
-        .await
-        .map_err(crate::Error::Io)?;
-
-    // Extract the channel-binding information before we hand the stream over to ktls.
-    let (_, connection) = tls_stream.get_ref();
-    let channel_binding = match connection.protocol_version() {
-        // TODO: Add support for TLS 1.2 and earlier.
-        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
-            let data = vec![0u8; 32];
-            let data = connection
-                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
-                .map_err(DirectTlsError::Tls)?;
-            ChannelBinding::TlsExporter(data)
-        }
-        _ => ChannelBinding::None,
-    };
-
-    #[cfg(feature = "ktls")]
-    let tls_stream = ktls::config_ktls_client(tls_stream)
-        .await
-        .map_err(DirectTlsError::KtlsError)?;
-    Ok((tls_stream, channel_binding))
-}
-
-/// Direct TLS ServerConnector Error
-#[derive(Debug)]
-pub enum DirectTlsError {
-    /// TLS error
-    Tls(TlsError),
-    #[cfg(feature = "rustls-any-backend")]
-    /// DNS name parsing error
-    DnsNameError(InvalidDnsNameError),
-    #[cfg(feature = "ktls")]
-    /// Error while setting up kernel TLS
-    KtlsError(ktls::Error),
-}
-
-impl ServerConnectorError for DirectTlsError {}
-
-impl fmt::Display for DirectTlsError {
-    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
-        match self {
-            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
-            #[cfg(feature = "rustls-any-backend")]
-            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
-            #[cfg(feature = "ktls")]
-            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
-        }
-    }
-}
-
-impl StdError for DirectTlsError {}
-
-impl From<TlsError> for DirectTlsError {
-    fn from(e: TlsError) -> Self {
-        Self::Tls(e)
-    }
-}
-
-#[cfg(feature = "rustls-any-backend")]
-impl From<InvalidDnsNameError> for DirectTlsError {
-    fn from(e: InvalidDnsNameError) -> Self {
-        Self::DnsNameError(e)
-    }
-}
+/// Direct TLS ServerConnector Error - now just an alias to the common error type
+pub type DirectTlsError = TlsConnectorError;

tokio-xmpp/src/connect/mod.rs πŸ”—

@@ -22,6 +22,9 @@ pub mod tcp;
 #[cfg(feature = "insecure-tcp")]
 pub use tcp::TcpServerConnector;
 
+#[cfg(any(feature = "direct-tls", feature = "starttls"))]
+pub mod tls_common;
+
 mod dns;
 pub use dns::DnsConfig;
 

tokio-xmpp/src/connect/starttls.rs πŸ”—

@@ -1,48 +1,10 @@
 //! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections
 
 use alloc::borrow::Cow;
-use core::{error::Error as StdError, fmt};
-#[cfg(feature = "native-tls")]
-use native_tls::Error as TlsError;
 use std::io;
 use std::os::fd::AsRawFd;
-#[cfg(feature = "rustls-any-backend")]
-use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
-// Note: feature = "rustls-any-backend" and feature = "native-tls" are
-// mutually exclusive during normal compiles, but we allow it for rustdoc
-// builds. Thus, we have to make sure that the compilation still succeeds in
-// such a case.
-#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
-use tokio_rustls::rustls::Error as TlsError;
 
 use futures::{sink::SinkExt, stream::StreamExt};
-
-#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
-use {
-    alloc::sync::Arc,
-    tokio_rustls::{
-        rustls::pki_types::ServerName,
-        rustls::{ClientConfig, RootCertStore},
-        TlsConnector,
-    },
-};
-
-#[cfg(all(
-    feature = "rustls-any-backend",
-    not(feature = "ktls"),
-    not(feature = "native-tls")
-))]
-use tokio_rustls::client::TlsStream;
-
-#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
-type TlsStream<S> = ktls::KtlsStream<S>;
-
-#[cfg(feature = "native-tls")]
-use {
-    native_tls::TlsConnector as NativeTlsConnector,
-    tokio_native_tls::{TlsConnector, TlsStream},
-};
-
 use sasl::common::ChannelBinding;
 use tokio::{
     io::{AsyncRead, AsyncWrite, BufStream},
@@ -54,7 +16,10 @@ use xmpp_parsers::{
 };
 
 use crate::{
-    connect::{DnsConfig, ServerConnector, ServerConnectorError},
+    connect::{
+        tls_common::{establish_tls_connection, TlsConnectorError, TlsStream},
+        DnsConfig, ServerConnector,
+    },
     error::{Error, ProtocolError},
     xmlstream::{
         initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
@@ -127,74 +92,6 @@ impl ServerConnector for StartTlsServerConnector {
     }
 }
 
-#[cfg(feature = "native-tls")]
-async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
-    xmpp_stream: XmppStream<BufStream<S>>,
-    domain: &str,
-) -> Result<(TlsStream<S>, ChannelBinding), Error> {
-    let domain = domain.to_owned();
-    let stream = xmpp_stream.into_inner().into_inner();
-    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
-        .connect(&domain, stream)
-        .await
-        .map_err(|e| StartTlsError::Tls(e))?;
-    log::warn!(
-        "tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"
-    );
-    Ok((tls_stream, ChannelBinding::None))
-}
-
-#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
-async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
-    xmpp_stream: XmppStream<BufStream<S>>,
-    domain: &str,
-) -> Result<(TlsStream<S>, ChannelBinding), Error> {
-    let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
-    let stream = xmpp_stream.into_inner().into_inner();
-    let mut root_store = RootCertStore::empty();
-    #[cfg(feature = "webpki-roots")]
-    {
-        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
-    }
-    #[cfg(feature = "rustls-native-certs")]
-    {
-        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
-    }
-    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
-    let mut config = ClientConfig::builder()
-        .with_root_certificates(root_store)
-        .with_no_client_auth();
-    #[cfg(feature = "ktls")]
-    let stream = {
-        config.enable_secret_extraction = true;
-        ktls::CorkStream::new(stream)
-    };
-    let tls_stream = TlsConnector::from(Arc::new(config))
-        .connect(domain, stream)
-        .await
-        .map_err(crate::Error::Io)?;
-
-    // Extract the channel-binding information before we hand the stream over to ktls.
-    let (_, connection) = tls_stream.get_ref();
-    let channel_binding = match connection.protocol_version() {
-        // TODO: Add support for TLS 1.2 and earlier.
-        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
-            let data = vec![0u8; 32];
-            let data = connection
-                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
-                .map_err(StartTlsError::Tls)?;
-            ChannelBinding::TlsExporter(data)
-        }
-        _ => ChannelBinding::None,
-    };
-
-    #[cfg(feature = "ktls")]
-    let tls_stream = ktls::config_ktls_client(tls_stream)
-        .await
-        .map_err(StartTlsError::KtlsError)?;
-    Ok((tls_stream, channel_binding))
-}
-
 /// Performs `<starttls/>` on an XmppStream and returns a binary
 /// TlsStream.
 pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
@@ -224,47 +121,9 @@ pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
         }
     }
 
-    get_tls_stream(stream, domain).await
+    let inner_stream = stream.into_inner().into_inner();
+    establish_tls_connection(inner_stream, domain).await
 }
 
-/// StartTLS ServerConnector Error
-#[derive(Debug)]
-pub enum StartTlsError {
-    /// TLS error
-    Tls(TlsError),
-    #[cfg(feature = "rustls-any-backend")]
-    /// DNS name parsing error
-    DnsNameError(InvalidDnsNameError),
-    #[cfg(feature = "ktls")]
-    /// Error while setting up kernel TLS
-    KtlsError(ktls::Error),
-}
-
-impl ServerConnectorError for StartTlsError {}
-
-impl fmt::Display for StartTlsError {
-    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
-        match self {
-            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
-            #[cfg(feature = "rustls-any-backend")]
-            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
-            #[cfg(feature = "ktls")]
-            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
-        }
-    }
-}
-
-impl StdError for StartTlsError {}
-
-impl From<TlsError> for StartTlsError {
-    fn from(e: TlsError) -> Self {
-        Self::Tls(e)
-    }
-}
-
-#[cfg(feature = "rustls-any-backend")]
-impl From<InvalidDnsNameError> for StartTlsError {
-    fn from(e: InvalidDnsNameError) -> Self {
-        Self::DnsNameError(e)
-    }
-}
+/// StartTLS ServerConnector Error - now just an alias to the common error type
+pub type StartTlsError = TlsConnectorError;

tokio-xmpp/src/connect/tls_common.rs πŸ”—

@@ -0,0 +1,170 @@
+// Copyright (c) 2025 Saarko <saarko@tutanota.com>
+//
+// This Source Code Form is subject to the terms of the Mozilla Public
+// License, v. 2.0. If a copy of the MPL was not distributed with this
+// file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+//! Common TLS functionality shared between direct_tls and starttls modules
+
+use core::{error::Error as StdError, fmt};
+use std::os::fd::AsRawFd;
+use tokio::io::{AsyncRead, AsyncWrite};
+
+#[cfg(feature = "native-tls")]
+use native_tls::Error as TlsError;
+#[cfg(feature = "rustls-any-backend")]
+use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
+#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
+use tokio_rustls::rustls::Error as TlsError;
+
+#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
+use {
+    alloc::sync::Arc,
+    tokio_rustls::{
+        rustls::pki_types::ServerName,
+        rustls::{ClientConfig, RootCertStore},
+        TlsConnector,
+    },
+};
+
+#[cfg(all(
+    feature = "rustls-any-backend",
+    not(feature = "ktls"),
+    not(feature = "native-tls")
+))]
+pub use tokio_rustls::client::TlsStream;
+
+#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
+pub type TlsStream<S> = ktls::KtlsStream<S>;
+
+#[cfg(feature = "native-tls")]
+pub use tokio_native_tls::TlsStream;
+
+#[cfg(feature = "native-tls")]
+use {native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::TlsConnector};
+
+use crate::{connect::ServerConnectorError, error::Error};
+use sasl::common::ChannelBinding;
+
+/// Common TLS error type used by both direct_tls and starttls
+#[derive(Debug)]
+pub enum TlsConnectorError {
+    /// TLS error
+    Tls(TlsError),
+    #[cfg(feature = "rustls-any-backend")]
+    /// DNS name parsing error
+    DnsNameError(InvalidDnsNameError),
+    #[cfg(feature = "ktls")]
+    /// Error while setting up kernel TLS
+    KtlsError(ktls::Error),
+}
+
+impl ServerConnectorError for TlsConnectorError {}
+
+impl fmt::Display for TlsConnectorError {
+    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
+            #[cfg(feature = "rustls-any-backend")]
+            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
+            #[cfg(feature = "ktls")]
+            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
+        }
+    }
+}
+
+impl StdError for TlsConnectorError {}
+
+impl From<TlsError> for TlsConnectorError {
+    fn from(e: TlsError) -> Self {
+        Self::Tls(e)
+    }
+}
+
+#[cfg(feature = "rustls-any-backend")]
+impl From<InvalidDnsNameError> for TlsConnectorError {
+    fn from(e: InvalidDnsNameError) -> Self {
+        Self::DnsNameError(e)
+    }
+}
+
+/// Establish TLS connection using native-tls
+#[cfg(feature = "native-tls")]
+pub async fn establish_tls_connection<S>(
+    stream: S,
+    domain: &str,
+) -> Result<(TlsStream<S>, ChannelBinding), Error>
+where
+    S: AsyncRead + AsyncWrite + Unpin,
+{
+    let domain = domain.to_owned();
+    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
+        .connect(&domain, stream)
+        .await
+        .map_err(|e| TlsConnectorError::Tls(e))?;
+    log::warn!(
+        "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
+    );
+    Ok((tls_stream, ChannelBinding::None))
+}
+
+/// Establish TLS connection using rustls
+#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
+pub async fn establish_tls_connection<S>(
+    stream: S,
+    domain: &str,
+) -> Result<(TlsStream<S>, ChannelBinding), Error>
+where
+    S: AsyncRead + AsyncWrite + Unpin + AsRawFd,
+{
+    let domain =
+        ServerName::try_from(domain.to_owned()).map_err(TlsConnectorError::DnsNameError)?;
+    let mut root_store = RootCertStore::empty();
+
+    #[cfg(feature = "webpki-roots")]
+    {
+        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
+    }
+
+    #[cfg(feature = "rustls-native-certs")]
+    {
+        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
+    }
+
+    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
+    let mut config = ClientConfig::builder()
+        .with_root_certificates(root_store)
+        .with_no_client_auth();
+
+    #[cfg(feature = "ktls")]
+    let stream = {
+        config.enable_secret_extraction = true;
+        ktls::CorkStream::new(stream)
+    };
+
+    let tls_stream = TlsConnector::from(Arc::new(config))
+        .connect(domain, stream)
+        .await
+        .map_err(crate::Error::Io)?;
+
+    // Extract the channel-binding information before we hand the stream over to ktls.
+    let (_, connection) = tls_stream.get_ref();
+    let channel_binding = match connection.protocol_version() {
+        // TODO: Add support for TLS 1.2 and earlier.
+        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
+            let data = vec![0u8; 32];
+            let data = connection
+                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
+                .map_err(TlsConnectorError::Tls)?;
+            ChannelBinding::TlsExporter(data)
+        }
+        _ => ChannelBinding::None,
+    };
+
+    #[cfg(feature = "ktls")]
+    let tls_stream = ktls::config_ktls_client(tls_stream)
+        .await
+        .map_err(TlsConnectorError::KtlsError)?;
+
+    Ok((tls_stream, channel_binding))
+}