stream_start.rs

  1use std::mem::replace;
  2use std::io::{Error, ErrorKind};
  3use std::collections::HashMap;
  4use futures::*;
  5use tokio_io::{AsyncRead, AsyncWrite};
  6use tokio_io::codec::Framed;
  7
  8use xmpp_codec::*;
  9use xmpp_stream::*;
 10
 11const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
 12
 13pub struct StreamStart<S: AsyncWrite> {
 14    state: StreamStartState<S>,
 15}
 16
 17enum StreamStartState<S: AsyncWrite> {
 18    SendStart(sink::Send<Framed<S, XMPPCodec>>),
 19    RecvStart(Framed<S, XMPPCodec>),
 20    RecvFeatures(Framed<S, XMPPCodec>, HashMap<String, String>),
 21    Invalid,
 22}
 23
 24impl<S: AsyncWrite> StreamStart<S> {
 25    pub fn from_stream(stream: Framed<S, XMPPCodec>, to: String) -> Self {
 26        let attrs = [("to".to_owned(), to),
 27                     ("version".to_owned(), "1.0".to_owned()),
 28                     ("xmlns".to_owned(), "jabber:client".to_owned()),
 29                     ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
 30        ].iter().cloned().collect();
 31        let send = stream.send(Packet::StreamStart(attrs));
 32
 33        StreamStart {
 34            state: StreamStartState::SendStart(send),
 35        }
 36    }
 37}
 38
 39impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
 40    type Item = XMPPStream<S>;
 41    type Error = Error;
 42
 43    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 44        let old_state = replace(&mut self.state, StreamStartState::Invalid);
 45        let mut retry = false;
 46
 47        let (new_state, result) = match old_state {
 48            StreamStartState::SendStart(mut send) =>
 49                match send.poll() {
 50                    Ok(Async::Ready(stream)) => {
 51                        retry = true;
 52                        (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
 53                    },
 54                    Ok(Async::NotReady) =>
 55                        (StreamStartState::SendStart(send), Ok(Async::NotReady)),
 56                    Err(e) =>
 57                        (StreamStartState::Invalid, Err(e)),
 58                },
 59            StreamStartState::RecvStart(mut stream) =>
 60                match stream.poll() {
 61                    Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
 62                        retry = true;
 63                        // TODO: skip RecvFeatures for version < 1.0
 64                        (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
 65                    },
 66                    Ok(Async::Ready(_)) =>
 67                        return Err(Error::from(ErrorKind::InvalidData)),
 68                    Ok(Async::NotReady) =>
 69                        (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
 70                    Err(e) =>
 71                        return Err(e),
 72                },
 73            StreamStartState::RecvFeatures(mut stream, stream_attrs) =>
 74                match stream.poll() {
 75                    Ok(Async::Ready(Some(Packet::Stanza(stanza)))) =>
 76                        if stanza.name == "features"
 77                        && stanza.ns == Some(NS_XMPP_STREAM.to_owned()) {
 78                            (StreamStartState::Invalid, Ok(Async::Ready(XMPPStream { stream, stream_attrs, stream_features: stanza })))
 79                        } else {
 80                            (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
 81                        },
 82                    Ok(Async::Ready(item)) => {
 83                        println!("StreamStart skip {:?}", item);
 84                        (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
 85                    },
 86                    Ok(Async::NotReady) =>
 87                        (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady)),
 88                    Err(e) =>
 89                        return Err(e),
 90                },
 91            StreamStartState::Invalid =>
 92                unreachable!(),
 93        };
 94
 95        self.state = new_state;
 96        if retry {
 97            self.poll()
 98        } else {
 99            result
100        }
101    }
102}