stream_start.rs

  1use std::mem::replace;
  2use std::io::{Error, ErrorKind};
  3use futures::{Future, Async, Poll, Stream, sink, Sink};
  4use tokio_io::{AsyncRead, AsyncWrite};
  5use tokio_io::codec::Framed;
  6use jid::Jid;
  7
  8use xmpp_codec::{XMPPCodec, Packet};
  9use xmpp_stream::XMPPStream;
 10
 11const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
 12
 13pub struct StreamStart<S: AsyncWrite> {
 14    state: StreamStartState<S>,
 15    jid: Jid,
 16    ns: String,
 17}
 18
 19enum StreamStartState<S: AsyncWrite> {
 20    SendStart(sink::Send<Framed<S, XMPPCodec>>),
 21    RecvStart(Framed<S, XMPPCodec>),
 22    RecvFeatures(Framed<S, XMPPCodec>, String),
 23    Invalid,
 24}
 25
 26impl<S: AsyncWrite> StreamStart<S> {
 27    pub fn from_stream(stream: Framed<S, XMPPCodec>, jid: Jid, ns: String) -> Self {
 28        let attrs = [("to".to_owned(), jid.domain.clone()),
 29                     ("version".to_owned(), "1.0".to_owned()),
 30                     ("xmlns".to_owned(), ns.clone()),
 31                     ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
 32        ].iter().cloned().collect();
 33        let send = stream.send(Packet::StreamStart(attrs));
 34
 35        StreamStart {
 36            state: StreamStartState::SendStart(send),
 37            jid,
 38            ns,
 39        }
 40    }
 41}
 42
 43impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
 44    type Item = XMPPStream<S>;
 45    type Error = Error;
 46
 47    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 48        let old_state = replace(&mut self.state, StreamStartState::Invalid);
 49        let mut retry = false;
 50
 51        let (new_state, result) = match old_state {
 52            StreamStartState::SendStart(mut send) =>
 53                match send.poll() {
 54                    Ok(Async::Ready(stream)) => {
 55                        retry = true;
 56                        (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
 57                    },
 58                    Ok(Async::NotReady) =>
 59                        (StreamStartState::SendStart(send), Ok(Async::NotReady)),
 60                    Err(e) =>
 61                        (StreamStartState::Invalid, Err(e)),
 62                },
 63            StreamStartState::RecvStart(mut stream) =>
 64                match stream.poll() {
 65                    Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
 66                        retry = true;
 67                        let stream_ns = match stream_attrs.get("xmlns") {
 68                            Some(ns) => ns.clone(),
 69                            None =>
 70                                return Err(Error::from(ErrorKind::InvalidData)),
 71                        };
 72                        // TODO: skip RecvFeatures for version < 1.0
 73                        (StreamStartState::RecvFeatures(stream, stream_ns), Ok(Async::NotReady))
 74                    },
 75                    Ok(Async::Ready(_)) =>
 76                        return Err(Error::from(ErrorKind::InvalidData)),
 77                    Ok(Async::NotReady) =>
 78                        (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
 79                    Err(e) =>
 80                        return Err(e),
 81                },
 82            StreamStartState::RecvFeatures(mut stream, stream_ns) =>
 83                match stream.poll() {
 84                    Ok(Async::Ready(Some(Packet::Stanza(stanza)))) =>
 85                        if stanza.name() == "features"
 86                        && stanza.ns() == Some(NS_XMPP_STREAM) {
 87                            let stream = XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), stanza);
 88                            (StreamStartState::Invalid, Ok(Async::Ready(stream)))
 89                        } else {
 90                            (StreamStartState::RecvFeatures(stream, stream_ns), Ok(Async::NotReady))
 91                        },
 92                    Ok(Async::Ready(_)) | Ok(Async::NotReady) =>
 93                        (StreamStartState::RecvFeatures(stream, stream_ns), Ok(Async::NotReady)),
 94                    Err(e) =>
 95                        return Err(e),
 96                },
 97            StreamStartState::Invalid =>
 98                unreachable!(),
 99        };
100
101        self.state = new_state;
102        if retry {
103            self.poll()
104        } else {
105            result
106        }
107    }
108}