starttls.rs

  1use std::mem::replace;
  2use futures::{Future, Sink, Poll, Async};
  3use futures::stream::Stream;
  4use futures::sink;
  5use tokio_io::{AsyncRead, AsyncWrite};
  6use tokio_tls::{TlsStream, TlsConnector, Connect};
  7use native_tls::TlsConnector as NativeTlsConnector;
  8use minidom::Element;
  9use jid::Jid;
 10
 11use xmpp_codec::Packet;
 12use xmpp_stream::XMPPStream;
 13use Error;
 14
 15/// XMPP TLS XML namespace
 16pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 17
 18
 19/// XMPP stream that switches to TLS if available in received features
 20pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
 21    state: StartTlsClientState<S>,
 22    jid: Jid,
 23}
 24
 25enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
 26    Invalid,
 27    SendStartTls(sink::Send<XMPPStream<S>>),
 28    AwaitProceed(XMPPStream<S>),
 29    StartingTls(Connect<S>),
 30}
 31
 32impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
 33    /// Waits for <stream:features>
 34    pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
 35        let jid = xmpp_stream.jid.clone();
 36
 37        let nonza = Element::builder("starttls")
 38            .ns(NS_XMPP_TLS)
 39            .build();
 40        let packet = Packet::Stanza(nonza);
 41        let send = xmpp_stream.send(packet);
 42
 43        StartTlsClient {
 44            state: StartTlsClientState::SendStartTls(send),
 45            jid,
 46        }
 47    }
 48}
 49
 50impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
 51    type Item = TlsStream<S>;
 52    type Error = Error;
 53
 54    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 55        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
 56        let mut retry = false;
 57
 58        let (new_state, result) = match old_state {
 59            StartTlsClientState::SendStartTls(mut send) =>
 60                match send.poll() {
 61                    Ok(Async::Ready(xmpp_stream)) => {
 62                        let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
 63                        retry = true;
 64                        (new_state, Ok(Async::NotReady))
 65                    },
 66                    Ok(Async::NotReady) =>
 67                        (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
 68                    Err(e) =>
 69                        (StartTlsClientState::SendStartTls(send), Err(e.into())),
 70                },
 71            StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
 72                match xmpp_stream.poll() {
 73                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
 74                        if stanza.name() == "proceed" =>
 75                    {
 76                        let stream = xmpp_stream.stream.into_inner();
 77                        let connect = TlsConnector::from(NativeTlsConnector::builder()
 78                            .build().unwrap())
 79                            .connect(&self.jid.domain, stream);
 80                        let new_state = StartTlsClientState::StartingTls(connect);
 81                        retry = true;
 82                        (new_state, Ok(Async::NotReady))
 83                    },
 84                    Ok(Async::Ready(value)) => {
 85                        println!("StartTlsClient ignore {:?}", value);
 86                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
 87                    },
 88                    Ok(_) =>
 89                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
 90                    Err(e) =>
 91                        (StartTlsClientState::AwaitProceed(xmpp_stream),  Err(Error::Protocol(e.into()))),
 92                },
 93            StartTlsClientState::StartingTls(mut connect) =>
 94                match connect.poll() {
 95                    Ok(Async::Ready(tls_stream)) =>
 96                        (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream))),
 97                    Ok(Async::NotReady) =>
 98                        (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
 99                    Err(e) =>
100                        (StartTlsClientState::Invalid, Err(e.into())),
101                },
102            StartTlsClientState::Invalid =>
103                unreachable!(),
104        };
105
106        self.state = new_state;
107        if retry {
108            self.poll()
109        } else {
110            result
111        }
112    }
113}