starttls.rs

  1//! `starttls::ServerConfig` provides a `ServerConnector` for starttls connections
  2
  3use alloc::borrow::Cow;
  4use std::io;
  5use std::os::fd::AsRawFd;
  6
  7use futures::{sink::SinkExt, stream::StreamExt};
  8use sasl::common::ChannelBinding;
  9use tokio::{
 10    io::{AsyncRead, AsyncWrite, BufStream},
 11    net::TcpStream,
 12};
 13use xmpp_parsers::{
 14    jid::Jid,
 15    starttls::{self, Request},
 16};
 17
 18use crate::{
 19    connect::{
 20        tls_common::{establish_tls_connection, TlsConnectorError, TlsStream},
 21        DnsConfig, ServerConnector,
 22    },
 23    error::{Error, ProtocolError},
 24    xmlstream::{
 25        initiate_stream, PendingFeaturesRecv, ReadError, StreamHeader, Timeouts, XmppStream,
 26        XmppStreamElement,
 27    },
 28    Client,
 29};
 30
 31/// Client that connects over StartTls
 32#[deprecated(since = "5.0.0", note = "use tokio_xmpp::Client instead")]
 33pub type StartTlsClient = Client;
 34
 35/// Connect via TCP+StartTLS to an XMPP server
 36#[derive(Debug, Clone)]
 37pub struct StartTlsServerConnector(pub DnsConfig);
 38
 39impl From<DnsConfig> for StartTlsServerConnector {
 40    fn from(dns_config: DnsConfig) -> StartTlsServerConnector {
 41        Self(dns_config)
 42    }
 43}
 44
 45impl ServerConnector for StartTlsServerConnector {
 46    type Stream = BufStream<TlsStream<TcpStream>>;
 47
 48    async fn connect(
 49        &self,
 50        jid: &Jid,
 51        ns: &'static str,
 52        timeouts: Timeouts,
 53    ) -> Result<(PendingFeaturesRecv<Self::Stream>, ChannelBinding), Error> {
 54        let tcp_stream = tokio::io::BufStream::new(self.0.resolve().await?);
 55
 56        // Unencryped XmppStream
 57        let xmpp_stream = initiate_stream(
 58            tcp_stream,
 59            ns,
 60            StreamHeader {
 61                to: Some(Cow::Borrowed(jid.domain().as_str())),
 62                from: None,
 63                id: None,
 64            },
 65            timeouts,
 66        )
 67        .await?;
 68        let (features, xmpp_stream) = xmpp_stream.recv_features().await?;
 69
 70        if features.can_starttls() {
 71            // TlsStream
 72            let (tls_stream, channel_binding) =
 73                starttls(xmpp_stream, jid.domain().as_str()).await?;
 74            // Encrypted XmppStream
 75            Ok((
 76                initiate_stream(
 77                    tokio::io::BufStream::new(tls_stream),
 78                    ns,
 79                    StreamHeader {
 80                        to: Some(Cow::Borrowed(jid.domain().as_str())),
 81                        from: None,
 82                        id: None,
 83                    },
 84                    timeouts,
 85                )
 86                .await?,
 87                channel_binding,
 88            ))
 89        } else {
 90            Err(crate::Error::Protocol(ProtocolError::NoTls))
 91        }
 92    }
 93}
 94
 95/// Performs `<starttls/>` on an XmppStream and returns a binary
 96/// TlsStream.
 97pub async fn starttls<S: AsyncRead + AsyncWrite + Unpin + AsRawFd>(
 98    mut stream: XmppStream<BufStream<S>>,
 99    domain: &str,
100) -> Result<(TlsStream<S>, ChannelBinding), Error> {
101    stream
102        .send(&XmppStreamElement::Starttls(starttls::Nonza::Request(
103            Request,
104        )))
105        .await?;
106
107    loop {
108        match stream.next().await {
109            Some(Ok(XmppStreamElement::Starttls(starttls::Nonza::Proceed(_)))) => {
110                break;
111            }
112            Some(Ok(_)) => (),
113            Some(Err(ReadError::SoftTimeout)) => (),
114            Some(Err(ReadError::HardError(e))) => return Err(e.into()),
115            Some(Err(ReadError::ParseError(e))) => {
116                return Err(io::Error::new(io::ErrorKind::InvalidData, e).into())
117            }
118            None | Some(Err(ReadError::StreamFooterReceived)) => {
119                return Err(crate::Error::Disconnected)
120            }
121        }
122    }
123
124    let inner_stream = stream.into_inner().into_inner();
125    establish_tls_connection(inner_stream, domain).await
126}
127
128/// StartTLS ServerConnector Error - now just an alias to the common error type
129pub type StartTlsError = TlsConnectorError;