stream_start.rs

  1use futures::{sink, Async, Future, Poll, Sink, Stream};
  2use xmpp_parsers::{Jid, Element};
  3use std::mem::replace;
  4use tokio_codec::Framed;
  5use tokio_io::{AsyncRead, AsyncWrite};
  6
  7use crate::xmpp_codec::{Packet, XMPPCodec};
  8use crate::xmpp_stream::XMPPStream;
  9use crate::{Error, ProtocolError};
 10
 11const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
 12
 13pub struct StreamStart<S: AsyncWrite> {
 14    state: StreamStartState<S>,
 15    jid: Jid,
 16    ns: String,
 17}
 18
 19enum StreamStartState<S: AsyncWrite> {
 20    SendStart(sink::Send<Framed<S, XMPPCodec>>),
 21    RecvStart(Framed<S, XMPPCodec>),
 22    RecvFeatures(Framed<S, XMPPCodec>, String),
 23    Invalid,
 24}
 25
 26impl<S: AsyncWrite> StreamStart<S> {
 27    pub fn from_stream(stream: Framed<S, XMPPCodec>, jid: Jid, ns: String) -> Self {
 28        let attrs = [
 29            ("to".to_owned(), jid.clone().domain()),
 30            ("version".to_owned(), "1.0".to_owned()),
 31            ("xmlns".to_owned(), ns.clone()),
 32            ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
 33        ]
 34        .iter()
 35        .cloned()
 36        .collect();
 37        let send = stream.send(Packet::StreamStart(attrs));
 38
 39        StreamStart {
 40            state: StreamStartState::SendStart(send),
 41            jid,
 42            ns,
 43        }
 44    }
 45}
 46
 47impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
 48    type Item = XMPPStream<S>;
 49    type Error = Error;
 50
 51    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 52        let old_state = replace(&mut self.state, StreamStartState::Invalid);
 53        let mut retry = false;
 54
 55        let (new_state, result) = match old_state {
 56            StreamStartState::SendStart(mut send) => match send.poll() {
 57                Ok(Async::Ready(stream)) => {
 58                    retry = true;
 59                    (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
 60                }
 61                Ok(Async::NotReady) => (StreamStartState::SendStart(send), Ok(Async::NotReady)),
 62                Err(e) => (StreamStartState::Invalid, Err(e.into())),
 63            },
 64            StreamStartState::RecvStart(mut stream) => match stream.poll() {
 65                Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
 66                    let stream_ns = stream_attrs
 67                        .get("xmlns")
 68                        .ok_or(ProtocolError::NoStreamNamespace)?
 69                        .clone();
 70                    if self.ns == "jabber:client" {
 71                        retry = true;
 72                        // TODO: skip RecvFeatures for version < 1.0
 73                        (
 74                            StreamStartState::RecvFeatures(stream, stream_ns),
 75                            Ok(Async::NotReady),
 76                        )
 77                    } else {
 78                        let id = stream_attrs
 79                            .get("id")
 80                            .ok_or(ProtocolError::NoStreamId)?
 81                            .clone();
 82                        // FIXME: huge hack, shouldn’t be an element!
 83                        let stream = XMPPStream::new(
 84                            self.jid.clone(),
 85                            stream,
 86                            self.ns.clone(),
 87                            Element::builder(id).build(),
 88                        );
 89                        (StreamStartState::Invalid, Ok(Async::Ready(stream)))
 90                    }
 91                }
 92                Ok(Async::Ready(_)) => return Err(ProtocolError::InvalidToken.into()),
 93                Ok(Async::NotReady) => (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
 94                Err(e) => return Err(ProtocolError::from(e).into()),
 95            },
 96            StreamStartState::RecvFeatures(mut stream, stream_ns) => match stream.poll() {
 97                Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
 98                    if stanza.is("features", NS_XMPP_STREAM) {
 99                        let stream =
100                            XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), stanza);
101                        (StreamStartState::Invalid, Ok(Async::Ready(stream)))
102                    } else {
103                        (
104                            StreamStartState::RecvFeatures(stream, stream_ns),
105                            Ok(Async::NotReady),
106                        )
107                    }
108                }
109                Ok(Async::Ready(_)) | Ok(Async::NotReady) => (
110                    StreamStartState::RecvFeatures(stream, stream_ns),
111                    Ok(Async::NotReady),
112                ),
113                Err(e) => return Err(ProtocolError::from(e).into()),
114            },
115            StreamStartState::Invalid => unreachable!(),
116        };
117
118        self.state = new_state;
119        if retry {
120            self.poll()
121        } else {
122            result
123        }
124    }
125}