@@ -112,11 +112,9 @@ pub async fn client_auth<C: ServerConnector>(
let username = jid.node().unwrap().as_str();
let password = password;
- let xmpp_stream = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
+ let (xmpp_stream, channel_binding) = server.connect(&jid, ns::JABBER_CLIENT, timeouts).await?;
let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
- let channel_binding = C::channel_binding(xmpp_stream.get_stream())?;
-
let creds = Credentials::default()
.with_username(username)
.with_password(password)
@@ -16,7 +16,7 @@ pub async fn component_login<C: ServerConnector>(
timeouts: Timeouts,
) -> Result<XmppStream<C::Stream>, Error> {
let password = password;
- let mut stream = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
+ let (mut stream, _) = connector.connect(&jid, ns::COMPONENT, timeouts).await?;
let header = stream.take_header();
let mut stream = stream.skip_features();
let stream_id = match header.id {
@@ -37,12 +37,7 @@ pub trait ServerConnector: Clone + core::fmt::Debug + Send + Unpin + 'static {
jid: &Jid,
ns: &'static str,
timeouts: Timeouts,
- ) -> impl std::future::Future<Output = Result<PendingFeaturesRecv<Self::Stream>, Error>> + Send;
-
- /// Return channel binding data if available
- /// do not fail if channel binding is simply unavailable, just return Ok(None)
- /// this should only be called after the TLS handshake is finished
- fn channel_binding(_stream: &Self::Stream) -> Result<ChannelBinding, Error> {
- Ok(ChannelBinding::None)
- }
+ ) -> impl std::future::Future<
+ Output = Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error>,
+ > + Send;
}
@@ -82,7 +82,7 @@ impl ServerConnector for StartTlsServerConnector {
jid: &Jid,
ns: &'static str,
timeouts: Timeouts,
- ) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
+ ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
// Unencryped XmppStream
@@ -101,78 +101,51 @@ impl ServerConnector for StartTlsServerConnector {
if features.can_starttls() {
// TlsStream
- let tls_stream = starttls(xmpp_stream, jid.domain().as_str()).await?;
+ let (tls_stream, channel_binding) =
+ starttls(xmpp_stream, jid.domain().as_str()).await?;
// Encrypted XmppStream
- Ok(initiate_stream(
- tokio::io::BufStream::new(tls_stream),
- ns,
- StreamHeader {
- to: Some(Cow::Borrowed(jid.domain().as_str())),
- from: None,
- id: None,
- },
- timeouts,
- )
- .await?)
+ Ok((
+ initiate_stream(
+ tokio::io::BufStream::new(tls_stream),
+ ns,
+ StreamHeader {
+ to: Some(Cow::Borrowed(jid.domain().as_str())),
+ from: None,
+ id: None,
+ },
+ timeouts,
+ )
+ .await?,
+ channel_binding,
+ ))
} else {
Err(crate::Error::Protocol(ProtocolError::NoTls).into())
}
}
-
- fn channel_binding(
- #[allow(unused_variables)] stream: &Self::Stream,
- ) -> Result<sasl::common::ChannelBinding, Error> {
- #[cfg(feature = "tls-native")]
- {
- log::warn!("tls-native doesn’t support channel binding, please use tls-rust if you want this feature!");
- Ok(ChannelBinding::None)
- }
- #[cfg(all(feature = "tls-rust-ktls", not(feature = "tls-native")))]
- {
- log::warn!("Kernel TLS doesn’t support channel binding yet, we would have to extract the secrets in the rustls TlsStream before converting it into a KtlsStream.");
- Ok(ChannelBinding::None)
- }
- #[cfg(all(
- feature = "tls-rust",
- not(feature = "tls-native"),
- not(feature = "tls-rust-ktls")
- ))]
- {
- let (_, connection) = stream.get_ref().get_ref();
- Ok(match connection.protocol_version() {
- // TODO: Add support for TLS 1.2 and earlier.
- Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
- let data = vec![0u8; 32];
- let data = connection
- .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
- .map_err(|e| StartTlsError::Tls(e))?;
- ChannelBinding::TlsExporter(data)
- }
- _ => ChannelBinding::None,
- })
- }
- }
}
#[cfg(feature = "tls-native")]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
xmpp_stream: XmppStream<BufStream<S>>,
domain: &str,
-) -> Result<TlsStream<S>, Error> {
+) -> Result<(TlsStream<S>, ChannelBinding), Error> {
let domain = domain.to_owned();
let stream = xmpp_stream.into_inner().into_inner();
let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
.connect(&domain, stream)
.await
.map_err(|e| StartTlsError::Tls(e))?;
- Ok(tls_stream)
+ log::warn!(
+ "tls-native doesn’t support channel binding, please use tls-rust if you want this feature!"
+ );
+ Ok((tls_stream, ChannelBinding::None))
}
#[cfg(all(feature = "tls-rust", not(feature = "tls-native")))]
async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
xmpp_stream: XmppStream<BufStream<S>>,
domain: &str,
-) -> Result<TlsStream<S>, Error> {
+) -> Result<(TlsStream<S>, ChannelBinding), Error> {
let domain = ServerName::try_from(domain.to_owned()).map_err(StartTlsError::DnsNameError)?;
let stream = xmpp_stream.into_inner().into_inner();
let mut root_store = RootCertStore::empty();
@@ -197,11 +170,26 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
.connect(domain, stream)
.await
.map_err(|e| Error::from(crate::Error::Io(e)))?;
+
+ // Extract the channel-binding information before we hand the stream over to ktls.
+ let (_, connection) = tls_stream.get_ref();
+ let channel_binding = match connection.protocol_version() {
+ // TODO: Add support for TLS 1.2 and earlier.
+ Some(tokio_rustls::rustls::ProtocolVersion::TLSv1_3) => {
+ let data = vec![0u8; 32];
+ let data = connection
+ .export_keying_material(data, b"EXPORTER-Channel-Binding", None)
+ .map_err(|e| StartTlsError::Tls(e))?;
+ ChannelBinding::TlsExporter(data)
+ }
+ _ => ChannelBinding::None,
+ };
+
#[cfg(feature = "tls-rust-ktls")]
let tls_stream = ktls::config_ktls_client(tls_stream)
.await
.map_err(StartTlsError::KtlsError)?;
- Ok(tls_stream)
+ Ok((tls_stream, channel_binding))
}
/// Performs `<starttls/>` on an XmppStream and returns a binary
@@ -209,7 +197,7 @@ async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
mut stream: XmppStream<BufStream<S>>,
domain: &str,
-) -> Result<TlsStream<S>, Error> {
+) -> Result<(TlsStream<S>, ChannelBinding), Error> {
stream
.send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
Request,
@@ -5,7 +5,7 @@ use std::borrow::Cow;
use tokio::{io::BufStream, net::TcpStream};
use crate::{
- connect::{DnsConfig, ServerConnector},
+ connect::{ChannelBinding, DnsConfig, ServerConnector},
xmlstream::{initiate_stream, PendingFeaturesRecv, StreamHeader, Timeouts},
Client, Component, Error,
};
@@ -37,18 +37,21 @@ impl ServerConnector for TcpServerConnector {
jid: &xmpp_parsers::jid::Jid,
ns: &'static str,
timeouts: Timeouts,
- ) -> Result<PendingFeaturesRecv<Self::Stream>, Error> {
+ ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
let stream = BufStream::new(self.0.resolve().await?);
- Ok(initiate_stream(
- stream,
- ns,
- StreamHeader {
- to: Some(Cow::Borrowed(jid.domain().as_str())),
- from: None,
- id: None,
- },
- timeouts,
- )
- .await?)
+ Ok((
+ initiate_stream(
+ stream,
+ ns,
+ StreamHeader {
+ to: Some(Cow::Borrowed(jid.domain().as_str())),
+ from: None,
+ id: None,
+ },
+ timeouts,
+ )
+ .await?,
+ ChannelBinding::None,
+ ))
}
}