1use futures::{sink::SinkExt, stream::StreamExt};
2
3#[cfg(feature = "tls-rust")]
4use {
5 std::convert::TryFrom,
6 std::sync::Arc,
7 tokio_rustls::{
8 client::TlsStream,
9 rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore, ServerName},
10 TlsConnector,
11 },
12 webpki_roots,
13};
14
15#[cfg(feature = "tls-native")]
16use {
17 native_tls::TlsConnector as NativeTlsConnector,
18 tokio_native_tls::{TlsConnector, TlsStream},
19};
20
21use tokio::io::{AsyncRead, AsyncWrite};
22use xmpp_parsers::{ns, Element};
23
24use crate::xmpp_codec::Packet;
25use crate::xmpp_stream::XMPPStream;
26use crate::{Error, ProtocolError};
27
28#[cfg(feature = "tls-native")]
29async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
30 xmpp_stream: XMPPStream<S>,
31) -> Result<TlsStream<S>, Error> {
32 let domain = &xmpp_stream.jid.clone().domain();
33 let stream = xmpp_stream.into_inner();
34 let tls_stream = TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
35 .connect(&domain, stream)
36 .await?;
37 Ok(tls_stream)
38}
39
40#[cfg(feature = "tls-rust")]
41async fn get_tls_stream<S: AsyncRead + AsyncWrite + Unpin>(
42 xmpp_stream: XMPPStream<S>,
43) -> Result<TlsStream<S>, Error> {
44 let domain = &xmpp_stream.jid.clone().domain();
45 let domain = ServerName::try_from(domain.as_str())?;
46 let stream = xmpp_stream.into_inner();
47 let mut root_store = RootCertStore::empty();
48 root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
49 OwnedTrustAnchor::from_subject_spki_name_constraints(
50 ta.subject,
51 ta.spki,
52 ta.name_constraints,
53 )
54 }));
55 let config = ClientConfig::builder()
56 .with_safe_defaults()
57 .with_root_certificates(root_store)
58 .with_no_client_auth();
59 let tls_stream = TlsConnector::from(Arc::new(config))
60 .connect(domain, stream)
61 .await?;
62 Ok(tls_stream)
63}
64
65/// Performs `<starttls/>` on an XMPPStream and returns a binary
66/// TlsStream.
67pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin>(
68 mut xmpp_stream: XMPPStream<S>,
69) -> Result<TlsStream<S>, Error> {
70 let nonza = Element::builder("starttls", ns::TLS).build();
71 let packet = Packet::Stanza(nonza);
72 xmpp_stream.send(packet).await?;
73
74 loop {
75 match xmpp_stream.next().await {
76 Some(Ok(Packet::Stanza(ref stanza))) if stanza.name() == "proceed" => break,
77 Some(Ok(Packet::Text(_))) => {}
78 Some(Err(e)) => return Err(e.into()),
79 _ => {
80 return Err(ProtocolError::NoTls.into());
81 }
82 }
83 }
84
85 get_tls_stream(xmpp_stream).await
86}