diff --git a/tokio-xmpp/ChangeLog b/tokio-xmpp/ChangeLog index 892096f8064cc4a474cb1c0894fbe2a88771cca5..2b07cded0f7cc12fb002af209f70fcf5fdf5e26e 100644 --- a/tokio-xmpp/ChangeLog +++ b/tokio-xmpp/ChangeLog @@ -35,7 +35,7 @@ XXXX-YY-ZZ RELEASER Please refer to the crate docs for details. (!581) * Added: - - Add new directTLS connection method to the `Client`. (Placeholder for PR number) + - Add new `direct-tls` connection method to the `Client`. (!585) - Support for sending IQ requests while tracking their responses in a Future. - `rustls` is now re-exported if it is enabled, to allow applications to diff --git a/tokio-xmpp/src/connect/direct_tls.rs b/tokio-xmpp/src/connect/direct_tls.rs index 9a0f56621cc4f9e7cf081db03fe3ea814f8c30dd..70893f63f21367569d9ef003d66a42e898dc6954 100644 --- a/tokio-xmpp/src/connect/direct_tls.rs +++ b/tokio-xmpp/src/connect/direct_tls.rs @@ -7,50 +7,15 @@ //! `direct_tls::ServerConfig` provides a `ServerConnector` for direct TLS connections use alloc::borrow::Cow; -use core::{error::Error as StdError, fmt}; -#[cfg(feature = "native-tls")] -use native_tls::Error as TlsError; -#[cfg(feature = "rustls-any-backend")] -use tokio_rustls::rustls::pki_types::InvalidDnsNameError; -// Note: feature = "rustls-any-backend" and feature = "native-tls" are -// mutually exclusive during normal compiles, but we allow it for rustdoc -// builds. Thus, we have to make sure that the compilation still succeeds in -// such a case. -#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))] -use tokio_rustls::rustls::Error as TlsError; - -#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))] -use { - alloc::sync::Arc, - tokio_rustls::{ - rustls::pki_types::ServerName, - rustls::{ClientConfig, RootCertStore}, - TlsConnector, - }, -}; - -#[cfg(all( - feature = "rustls-any-backend", - not(feature = "ktls"), - not(feature = "native-tls") -))] -use tokio_rustls::client::TlsStream; - -#[cfg(all(feature = "ktls", not(feature = "native-tls")))] -type TlsStream = ktls::KtlsStream; - -#[cfg(feature = "native-tls")] -use { - native_tls::TlsConnector as NativeTlsConnector, - tokio_native_tls::{TlsConnector, TlsStream}, -}; - use sasl::common::ChannelBinding; use tokio::{io::BufStream, net::TcpStream}; use xmpp_parsers::jid::Jid; use crate::{ - connect::{DnsConfig, ServerConnector, ServerConnectorError}, + connect::{ + tls_common::{establish_tls_connection, TlsConnectorError, TlsStream}, + DnsConfig, ServerConnector, + }, error::Error, xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts}, }; @@ -78,7 +43,7 @@ impl ServerConnector for DirectTlsServerConnector { // Immediately establish TLS connection let (tls_stream, channel_binding) = - establish_tls(tcp_stream, jid.domain().as_str()).await?; + establish_tls_connection(tcp_stream, jid.domain().as_str()).await?; // Establish XMPP stream over TLS Ok(( @@ -100,110 +65,5 @@ impl ServerConnector for DirectTlsServerConnector { } } -#[cfg(feature = "native-tls")] -async fn establish_tls( - tcp_stream: TcpStream, - domain: &str, -) -> Result<(TlsStream, ChannelBinding), Error> { - let domain = domain.to_owned(); - let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) - .connect(&domain, tcp_stream) - .await - .map_err(|e| DirectTlsError::Tls(e))?; - log::warn!( - "tls-native doesn't support channel binding, please use tls-rust if you want this feature!" - ); - Ok((tls_stream, ChannelBinding::None)) -} - -#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))] -async fn establish_tls( - tcp_stream: TcpStream, - domain: &str, -) -> Result<(TlsStream, ChannelBinding), Error> { - let domain = ServerName::try_from(domain.to_owned()).map_err(DirectTlsError::DnsNameError)?; - let mut root_store = RootCertStore::empty(); - #[cfg(feature = "webpki-roots")] - { - root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - } - #[cfg(feature = "rustls-native-certs")] - { - root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?); - } - #[allow(unused_mut, reason = "This config is mutable when using ktls")] - let mut config = ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - #[cfg(feature = "ktls")] - let tcp_stream = { - config.enable_secret_extraction = true; - ktls::CorkStream::new(tcp_stream) - }; - let tls_stream = TlsConnector::from(Arc::new(config)) - .connect(domain, tcp_stream) - .await - .map_err(crate::Error::Io)?; - - // Extract the channel-binding information before we hand the stream over to ktls. - let (_, connection) = tls_stream.get_ref(); - let channel_binding = match connection.protocol_version() { - // TODO: Add support for TLS 1.2 and earlier. - Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => { - let data = vec![0u8; 32]; - let data = connection - .export_keying_material(data, b"EXPORTER-Channel-Binding", None) - .map_err(DirectTlsError::Tls)?; - ChannelBinding::TlsExporter(data) - } - _ => ChannelBinding::None, - }; - - #[cfg(feature = "ktls")] - let tls_stream = ktls::config_ktls_client(tls_stream) - .await - .map_err(DirectTlsError::KtlsError)?; - Ok((tls_stream, channel_binding)) -} - -/// Direct TLS ServerConnector Error -#[derive(Debug)] -pub enum DirectTlsError { - /// TLS error - Tls(TlsError), - #[cfg(feature = "rustls-any-backend")] - /// DNS name parsing error - DnsNameError(InvalidDnsNameError), - #[cfg(feature = "ktls")] - /// Error while setting up kernel TLS - KtlsError(ktls::Error), -} - -impl ServerConnectorError for DirectTlsError {} - -impl fmt::Display for DirectTlsError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Tls(e) => write!(fmt, "TLS error: {}", e), - #[cfg(feature = "rustls-any-backend")] - Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e), - #[cfg(feature = "ktls")] - Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e), - } - } -} - -impl StdError for DirectTlsError {} - -impl From for DirectTlsError { - fn from(e: TlsError) -> Self { - Self::Tls(e) - } -} - -#[cfg(feature = "rustls-any-backend")] -impl From for DirectTlsError { - fn from(e: InvalidDnsNameError) -> Self { - Self::DnsNameError(e) - } -} +/// Direct TLS ServerConnector Error - now just an alias to the common error type +pub type DirectTlsError = TlsConnectorError; diff --git a/tokio-xmpp/src/connect/mod.rs b/tokio-xmpp/src/connect/mod.rs index 9bf9990aef0f591592c6e58c32e3b7e2ba1b9e92..fbf4ed1eb64ff9fffd692db1a7f3c4a0d36c1f61 100644 --- a/tokio-xmpp/src/connect/mod.rs +++ b/tokio-xmpp/src/connect/mod.rs @@ -22,6 +22,9 @@ pub mod tcp; #[cfg(feature = "insecure-tcp")] pub use tcp::TcpServerConnector; +#[cfg(any(feature = "direct-tls", feature = "starttls"))] +pub mod tls_common; + mod dns; pub use dns::DnsConfig; diff --git a/tokio-xmpp/src/connect/starttls.rs b/tokio-xmpp/src/connect/starttls.rs index f3a7fede6d32da296dfc5eb20b81a7384799f289..aa6772d67f977b8f6db6184b72d9c7846096c76a 100644 --- a/tokio-xmpp/src/connect/starttls.rs +++ b/tokio-xmpp/src/connect/starttls.rs @@ -1,48 +1,10 @@ //! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections use alloc::borrow::Cow; -use core::{error::Error as StdError, fmt}; -#[cfg(feature = "native-tls")] -use native_tls::Error as TlsError; use std::io; use std::os::fd::AsRawFd; -#[cfg(feature = "rustls-any-backend")] -use tokio_rustls::rustls::pki_types::InvalidDnsNameError; -// Note: feature = "rustls-any-backend" and feature = "native-tls" are -// mutually exclusive during normal compiles, but we allow it for rustdoc -// builds. Thus, we have to make sure that the compilation still succeeds in -// such a case. -#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))] -use tokio_rustls::rustls::Error as TlsError; use futures::{sink::SinkExt, stream::StreamExt}; - -#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))] -use { - alloc::sync::Arc, - tokio_rustls::{ - rustls::pki_types::ServerName, - rustls::{ClientConfig, RootCertStore}, - TlsConnector, - }, -}; - -#[cfg(all( - feature = "rustls-any-backend", - not(feature = "ktls"), - not(feature = "native-tls") -))] -use tokio_rustls::client::TlsStream; - -#[cfg(all(feature = "ktls", not(feature = "native-tls")))] -type TlsStream = ktls::KtlsStream; - -#[cfg(feature = "native-tls")] -use { - native_tls::TlsConnector as NativeTlsConnector, - tokio_native_tls::{TlsConnector, TlsStream}, -}; - use sasl::common::ChannelBinding; use tokio::{ io::{AsyncRead, AsyncWrite, BufStream}, @@ -54,7 +16,10 @@ use xmpp_parsers::{ }; use crate::{ - connect::{DnsConfig, ServerConnector, ServerConnectorError}, + connect::{ + tls_common::{establish_tls_connection, TlsConnectorError, TlsStream}, + DnsConfig, ServerConnector, + }, error::{Error, ProtocolError}, xmlstream::{ initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream, @@ -127,74 +92,6 @@ impl ServerConnector for StartTlsServerConnector { } } -#[cfg(feature = "native-tls")] -async fn get_tls_stream( - xmpp_stream: XmppStream>, - domain: &str, -) -> Result<(TlsStream, ChannelBinding), Error> { - let domain = domain.to_owned(); - let stream = xmpp_stream.into_inner().into_inner(); - let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) - .connect(&domain, stream) - .await - .map_err(|e| StartTlsError::Tls(e))?; - log::warn!( - "tls-native doesn’t support channel binding, please use tls-rust if you want this feature!" - ); - Ok((tls_stream, ChannelBinding::None)) -} - -#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))] -async fn get_tls_stream( - xmpp_stream: XmppStream>, - domain: &str, -) -> Result<(TlsStream, ChannelBinding), Error> { - let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?; - let stream = xmpp_stream.into_inner().into_inner(); - let mut root_store = RootCertStore::empty(); - #[cfg(feature = "webpki-roots")] - { - root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - } - #[cfg(feature = "rustls-native-certs")] - { - root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?); - } - #[allow(unused_mut, reason = "This config is mutable when using ktls")] - let mut config = ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - #[cfg(feature = "ktls")] - let stream = { - config.enable_secret_extraction = true; - ktls::CorkStream::new(stream) - }; - let tls_stream = TlsConnector::from(Arc::new(config)) - .connect(domain, stream) - .await - .map_err(crate::Error::Io)?; - - // Extract the channel-binding information before we hand the stream over to ktls. - let (_, connection) = tls_stream.get_ref(); - let channel_binding = match connection.protocol_version() { - // TODO: Add support for TLS 1.2 and earlier. - Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => { - let data = vec![0u8; 32]; - let data = connection - .export_keying_material(data, b"EXPORTER-Channel-Binding", None) - .map_err(StartTlsError::Tls)?; - ChannelBinding::TlsExporter(data) - } - _ => ChannelBinding::None, - }; - - #[cfg(feature = "ktls")] - let tls_stream = ktls::config_ktls_client(tls_stream) - .await - .map_err(StartTlsError::KtlsError)?; - Ok((tls_stream, channel_binding)) -} - /// Performs `` on an XmppStream and returns a binary /// TlsStream. pub async fn starttls( @@ -224,47 +121,9 @@ pub async fn starttls( } } - get_tls_stream(stream, domain).await + let inner_stream = stream.into_inner().into_inner(); + establish_tls_connection(inner_stream, domain).await } -/// StartTLS ServerConnector Error -#[derive(Debug)] -pub enum StartTlsError { - /// TLS error - Tls(TlsError), - #[cfg(feature = "rustls-any-backend")] - /// DNS name parsing error - DnsNameError(InvalidDnsNameError), - #[cfg(feature = "ktls")] - /// Error while setting up kernel TLS - KtlsError(ktls::Error), -} - -impl ServerConnectorError for StartTlsError {} - -impl fmt::Display for StartTlsError { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Tls(e) => write!(fmt, "TLS error: {}", e), - #[cfg(feature = "rustls-any-backend")] - Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e), - #[cfg(feature = "ktls")] - Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e), - } - } -} - -impl StdError for StartTlsError {} - -impl From for StartTlsError { - fn from(e: TlsError) -> Self { - Self::Tls(e) - } -} - -#[cfg(feature = "rustls-any-backend")] -impl From for StartTlsError { - fn from(e: InvalidDnsNameError) -> Self { - Self::DnsNameError(e) - } -} +/// StartTLS ServerConnector Error - now just an alias to the common error type +pub type StartTlsError = TlsConnectorError; diff --git a/tokio-xmpp/src/connect/tls_common.rs b/tokio-xmpp/src/connect/tls_common.rs new file mode 100644 index 0000000000000000000000000000000000000000..b0ac8b374c1ea9409b39b9fc83dcebbb2d5d9470 --- /dev/null +++ b/tokio-xmpp/src/connect/tls_common.rs @@ -0,0 +1,170 @@ +// Copyright (c) 2025 Saarko +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Common TLS functionality shared between direct_tls and starttls modules + +use core::{error::Error as StdError, fmt}; +use std::os::fd::AsRawFd; +use tokio::io::{AsyncRead, AsyncWrite}; + +#[cfg(feature = "native-tls")] +use native_tls::Error as TlsError; +#[cfg(feature = "rustls-any-backend")] +use tokio_rustls::rustls::pki_types::InvalidDnsNameError; +#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))] +use tokio_rustls::rustls::Error as TlsError; + +#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))] +use { + alloc::sync::Arc, + tokio_rustls::{ + rustls::pki_types::ServerName, + rustls::{ClientConfig, RootCertStore}, + TlsConnector, + }, +}; + +#[cfg(all( + feature = "rustls-any-backend", + not(feature = "ktls"), + not(feature = "native-tls") +))] +pub use tokio_rustls::client::TlsStream; + +#[cfg(all(feature = "ktls", not(feature = "native-tls")))] +pub type TlsStream = ktls::KtlsStream; + +#[cfg(feature = "native-tls")] +pub use tokio_native_tls::TlsStream; + +#[cfg(feature = "native-tls")] +use {native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::TlsConnector}; + +use crate::{connect::ServerConnectorError, error::Error}; +use sasl::common::ChannelBinding; + +/// Common TLS error type used by both direct_tls and starttls +#[derive(Debug)] +pub enum TlsConnectorError { + /// TLS error + Tls(TlsError), + #[cfg(feature = "rustls-any-backend")] + /// DNS name parsing error + DnsNameError(InvalidDnsNameError), + #[cfg(feature = "ktls")] + /// Error while setting up kernel TLS + KtlsError(ktls::Error), +} + +impl ServerConnectorError for TlsConnectorError {} + +impl fmt::Display for TlsConnectorError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Tls(e) => write!(fmt, "TLS error: {}", e), + #[cfg(feature = "rustls-any-backend")] + Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e), + #[cfg(feature = "ktls")] + Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e), + } + } +} + +impl StdError for TlsConnectorError {} + +impl From for TlsConnectorError { + fn from(e: TlsError) -> Self { + Self::Tls(e) + } +} + +#[cfg(feature = "rustls-any-backend")] +impl From for TlsConnectorError { + fn from(e: InvalidDnsNameError) -> Self { + Self::DnsNameError(e) + } +} + +/// Establish TLS connection using native-tls +#[cfg(feature = "native-tls")] +pub async fn establish_tls_connection( + stream: S, + domain: &str, +) -> Result<(TlsStream, ChannelBinding), Error> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let domain = domain.to_owned(); + let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap()) + .connect(&domain, stream) + .await + .map_err(|e| TlsConnectorError::Tls(e))?; + log::warn!( + "tls-native doesn't support channel binding, please use tls-rust if you want this feature!" + ); + Ok((tls_stream, ChannelBinding::None)) +} + +/// Establish TLS connection using rustls +#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))] +pub async fn establish_tls_connection( + stream: S, + domain: &str, +) -> Result<(TlsStream, ChannelBinding), Error> +where + S: AsyncRead + AsyncWrite + Unpin + AsRawFd, +{ + let domain = + ServerName::try_from(domain.to_owned()).map_err(TlsConnectorError::DnsNameError)?; + let mut root_store = RootCertStore::empty(); + + #[cfg(feature = "webpki-roots")] + { + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + } + + #[cfg(feature = "rustls-native-certs")] + { + root_store.add_parsable_certificates(rustls_native_certs::load_native_certs()?); + } + + #[allow(unused_mut, reason = "This config is mutable when using ktls")] + let mut config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + #[cfg(feature = "ktls")] + let stream = { + config.enable_secret_extraction = true; + ktls::CorkStream::new(stream) + }; + + let tls_stream = TlsConnector::from(Arc::new(config)) + .connect(domain, stream) + .await + .map_err(crate::Error::Io)?; + + // Extract the channel-binding information before we hand the stream over to ktls. + let (_, connection) = tls_stream.get_ref(); + let channel_binding = match connection.protocol_version() { + // TODO: Add support for TLS 1.2 and earlier. + Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => { + let data = vec![0u8; 32]; + let data = connection + .export_keying_material(data, b"EXPORTER-Channel-Binding", None) + .map_err(TlsConnectorError::Tls)?; + ChannelBinding::TlsExporter(data) + } + _ => ChannelBinding::None, + }; + + #[cfg(feature = "ktls")] + let tls_stream = ktls::config_ktls_client(tls_stream) + .await + .map_err(TlsConnectorError::KtlsError)?; + + Ok((tls_stream, channel_binding)) +}