starttls.rs

  1use std::mem::replace;
  2use futures::{Future, Sink, Poll, Async};
  3use futures::stream::Stream;
  4use futures::sink;
  5use tokio_io::{AsyncRead, AsyncWrite};
  6use tokio_tls::*;
  7use native_tls::TlsConnector;
  8use xml;
  9use jid::Jid;
 10
 11use xmpp_codec::*;
 12use xmpp_stream::*;
 13use stream_start::StreamStart;
 14
 15
 16pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 17
 18
 19pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
 20    state: StartTlsClientState<S>,
 21    jid: Jid,
 22}
 23
 24enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
 25    Invalid,
 26    SendStartTls(sink::Send<XMPPStream<S>>),
 27    AwaitProceed(XMPPStream<S>),
 28    StartingTls(ConnectAsync<S>),
 29    Start(StreamStart<TlsStream<S>>),
 30}
 31
 32impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
 33    /// Waits for <stream:features>
 34    pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
 35        let jid = xmpp_stream.jid.clone();
 36
 37        let nonza = xml::Element::new(
 38            "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
 39            vec![]
 40        );
 41        let packet = Packet::Stanza(nonza);
 42        let send = xmpp_stream.send(packet);
 43
 44        StartTlsClient {
 45            state: StartTlsClientState::SendStartTls(send),
 46            jid,
 47        }
 48    }
 49}
 50
 51impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
 52    type Item = XMPPStream<TlsStream<S>>;
 53    type Error = String;
 54
 55    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 56        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
 57        let mut retry = false;
 58        
 59        let (new_state, result) = match old_state {
 60            StartTlsClientState::SendStartTls(mut send) =>
 61                match send.poll() {
 62                    Ok(Async::Ready(xmpp_stream)) => {
 63                        let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
 64                        retry = true;
 65                        (new_state, Ok(Async::NotReady))
 66                    },
 67                    Ok(Async::NotReady) =>
 68                        (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
 69                    Err(e) =>
 70                        (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))),
 71                },
 72            StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
 73                match xmpp_stream.poll() {
 74                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
 75                        if stanza.name == "proceed" =>
 76                    {
 77                        let stream = xmpp_stream.stream.into_inner();
 78                        let connect = TlsConnector::builder().unwrap()
 79                            .build().unwrap()
 80                            .connect_async(&self.jid.domain, stream);
 81                        let new_state = StartTlsClientState::StartingTls(connect);
 82                        retry = true;
 83                        (new_state, Ok(Async::NotReady))
 84                    },
 85                    Ok(Async::Ready(value)) => {
 86                        println!("StartTlsClient ignore {:?}", value);
 87                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
 88                    },
 89                    Ok(_) =>
 90                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
 91                    Err(e) =>
 92                        (StartTlsClientState::AwaitProceed(xmpp_stream),  Err(format!("{}", e))),
 93                },
 94            StartTlsClientState::StartingTls(mut connect) =>
 95                match connect.poll() {
 96                    Ok(Async::Ready(tls_stream)) => {
 97                        println!("TLS stream established");
 98                        let start = XMPPStream::from_stream(tls_stream, self.jid.clone());
 99                        let new_state = StartTlsClientState::Start(start);
100                        retry = true;
101                        (new_state, Ok(Async::NotReady))
102                    },
103                    Ok(Async::NotReady) =>
104                        (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
105                    Err(e) =>
106                        (StartTlsClientState::StartingTls(connect),  Err(format!("{}", e))),
107                },
108            StartTlsClientState::Start(mut start) =>
109                match start.poll() {
110                    Ok(Async::Ready(xmpp_stream)) =>
111                        (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))),
112                    Ok(Async::NotReady) =>
113                        (StartTlsClientState::Start(start), Ok(Async::NotReady)),
114                    Err(e) =>
115                        (StartTlsClientState::Invalid, Err(format!("{}", e))),
116                },
117            StartTlsClientState::Invalid =>
118                unreachable!(),
119        };
120
121        self.state = new_state;
122        if retry {
123            self.poll()
124        } else {
125            result
126        }
127    }
128}