diff --git a/tokio-xmpp/Cargo.toml b/tokio-xmpp/Cargo.toml index 1a07dd3563c0c8c1b02f98851547ad187facf092..9d36acb6f5fcb42b8b22869c5012607b5034800e 100644 --- a/tokio-xmpp/Cargo.toml +++ b/tokio-xmpp/Cargo.toml @@ -35,6 +35,7 @@ idna = { version = "1.0", optional = true} native-tls = { version = "0.2", optional = true } tokio-native-tls = { version = "0.3", optional = true } tokio-rustls = { version = "0.26", optional = true } +ktls = { version = "6", optional = true } [dev-dependencies] env_logger = { version = "0.11", default-features = false, features = ["auto-color", "humantime"] } @@ -46,6 +47,7 @@ tokio-xmpp = { path = ".", features = ["insecure-tcp"]} default = ["starttls-rust", "rustls-native-certs"] starttls = ["dns"] tls-rust = ["tokio-rustls"] +tls-rust-ktls = ["tls-rust", "ktls"] tls-rust-native-certs = ["tls-rust", "rustls-native-certs"] tls-rust-webpki-roots = ["tls-rust", "webpki-roots"] tls-native = ["tokio-native-tls", "native-tls"] diff --git a/tokio-xmpp/src/connect/starttls.rs b/tokio-xmpp/src/connect/starttls.rs index 7df476c15c46cc449a1c75dd0b1f218713d9a66f..08ba90cebbbf034361abd7da19e96f646a16ab70 100644 --- a/tokio-xmpp/src/connect/starttls.rs +++ b/tokio-xmpp/src/connect/starttls.rs @@ -6,6 +6,7 @@ use std::borrow::Cow; use std::error::Error as StdError; use std::fmt; use std::io; +use std::os::fd::AsRawFd; #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] use tokio_rustls::rustls::pki_types::InvalidDnsNameError; #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] @@ -17,13 +18,22 @@ use futures::{sink::SinkExt, stream::StreamExt}; use { std::sync::Arc, tokio_rustls::{ - client::TlsStream, rustls::pki_types::ServerName, rustls::{ClientConfig, RootCertStore}, TlsConnector, }, }; +#[cfg(all( + feature = "tls-rust", + not(feature = "tls-native"), + not(feature = "tls-rust-ktls") +))] +use tokio_rustls::client::TlsStream; + +#[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))] +type TlsStream = ktls::KtlsStream; + #[cfg(feature = "tls-native")] use { native_tls::TlsConnector as NativeTlsConnector, @@ -116,7 +126,16 @@ impl ServerConnector for StartTlsServerConnector { log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"); Ok(ChannelBinding::None) } - #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] + #[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))] + { + log::warn!("Kernel TLS doesn’t support channel binding yet, we would have to extract the secrets in the rustls TlsStream before converting it into a KtlsStream."); + Ok(ChannelBinding::None) + } + #[cfg(all( + feature = "tls-rust", + not(feature = "tls-native"), + not(feature = "tls-rust-ktls") + ))] { let (_, connection) = stream.get_ref().get_ref(); Ok(match connection.protocol_version() { @@ -149,7 +168,7 @@ async fn get_tls_stream( } #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] -async fn get_tls_stream( +async fn get_tls_stream( xmpp_stream: XmppStream>, domain: &str, ) -> Result, Error> { @@ -164,19 +183,29 @@ async fn get_tls_stream( { root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?); } - let config = ClientConfig::builder() + #[allow(unused_mut, reason = "This config is mutable when using ktls")] + let mut config = ClientConfig::builder() .with_root_certificates(root_store) .with_no_client_auth(); + #[cfg(feature = "tls-rust-ktls")] + let stream = { + config.enable_secret_extraction = true; + ktls::CorkStream::new(stream) + }; let tls_stream = TlsConnector::from(Arc::new(config)) .connect(domain, stream) .await .map_err(|e| Error::from(crate::Error::Io(e)))?; + #[cfg(feature = "tls-rust-ktls")] + let tls_stream = ktls::config_ktls_client(tls_stream) + .await + .map_err(StartTlsError::KtlsError)?; Ok(tls_stream) } /// Performs `` on an XmppStream and returns a binary /// TlsStream. -pub async fn starttls( +pub async fn starttls( mut stream: XmppStream>, domain: &str, ) -> Result, Error> { @@ -214,6 +243,9 @@ pub enum StartTlsError { #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] /// DNS name parsing error DnsNameError(InvalidDnsNameError), + #[cfg(feature = "tls-rust-ktls")] + /// Error while setting up kernel TLS + KtlsError(ktls::Error), } impl ServerConnectorError for StartTlsError {} @@ -224,6 +256,8 @@ impl fmt::Display for StartTlsError { Self::Tls(e) => write!(fmt, "TLS error: {}", e), #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))] Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e), + #[cfg(feature = "tls-rust-ktls")] + Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e), } } }