starttls.rs

  1use std::mem::replace;
  2use futures::{Future, Sink, Poll, Async};
  3use futures::stream::Stream;
  4use futures::sink;
  5use tokio_io::{AsyncRead, AsyncWrite};
  6use tokio_tls::{TlsStream, TlsConnectorExt, ConnectAsync};
  7use native_tls::TlsConnector;
  8use minidom::Element;
  9use jid::Jid;
 10
 11use xmpp_codec::Packet;
 12use xmpp_stream::XMPPStream;
 13
 14/// XMPP TLS XML namespace
 15pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 16
 17
 18/// XMPP stream that switches to TLS if available in received features
 19pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
 20    state: StartTlsClientState<S>,
 21    jid: Jid,
 22}
 23
 24enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
 25    Invalid,
 26    SendStartTls(sink::Send<XMPPStream<S>>),
 27    AwaitProceed(XMPPStream<S>),
 28    StartingTls(ConnectAsync<S>),
 29}
 30
 31impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
 32    /// Waits for <stream:features>
 33    pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
 34        let jid = xmpp_stream.jid.clone();
 35
 36        let nonza = Element::builder("starttls")
 37            .ns(NS_XMPP_TLS)
 38            .build();
 39        let packet = Packet::Stanza(nonza);
 40        let send = xmpp_stream.send(packet);
 41
 42        StartTlsClient {
 43            state: StartTlsClientState::SendStartTls(send),
 44            jid,
 45        }
 46    }
 47}
 48
 49impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
 50    type Item = TlsStream<S>;
 51    type Error = String;
 52
 53    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 54        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
 55        let mut retry = false;
 56        
 57        let (new_state, result) = match old_state {
 58            StartTlsClientState::SendStartTls(mut send) =>
 59                match send.poll() {
 60                    Ok(Async::Ready(xmpp_stream)) => {
 61                        let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
 62                        retry = true;
 63                        (new_state, Ok(Async::NotReady))
 64                    },
 65                    Ok(Async::NotReady) =>
 66                        (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
 67                    Err(e) =>
 68                        (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))),
 69                },
 70            StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
 71                match xmpp_stream.poll() {
 72                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
 73                        if stanza.name() == "proceed" =>
 74                    {
 75                        let stream = xmpp_stream.stream.into_inner();
 76                        let connect = TlsConnector::builder().unwrap()
 77                            .build().unwrap()
 78                            .connect_async(&self.jid.domain, stream);
 79                        let new_state = StartTlsClientState::StartingTls(connect);
 80                        retry = true;
 81                        (new_state, Ok(Async::NotReady))
 82                    },
 83                    Ok(Async::Ready(value)) => {
 84                        println!("StartTlsClient ignore {:?}", value);
 85                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
 86                    },
 87                    Ok(_) =>
 88                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
 89                    Err(e) =>
 90                        (StartTlsClientState::AwaitProceed(xmpp_stream),  Err(format!("{}", e))),
 91                },
 92            StartTlsClientState::StartingTls(mut connect) =>
 93                match connect.poll() {
 94                    Ok(Async::Ready(tls_stream)) =>
 95                        (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream))),
 96                    Ok(Async::NotReady) =>
 97                        (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
 98                    Err(e) =>
 99                        (StartTlsClientState::Invalid, Err(format!("{}", e))),
100                },
101            StartTlsClientState::Invalid =>
102                unreachable!(),
103        };
104
105        self.state = new_state;
106        if retry {
107            self.poll()
108        } else {
109            result
110        }
111    }
112}