tokio-xmpp: Refactor to provide channel-binding for ktls

Emmanuel Gil Peyrot created

Instead of having a second method to fetch channel-binding from the
TlsStream, do it directly in the connect() method, since after that we
don’t have enough information to fetch it any longer when using ktls.

Change summary

tokio-xmpp/src/client/login.rs     |  4 -
tokio-xmpp/src/component/login.rs  |  2 
tokio-xmpp/src/connect/mod.rs      | 11 +--
tokio-xmpp/src/connect/starttls.rs | 92 +++++++++++++------------------
tokio-xmpp/src/connect/tcp.rs      | 29 +++++----
5 files changed, 61 insertions(+), 77 deletions(-)

Detailed changes

tokio-xmpp/src/client/login.rs 🔗

@@ -112,11 +112,9 @@ pub async fn client_auth<C: ServerConnector>(
     let username = jid.node().unwrap().as_str();
     let password = password;
 
-    let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
+    let (xmpp_stream, channel_binding) = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
     let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
 
-    let channel_binding = C::channel_binding(xmpp_stream.get_stream())?;
-
     let creds = Credentials::default()
         .with_username(username)
         .with_password(password)

tokio-xmpp/src/component/login.rs 🔗

@@ -16,7 +16,7 @@ pub async fn component_login<C: ServerConnector>(
     timeouts: Timeouts,
 ) -> Result<XmppStream<C::Stream>, Error> {
     let password = password;
-    let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
+    let (mut stream, _) = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
     let header = stream.take_header();
     let mut stream = stream.skip_features();
     let stream_id = match header.id {

tokio-xmpp/src/connect/mod.rs 🔗

@@ -37,12 +37,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
         jid: &Jid,
         ns: &'static str,
         timeouts: Timeouts,
-    ) -> impl std::future::Future<Output = Result<PendingFeaturesRecv<Self::Stream>, Error>> + Send;
-
-    /// Return channel binding data if available
-    /// do not fail if channel binding is simply unavailable, just return Ok(None)
-    /// this should only be called after the TLS handshake is finished
-    fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Error> {
-        Ok(ChannelBinding::None)
-    }
+    ) -> impl std::future::Future<
+        Output = Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error>,
+    > + Send;
 }

tokio-xmpp/src/connect/starttls.rs 🔗

@@ -82,7 +82,7 @@ impl ServerConnector for StartTlsServerConnector {
         jid: &Jid,
         ns: &'static str,
         timeouts: Timeouts,
-    ) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
+    ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
         let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
 
         // Unencryped XmppStream
@@ -101,78 +101,51 @@ impl ServerConnector for StartTlsServerConnector {
 
         if features.can_starttls() {
             // TlsStream
-            let tls_stream = starttls(xmpp_stream, jid.domain().as_str()).await?;
+            let (tls_stream, channel_binding) =
+                starttls(xmpp_stream, jid.domain().as_str()).await?;
             // Encrypted XmppStream
-            Ok(initiate_stream(
-                tokio::io::BufStream::new(tls_stream),
-                ns,
-                StreamHeader {
-                    to: Some(Cow::Borrowed(jid.domain().as_str())),
-                    from: None,
-                    id: None,
-                },
-                timeouts,
-            )
-            .await?)
+            Ok((
+                initiate_stream(
+                    tokio::io::BufStream::new(tls_stream),
+                    ns,
+                    StreamHeader {
+                        to: Some(Cow::Borrowed(jid.domain().as_str())),
+                        from: None,
+                        id: None,
+                    },
+                    timeouts,
+                )
+                .await?,
+                channel_binding,
+            ))
         } else {
             Err(crate::Error::Protocol(ProtocolError::NoTls).into())
         }
     }
