starttls.rs

  1use std::mem::replace;
  2use futures::{Future, Sink, Poll, Async};
  3use futures::stream::Stream;
  4use futures::sink;
  5use tokio_io::{AsyncRead, AsyncWrite};
  6use tokio_tls::{TlsStream, TlsConnectorExt, ConnectAsync};
  7use native_tls::TlsConnector;
  8use minidom::Element;
  9use jid::Jid;
 10
 11use xmpp_codec::Packet;
 12use xmpp_stream::XMPPStream;
 13
 14
 15pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 16
 17
 18pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
 19    state: StartTlsClientState<S>,
 20    jid: Jid,
 21}
 22
 23enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
 24    Invalid,
 25    SendStartTls(sink::Send<XMPPStream<S>>),
 26    AwaitProceed(XMPPStream<S>),
 27    StartingTls(ConnectAsync<S>),
 28}
 29
 30impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
 31    /// Waits for <stream:features>
 32    pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
 33        let jid = xmpp_stream.jid.clone();
 34
 35        let nonza = Element::builder("starttls")
 36            .ns(NS_XMPP_TLS)
 37            .build();
 38        let packet = Packet::Stanza(nonza);
 39        let send = xmpp_stream.send(packet);
 40
 41        StartTlsClient {
 42            state: StartTlsClientState::SendStartTls(send),
 43            jid,
 44        }
 45    }
 46}
 47
 48impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
 49    type Item = TlsStream<S>;
 50    type Error = String;
 51
 52    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 53        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
 54        let mut retry = false;
 55        
 56        let (new_state, result) = match old_state {
 57            StartTlsClientState::SendStartTls(mut send) =>
 58                match send.poll() {
 59                    Ok(Async::Ready(xmpp_stream)) => {
 60                        let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
 61                        retry = true;
 62                        (new_state, Ok(Async::NotReady))
 63                    },
 64                    Ok(Async::NotReady) =>
 65                        (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
 66                    Err(e) =>
 67                        (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))),
 68                },
 69            StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
 70                match xmpp_stream.poll() {
 71                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
 72                        if stanza.name() == "proceed" =>
 73                    {
 74                        let stream = xmpp_stream.stream.into_inner();
 75                        let connect = TlsConnector::builder().unwrap()
 76                            .build().unwrap()
 77                            .connect_async(&self.jid.domain, stream);
 78                        let new_state = StartTlsClientState::StartingTls(connect);
 79                        retry = true;
 80                        (new_state, Ok(Async::NotReady))
 81                    },
 82                    Ok(Async::Ready(value)) => {
 83                        println!("StartTlsClient ignore {:?}", value);
 84                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
 85                    },
 86                    Ok(_) =>
 87                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
 88                    Err(e) =>
 89                        (StartTlsClientState::AwaitProceed(xmpp_stream),  Err(format!("{}", e))),
 90                },
 91            StartTlsClientState::StartingTls(mut connect) =>
 92                match connect.poll() {
 93                    Ok(Async::Ready(tls_stream)) =>
 94                        (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream))),
 95                    Ok(Async::NotReady) =>
 96                        (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
 97                    Err(e) =>
 98                        (StartTlsClientState::Invalid, Err(format!("{}", e))),
 99                },
100            StartTlsClientState::Invalid =>
101                unreachable!(),
102        };
103
104        self.state = new_state;
105        if retry {
106            self.poll()
107        } else {
108            result
109        }
110    }
111}