starttls.rs

 1use futures::{sink::SinkExt, stream::StreamExt};
 2
 3#[cfg(feature = "tls-rust")]
 4use {
 5    std::convert::TryFrom,
 6    std::sync::Arc,
 7    tokio_rustls::{
 8        client::TlsStream,
 9        rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
10        TlsConnector,
11    },
12    webpki_roots,
13};
14
15#[cfg(feature = "tls-native")]
16use {
17    native_tls::TlsConnector as NativeTlsConnector,
18    tokio_native_tls::{TlsConnector, TlsStream},
19};
20
21use tokio::io::{AsyncRead, AsyncWrite};
22use xmpp_parsers::{ns, Element};
23
24use crate::xmpp_codec::Packet;
25use crate::xmpp_stream::XMPPStream;
26use crate::{Error, ProtocolError};
27
28#[cfg(feature = "tls-native")]
29async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
30    xmpp_stream: XMPPStream<S>,
31) -> Result<TlsStream<S>, Error> {
32    let domain = &xmpp_stream.jid.clone().domain();
33    let stream = xmpp_stream.into_inner();
34    let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
35        .connect(&domain, stream)
36        .await?;
37    Ok(tls_stream)
38}
39
40#[cfg(feature = "tls-rust")]
41async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
42    xmpp_stream: XMPPStream<S>,
43) -> Result<TlsStream<S>, Error> {
44    let domain = &xmpp_stream.jid.clone().domain();
45    let domain = ServerName::try_from(domain.as_str())?;
46    let stream = xmpp_stream.into_inner();
47    let mut root_store = RootCertStore::empty();
48    root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
49        OwnedTrustAnchor::from_subject_spki_name_constraints(
50            ta.subject,
51            ta.spki,
52            ta.name_constraints,
53        )
54    }));
55    let config = ClientConfig::builder()
56        .with_safe_defaults()
57        .with_root_certificates(root_store)
58        .with_no_client_auth();
59    let tls_stream = TlsConnector::from(Arc::new(config))
60        .connect(domain, stream)
61        .await?;
62    Ok(tls_stream)
63}
64
65/// Performs `<starttls/>` on an XMPPStream and returns a binary
66/// TlsStream.
67pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
68    mut xmpp_stream: XMPPStream<S>,
69) -> Result<TlsStream<S>, Error> {
70    let nonza = Element::builder("starttls", ns::TLS).build();
71    let packet = Packet::Stanza(nonza);
72    xmpp_stream.send(packet).await?;
73
74    loop {
75        match xmpp_stream.next().await {
76            Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
77            Some(Ok(Packet::Text(_))) => {}
78            Some(Err(e)) => return Err(e.into()),
79            _ => {
80                return Err(ProtocolError::NoTls.into());
81            }
82        }
83    }
84
85    get_tls_stream(xmpp_stream).await
86}