@@ -35,6 +35,7 @@ idna = { version = "1.0", optional = true}
native-tls = { version = "0.2", optional = true }
tokio-native-tls = { version = "0.3", optional = true }
tokio-rustls = { version = "0.26", optional = true }
+ktls = { version = "6", optional = true }
[dev-dependencies]
env_logger = { version = "0.11", default-features = false, features = ["auto-color", "humantime"] }
@@ -46,6 +47,7 @@ tokio-xmpp = { path = ".", features = ["insecure-tcp"]}
default = ["starttls-rust", "rustls-native-certs"]
starttls = ["dns"]
tls-rust = ["tokio-rustls"]
+tls-rust-ktls = ["tls-rust", "ktls"]
tls-rust-native-certs = ["tls-rust", "rustls-native-certs"]
tls-rust-webpki-roots = ["tls-rust", "webpki-roots"]
tls-native = ["tokio-native-tls", "native-tls"]
@@ -6,6 +6,7 @@ use std::borrow::Cow;
use std::error::Error as StdError;
use std::fmt;
use std::io;
+use std::os::fd::AsRawFd;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
@@ -17,13 +18,22 @@ use futures::{sink::SinkExt, stream::StreamExt};
use {
std::sync::Arc,
tokio_rustls::{
- client::TlsStream,
rustls::pki_types::ServerName,
rustls::{ClientConfig, RootCertStore},
TlsConnector,
},
};
+#[cfg(all(
+ feature = "tls-rust",
+ not(feature = "tls-native"),
+ not(feature = "tls-rust-ktls")
+))]
+use tokio_rustls::client::TlsStream;
+
+#[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
+type TlsStream<S> = ktls::KtlsStream<S>;
+
#[cfg(feature = "tls-native")]
use {
native_tls::TlsConnector as NativeTlsConnector,
@@ -116,7 +126,16 @@ impl ServerConnector for StartTlsServerConnector {
log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
Ok(ChannelBinding::None)
}
- #[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
+ #[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
+ {
+ log::warn!("Kernel TLS doesn’t support channel binding yet, we would have to extract the secrets in the rustls TlsStream before converting it into a KtlsStream.");
+ Ok(ChannelBinding::None)
+ }
+ #[cfg(all(
+ feature = "tls-rust",
+ not(feature = "tls-native"),
+ not(feature = "tls-rust-ktls")
+ ))]
{
let (_, connection) = stream.get_ref().get_ref();
Ok(match connection.protocol_version() {
@@ -149,7 +168,7 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
-async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
+async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
xmpp_stream: XmppStream<BufStream<S>>,
domain: &str,
) -> Result<TlsStream<S>, Error> {
@@ -164,19 +183,29 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
{
root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?);
}
- let config = ClientConfig::builder()
+ #[allow(unused_mut, reason = "This config is mutable when using ktls")]
+ let mut config = ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
+ #[cfg(feature = "tls-rust-ktls")]
+ let stream = {
+ config.enable_secret_extraction = true;
+ ktls::CorkStream::new(stream)
+ };
let tls_stream = TlsConnector::from(Arc::new(config))
.connect(domain, stream)
.await
.map_err(|e| Error::from(crate::Error::Io(e)))?;
+ #[cfg(feature = "tls-rust-ktls")]
+ let tls_stream = ktls::config_ktls_client(tls_stream)
+ .await
+ .map_err(StartTlsError::KtlsError)?;
Ok(tls_stream)
}
/// Performs `<starttls/>` on an XmppStream and returns a binary
/// TlsStream.
-pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
+pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
mut stream: XmppStream<BufStream<S>>,
domain: &str,
) -> Result<TlsStream<S>, Error> {
@@ -214,6 +243,9 @@ pub enum StartTlsError {
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
/// DNS name parsing error
DnsNameError(InvalidDnsNameError),
+ #[cfg(feature = "tls-rust-ktls")]
+ /// Error while setting up kernel TLS
+ KtlsError(ktls::Error),
}
impl ServerConnectorError for StartTlsError {}
@@ -224,6 +256,8 @@ impl fmt::Display for StartTlsError {
Self::Tls(e) => write!(fmt, "TLS error: {}", e),
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
+ #[cfg(feature = "tls-rust-ktls")]
+ Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
}
}
}