starttls.rs

  1use std::mem::replace;
  2use std::io::Error;
  3use std::sync::Arc;
  4use futures::{Future, Sink, Poll, Async};
  5use futures::stream::Stream;
  6use futures::sink;
  7use tokio_io::{AsyncRead, AsyncWrite};
  8use rustls::*;
  9use tokio_rustls::*;
 10use xml;
 11
 12use xmpp_codec::*;
 13use xmpp_stream::*;
 14use stream_start::StreamStart;
 15
 16
 17pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
 18
 19pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
 20    state: StartTlsClientState<S>,
 21    arc_config: Arc<ClientConfig>,
 22}
 23
 24enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
 25    Invalid,
 26    SendStartTls(sink::Send<XMPPStream<S>>),
 27    AwaitProceed(XMPPStream<S>),
 28    StartingTls(ConnectAsync<S>),
 29    Start(StreamStart<TlsStream<S, ClientSession>>),
 30}
 31
 32impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
 33    /// Waits for <stream:features>
 34    pub fn from_stream(xmpp_stream: XMPPStream<S>, arc_config: Arc<ClientConfig>) -> Self {
 35        let nonza = xml::Element::new(
 36            "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
 37            vec![]
 38        );
 39        println!("send {}", nonza);
 40        let packet = Packet::Stanza(nonza);
 41        let send = xmpp_stream.send(packet);
 42
 43        StartTlsClient {
 44            state: StartTlsClientState::SendStartTls(send),
 45            arc_config: arc_config,
 46        }
 47    }
 48}
 49
 50impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
 51    type Item = XMPPStream<TlsStream<S, ClientSession>>;
 52    type Error = Error;
 53
 54    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 55        let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
 56        let mut retry = false;
 57        
 58        let (new_state, result) = match old_state {
 59            StartTlsClientState::SendStartTls(mut send) =>
 60                match send.poll() {
 61                    Ok(Async::Ready(xmpp_stream)) => {
 62                        println!("starttls sent");
 63                        let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
 64                        retry = true;
 65                        (new_state, Ok(Async::NotReady))
 66                    },
 67                    Ok(Async::NotReady) =>
 68                        (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
 69                    Err(e) =>
 70                        (StartTlsClientState::SendStartTls(send), Err(e)),
 71                },
 72            StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
 73                match xmpp_stream.poll() {
 74                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
 75                        if stanza.name == "proceed" =>
 76                    {
 77                        println!("* proceed *");
 78                        let stream = xmpp_stream.into_inner();
 79                        let connect = self.arc_config.connect_async("spaceboyz.net", stream);
 80                        let new_state = StartTlsClientState::StartingTls(connect);
 81                        retry = true;
 82                        (new_state, Ok(Async::NotReady))
 83                    },
 84                    Ok(Async::Ready(value)) => {
 85                        println!("StartTlsClient ignore {:?}", value);
 86                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
 87                    },
 88                    Ok(_) =>
 89                        (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
 90                    Err(e) =>
 91                        (StartTlsClientState::AwaitProceed(xmpp_stream),  Err(e)),
 92                },
 93            StartTlsClientState::StartingTls(mut connect) =>
 94                match connect.poll() {
 95                    Ok(Async::Ready(tls_stream)) => {
 96                        println!("Got a TLS stream!");
 97                        let start = XMPPStream::from_stream(tls_stream, "spaceboyz.net".to_owned());
 98                        let new_state = StartTlsClientState::Start(start);
 99                        retry = true;
100                        (new_state, Ok(Async::NotReady))
101                    },
102                    Ok(Async::NotReady) =>
103                        (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
104                    Err(e) =>
105                        (StartTlsClientState::StartingTls(connect),  Err(e)),
106                },
107            StartTlsClientState::Start(mut start) =>
108                match start.poll() {
109                    Ok(Async::Ready(xmpp_stream)) =>
110                        (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))),
111                    Ok(Async::NotReady) =>
112                        (StartTlsClientState::Start(start), Ok(Async::NotReady)),
113                    Err(e) =>
114                        (StartTlsClientState::Invalid, Err(e)),
115                },
116            StartTlsClientState::Invalid =>
117                unreachable!(),
118        };
119
120        self.state = new_state;
121        if retry {
122            self.poll()
123        } else {
124            result
125        }
126    }
127}