starttls.rs

  1use futures::sink;
  2use futures::stream::Stream;
  3use futures::{Async, Future, Poll, Sink};
  4use jid::Jid;
  5use minidom::Element;
  6use native_tls::TlsConnector as NativeTlsConnector;
  7use std::mem::replace;
  8use tokio_io::{AsyncRead, AsyncWrite};
  9use tokio_tls::{Connect, TlsConnector, TlsStream};
 10
 11use crate::xmpp_codec::Packet;
 12use crate::xmpp_stream::XMPPStream;
 13use crate::Error;
 14
 15/// XMPP TLS XML namespace
 16pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 17
 18/// XMPP stream that switches to TLS if available in received features
 19pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
 20    state: StartTlsClientState<S>,
 21    jid: Jid,
 22}
 23
 24enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
 25    Invalid,
 26    SendStartTls(sink::Send<XMPPStream<S>>),
 27    AwaitProceed(XMPPStream<S>),
 28    StartingTls(Connect<S>),
 29}
 30
 31impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
 32    /// Waits for <stream:features>
 33    pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
 34        let jid = xmpp_stream.jid.clone();
 35
 36        let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
 37        let packet = Packet::Stanza(nonza);
 38        let send = xmpp_stream.send(packet);
 39
 40        StartTlsClient {
 41            state: StartTlsClientState::SendStartTls(send),
 42            jid,
 43        }
 44    }
 45}
 46
 47impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
 48    type Item = TlsStream<S>;
 49    type Error = Error;
 50
 51    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 52        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
 53        let mut retry = false;
 54
 55        let (new_state, result) = match old_state {
 56            StartTlsClientState::SendStartTls(mut send) => match send.poll() {
 57                Ok(Async::Ready(xmpp_stream)) => {
 58                    let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
 59                    retry = true;
 60                    (new_state, Ok(Async::NotReady))
 61                }
 62                Ok(Async::NotReady) => {
 63                    (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady))
 64                }
 65                Err(e) => (StartTlsClientState::SendStartTls(send), Err(e.into())),
 66            },
 67            StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() {
 68                Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
 69                    if stanza.name() == "proceed" =>
 70                {
 71                    let stream = xmpp_stream.stream.into_inner();
 72                    let connect =
 73                        TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
 74                            .connect(&self.jid.domain, stream);
 75                    let new_state = StartTlsClientState::StartingTls(connect);
 76                    retry = true;
 77                    (new_state, Ok(Async::NotReady))
 78                }
 79                Ok(Async::Ready(value)) => {
 80                    println!("StartTlsClient ignore {:?}", value);
 81                    (
 82                        StartTlsClientState::AwaitProceed(xmpp_stream),
 83                        Ok(Async::NotReady),
 84                    )
 85                }
 86                Ok(_) => (
 87                    StartTlsClientState::AwaitProceed(xmpp_stream),
 88                    Ok(Async::NotReady),
 89                ),
 90                Err(e) => (
 91                    StartTlsClientState::AwaitProceed(xmpp_stream),
 92                    Err(Error::Protocol(e.into())),
 93                ),
 94            },
 95            StartTlsClientState::StartingTls(mut connect) => match connect.poll() {
 96                Ok(Async::Ready(tls_stream)) => {
 97                    (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream)))
 98                }
 99                Ok(Async::NotReady) => (
100                    StartTlsClientState::StartingTls(connect),
101                    Ok(Async::NotReady),
102                ),
103                Err(e) => (StartTlsClientState::Invalid, Err(e.into())),
104            },
105            StartTlsClientState::Invalid => unreachable!(),
106        };
107
108        self.state = new_state;
109        if retry {
110            self.poll()
111        } else {
112            result
113        }
114    }
115}