-
-    fn channel_binding(
-        #[allow(unused_variables)] stream: &Self::Stream,
-    ) -> Result<sasl::common::ChannelBinding, Error> {
-        #[cfg(feature = "tls-native")]
-        {
-            log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
-            Ok(ChannelBinding::None)
-        }
-        #[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
-        {
-            log::warn!("Kernel TLS doesn’t support channel binding yet, we would have to extract the secrets in the rustls TlsStream before converting it into a KtlsStream.");
-            Ok(ChannelBinding::None)
-        }
-        #[cfg(all(
-            feature = "tls-rust",
-            not(feature = "tls-native"),
-            not(feature = "tls-rust-ktls")
-        ))]
-        {
-            let (_, connection) = stream.get_ref().get_ref();
-            Ok(match connection.protocol_version() {
-                // TODO: Add support for TLS 1.2 and earlier.
-                Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
-                    let data = vec![0u8; 32];
-                    let data = connection
-                        .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
-                        .map_err(|e| StartTlsError::Tls(e))?;
-                    ChannelBinding::TlsExporter(data)
-                }
-                _ => ChannelBinding::None,
-            })
-        }
-    }
 }
 
 #[cfg(feature = "tls-native")]
 async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
     xmpp_stream: XmppStream<BufStream<S>>,
     domain: &str,
-) -> Result<TlsStream<S>, Error> {
+) -> Result<(TlsStream<S>, ChannelBinding), Error> {
     let domain = domain.to_owned();
     let stream = xmpp_stream.into_inner().into_inner();
     let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
         .connect(&domain, stream)
         .await
         .map_err(|e| StartTlsError::Tls(e))?;
-    Ok(tls_stream)
+    log::warn!(
+        "tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"
+    );
+    Ok((tls_stream, ChannelBinding::None))
 }
 
 #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
 async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
     xmpp_stream: XmppStream<BufStream<S>>,
     domain: &str,
-) -> Result<TlsStream<S>, Error> {
+) -> Result<(TlsStream<S>, ChannelBinding), Error> {
     let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
     let stream = xmpp_stream.into_inner().into_inner();
     let mut root_store = RootCertStore::empty();
@@ -197,11 +170,26 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
         .connect(domain, stream)
         .await
         .map_err(|e| Error::from(crate::Error::Io(e)))?;
+
+    // Extract the channel-binding information before we hand the stream over to ktls.
+    let (_, connection) = tls_stream.get_ref();
+    let channel_binding = match connection.protocol_version() {
+        // TODO: Add support for TLS 1.2 and earlier.
+        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
+            let data = vec![0u8; 32];
+            let data = connection
+                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
+                .map_err(|e| StartTlsError::Tls(e))?;
+            ChannelBinding::TlsExporter(data)
+        }
+        _ => ChannelBinding::None,
+    };
+
     #[cfg(feature = "tls-rust-ktls")]
     let tls_stream = ktls::config_ktls_client(tls_stream)
         .await
         .map_err(StartTlsError::KtlsError)?;
-    Ok(tls_stream)
+    Ok((tls_stream, channel_binding))
 }
 
 /// Performs `<starttls/>` on an XmppStream and returns a binary
@@ -209,7 +197,7 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
 pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
     mut stream: XmppStream<BufStream<S>>,
     domain: &str,
-) -> Result<TlsStream<S>, Error> {
+) -> Result<(TlsStream<S>, ChannelBinding), Error> {
     stream
         .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
             Request,

tokio-xmpp/src/connect/tcp.rs 🔗

@@ -5,7 +5,7 @@ use std::borrow::Cow;
 use tokio::{io::BufStream, net::TcpStream};
 
 use crate::{
-    connect::{DnsConfig, ServerConnector},
+    connect::{ChannelBinding, DnsConfig, ServerConnector},
     xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
     Client, Component, Error,
 };
@@ -37,18 +37,21 @@ impl ServerConnector for TcpServerConnector {
         jid: &xmpp_parsers::jid::Jid,
         ns: &'static str,
         timeouts: Timeouts,
-    ) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
+    ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
         let stream = BufStream::new(self.0.resolve().await?);
-        Ok(initiate_stream(
-            stream,
-            ns,
-            StreamHeader {
-                to: Some(Cow::Borrowed(jid.domain().as_str())),
-                from: None,
-                id: None,
-            },
-            timeouts,
-        )
-        .await?)
+        Ok((
+            initiate_stream(
+                stream,
+                ns,
+                StreamHeader {
+                    to: Some(Cow::Borrowed(jid.domain().as_str())),
+                    from: None,
+                    id: None,
+                },
+                timeouts,
+            )
+            .await?,
+            ChannelBinding::None,
+        ))
     }
 }