tls_common.rs

  1// Copyright (c) 2025 Saarko <saarko@tutanota.com>
  2//
  3// This Source Code Form is subject to the terms of the Mozilla Public
  4// License, v. 2.0. If a copy of the MPL was not distributed with this
  5// file, You can obtain one at http://mozilla.org/MPL/2.0/.
  6
  7//! Common TLS functionality shared between direct_tls and starttls modules
  8
  9use core::{error::Error as StdError, fmt};
 10#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
 11use std::os::fd::AsRawFd;
 12use tokio::io::{AsyncRead, AsyncWrite};
 13
 14#[cfg(feature = "native-tls")]
 15use native_tls::Error as TlsError;
 16#[cfg(feature = "rustls-any-backend")]
 17use tokio_rustls::rustls::pki_types::InvalidDnsNameError;
 18#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
 19use tokio_rustls::rustls::Error as TlsError;
 20
 21#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
 22use {
 23    alloc::sync::Arc,
 24    tokio_rustls::{
 25        rustls::pki_types::ServerName,
 26        rustls::{ClientConfig, RootCertStore},
 27        TlsConnector,
 28    },
 29};
 30
 31#[cfg(all(
 32    feature = "rustls-any-backend",
 33    not(feature = "ktls"),
 34    not(feature = "native-tls")
 35))]
 36pub use tokio_rustls::client::TlsStream;
 37
 38#[cfg(all(feature = "ktls", not(feature = "native-tls")))]
 39/// Tls Stream type based on Ktls
 40pub type TlsStream<S> = ktls::KtlsStream<S>;
 41
 42#[cfg(feature = "native-tls")]
 43pub use tokio_native_tls::TlsStream;
 44
 45#[cfg(feature = "native-tls")]
 46use {native_tls::TlsConnector as NativeTlsConnector, tokio_native_tls::TlsConnector};
 47
 48use crate::{connect::ServerConnectorError, error::Error};
 49use sasl::common::ChannelBinding;
 50
 51/// Common TLS error type used by both direct_tls and starttls
 52#[derive(Debug)]
 53pub enum TlsConnectorError {
 54    /// TLS error
 55    Tls(TlsError),
 56    #[cfg(feature = "rustls-any-backend")]
 57    /// DNS name parsing error
 58    DnsNameError(InvalidDnsNameError),
 59    #[cfg(feature = "ktls")]
 60    /// Error while setting up kernel TLS
 61    KtlsError(ktls::Error),
 62}
 63
 64impl ServerConnectorError for TlsConnectorError {}
 65
 66impl fmt::Display for TlsConnectorError {
 67    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
 68        match self {
 69            Self::Tls(e) => write!(fmt, "TLS error: {}", e),
 70            #[cfg(feature = "rustls-any-backend")]
 71            Self::DnsNameError(e) => write!(fmt, "DNS name error: {}", e),
 72            #[cfg(feature = "ktls")]
 73            Self::KtlsError(e) => write!(fmt, "Kernel TLS error: {}", e),
 74        }
 75    }
 76}
 77
 78impl StdError for TlsConnectorError {}
 79
 80impl From<TlsError> for TlsConnectorError {
 81    fn from(e: TlsError) -> Self {
 82        Self::Tls(e)
 83    }
 84}
 85
 86#[cfg(feature = "rustls-any-backend")]
 87impl From<InvalidDnsNameError> for TlsConnectorError {
 88    fn from(e: InvalidDnsNameError) -> Self {
 89        Self::DnsNameError(e)
 90    }
 91}
 92
 93/// Establish TLS connection using native-tls
 94#[cfg(feature = "native-tls")]
 95pub async fn establish_tls_connection<S>(
 96    stream: S,
 97    domain: &str,
 98) -> Result<(TlsStream<S>, ChannelBinding), Error>
 99where
100    S: AsyncRead + AsyncWrite + Unpin,
101{
102    let domain = domain.to_owned();
103    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
104        .connect(&domain, stream)
105        .await
106        .map_err(|e| TlsConnectorError::Tls(e))?;
107    log::warn!(
108        "tls-native doesn't support channel binding, please use tls-rust if you want this feature!"
109    );
110    Ok((tls_stream, ChannelBinding::None))
111}
112
113/// Establish TLS connection using rustls
114#[cfg(all(feature = "rustls-any-backend", not(feature = "native-tls")))]
115pub async fn establish_tls_connection<S>(
116    stream: S,
117    domain: &str,
118) -> Result<(TlsStream<S>, ChannelBinding), Error>
119where
120    S: AsyncRead + AsyncWrite + Unpin + AsRawFd,
121{
122    let domain =
123        ServerName::try_from(domain.to_owned()).map_err(TlsConnectorError::DnsNameError)?;
124    let mut root_store = RootCertStore::empty();
125
126    #[cfg(feature = "webpki-roots")]
127    {
128        root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
129    }
130
131    #[cfg(feature = "rustls-native-certs")]
132    {
133        root_store.add_parsable_certificates(rustls_native_certs::load_native_certs().certs);
134    }
135
136    #[allow(unused_mut, reason = "This config is mutable when using ktls")]
137    let mut config = ClientConfig::builder()
138        .with_root_certificates(root_store)
139        .with_no_client_auth();
140
141    #[cfg(feature = "ktls")]
142    let stream = {
143        config.enable_secret_extraction = true;
144        ktls::CorkStream::new(stream)
145    };
146
147    let tls_stream = TlsConnector::from(Arc::new(config))
148        .connect(domain, stream)
149        .await
150        .map_err(crate::Error::Io)?;
151
152    // Extract the channel-binding information before we hand the stream over to ktls.
153    let (_, connection) = tls_stream.get_ref();
154    let channel_binding = match connection.protocol_version() {
155        // TODO: Add support for TLS 1.2 and earlier.
156        Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
157            let data = vec![0u8; 32];
158            let data = connection
159                .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
160                .map_err(TlsConnectorError::Tls)?;
161            ChannelBinding::TlsExporter(data)
162        }
163        _ => ChannelBinding::None,
164    };
165
166    #[cfg(feature = "ktls")]
167    let tls_stream = ktls::config_ktls_client(tls_stream)
168        .await
169        .map_err(TlsConnectorError::KtlsError)?;
170
171    Ok((tls_stream, channel_binding))
172}