starttls.rs

  1use std::mem::replace;
  2use futures::{Future, Sink, Poll, Async};
  3use futures::stream::Stream;
  4use futures::sink;
  5use tokio_io::{AsyncRead, AsyncWrite};
  6use tokio_tls::*;
  7use native_tls::TlsConnector;
  8use xml;
  9use jid::Jid;
 10
 11use xmpp_codec::*;
 12use xmpp_stream::*;
 13use stream_start::StreamStart;
 14
 15
 16pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 17
 18pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
 19    state: StartTlsClientState<S>,
 20    jid: Jid,
 21}
 22
 23enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
 24    Invalid,
 25    SendStartTls(sink::Send<XMPPStream<S>>),
 26    AwaitProceed(XMPPStream<S>),
 27    StartingTls(ConnectAsync<S>),
 28    Start(StreamStart<TlsStream<S>>),
 29}
 30
 31impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
 32    /// Waits for <stream:features>
 33    pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
 34        let jid = xmpp_stream.jid.clone();
 35
 36        let nonza = xml::Element::new(
 37            "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
 38            vec![]
 39        );
 40        println!("send {}", nonza);
 41        let packet = Packet::Stanza(nonza);
 42        let send = xmpp_stream.send(packet);
 43
 44        StartTlsClient {
 45            state: StartTlsClientState::SendStartTls(send),
 46            jid,
 47        }
 48    }
 49}
 50
 51impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
 52    type Item = XMPPStream<TlsStream<S>>;
 53    type Error = String;
 54
 55    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 56        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
 57        let mut retry = false;
 58        
 59        let (new_state, result) = match old_state {
 60            StartTlsClientState::SendStartTls(mut send) =>
 61                match send.poll() {
 62                    Ok(Async::Ready(xmpp_stream)) => {
 63                        println!("starttls sent");
 64                        let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
 65                        retry = true;
 66                        (new_state, Ok(Async::NotReady))
 67                    },
 68                    Ok(Async::NotReady) =>
 69                        (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
 70                    Err(e) =>
 71                        (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))),
 72                },
 73            StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
 74                match xmpp_stream.poll() {
 75                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
 76                        if stanza.name == "proceed" =>
 77                    {
 78                        println!("* proceed *");
 79                        let stream = xmpp_stream.stream.into_inner();
 80                        let connect = TlsConnector::builder().unwrap()
 81                            .build().unwrap()
 82                            .connect_async(&self.jid.domain, stream);
 83                        let new_state = StartTlsClientState::StartingTls(connect);
 84                        retry = true;
 85                        (new_state, Ok(Async::NotReady))
 86                    },
 87                    Ok(Async::Ready(value)) => {
 88                        println!("StartTlsClient ignore {:?}", value);
 89                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
 90                    },
 91                    Ok(_) =>
 92                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
 93                    Err(e) =>
 94                        (StartTlsClientState::AwaitProceed(xmpp_stream),  Err(format!("{}", e))),
 95                },
 96            StartTlsClientState::StartingTls(mut connect) =>
 97                match connect.poll() {
 98                    Ok(Async::Ready(tls_stream)) => {
 99                        println!("Got a TLS stream!");
100                        let start = XMPPStream::from_stream(tls_stream, self.jid.clone());
101                        let new_state = StartTlsClientState::Start(start);
102                        retry = true;
103                        (new_state, Ok(Async::NotReady))
104                    },
105                    Ok(Async::NotReady) =>
106                        (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
107                    Err(e) =>
108                        (StartTlsClientState::StartingTls(connect),  Err(format!("{}", e))),
109                },
110            StartTlsClientState::Start(mut start) =>
111                match start.poll() {
112                    Ok(Async::Ready(xmpp_stream)) =>
113                        (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))),
114                    Ok(Async::NotReady) =>
115                        (StartTlsClientState::Start(start), Ok(Async::NotReady)),
116                    Err(e) =>
117                        (StartTlsClientState::Invalid, Err(format!("{}", e))),
118                },
119            StartTlsClientState::Invalid =>
120                unreachable!(),
121        };
122
123        self.state = new_state;
124        if retry {
125            self.poll()
126        } else {
127            result
128        }
129    }
130}