starttls.rs

  1use futures::sink;
  2use futures::stream::Stream;
  3use futures::{Async, Future, Poll, Sink};
  4use xmpp_parsers::{Jid, Element};
  5use native_tls::TlsConnector as NativeTlsConnector;
  6use std::mem::replace;
  7use tokio_io::{AsyncRead, AsyncWrite};
  8use tokio_tls::{Connect, TlsConnector, TlsStream};
  9
 10use crate::xmpp_codec::Packet;
 11use crate::xmpp_stream::XMPPStream;
 12use crate::Error;
 13
 14/// XMPP TLS XML namespace
 15pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 16
 17/// XMPP stream that switches to TLS if available in received features
 18pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
 19    state: StartTlsClientState<S>,
 20    jid: Jid,
 21}
 22
 23enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
 24    Invalid,
 25    SendStartTls(sink::Send<XMPPStream<S>>),
 26    AwaitProceed(XMPPStream<S>),
 27    StartingTls(Connect<S>),
 28}
 29
 30impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
 31    /// Waits for <stream:features>
 32    pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
 33        let jid = xmpp_stream.jid.clone();
 34
 35        let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
 36        let packet = Packet::Stanza(nonza);
 37        let send = xmpp_stream.send(packet);
 38
 39        StartTlsClient {
 40            state: StartTlsClientState::SendStartTls(send),
 41            jid,
 42        }
 43    }
 44}
 45
 46impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
 47    type Item = TlsStream<S>;
 48    type Error = Error;
 49
 50    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 51        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
 52        let mut retry = false;
 53
 54        let (new_state, result) = match old_state {
 55            StartTlsClientState::SendStartTls(mut send) => match send.poll() {
 56                Ok(Async::Ready(xmpp_stream)) => {
 57                    let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
 58                    retry = true;
 59                    (new_state, Ok(Async::NotReady))
 60                }
 61                Ok(Async::NotReady) => {
 62                    (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady))
 63                }
 64                Err(e) => (StartTlsClientState::SendStartTls(send), Err(e.into())),
 65            },
 66            StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() {
 67                Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
 68                    if stanza.name() == "proceed" =>
 69                {
 70                    let stream = xmpp_stream.stream.into_inner();
 71                    let connect =
 72                        TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
 73                            .connect(&self.jid.clone().domain(), stream);
 74                    let new_state = StartTlsClientState::StartingTls(connect);
 75                    retry = true;
 76                    (new_state, Ok(Async::NotReady))
 77                }
 78                Ok(Async::Ready(_value)) => {
 79                    // println!("StartTlsClient ignore {:?}", _value);
 80                    (
 81                        StartTlsClientState::AwaitProceed(xmpp_stream),
 82                        Ok(Async::NotReady),
 83                    )
 84                }
 85                Ok(_) => (
 86                    StartTlsClientState::AwaitProceed(xmpp_stream),
 87                    Ok(Async::NotReady),
 88                ),
 89                Err(e) => (
 90                    StartTlsClientState::AwaitProceed(xmpp_stream),
 91                    Err(Error::Protocol(e.into())),
 92                ),
 93            },
 94            StartTlsClientState::StartingTls(mut connect) => match connect.poll() {
 95                Ok(Async::Ready(tls_stream)) => {
 96                    (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream)))
 97                }
 98                Ok(Async::NotReady) => (
 99                    StartTlsClientState::StartingTls(connect),
100                    Ok(Async::NotReady),
101                ),
102                Err(e) => (StartTlsClientState::Invalid, Err(e.into())),
103            },
104            StartTlsClientState::Invalid => unreachable!(),
105        };
106
107        self.state = new_state;
108        if retry {
109            self.poll()
110        } else {
111            result
112        }
113    }
114}