1use std::mem::replace;
2use std::io::{Error, ErrorKind};
3use futures::{Future, Async, Poll, Stream, sink, Sink};
4use tokio_io::{AsyncRead, AsyncWrite};
5use tokio_io::codec::Framed;
6use jid::Jid;
7
8use xmpp_codec::{XMPPCodec, Packet};
9use xmpp_stream::XMPPStream;
10
11const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
12
13pub struct StreamStart<S: AsyncWrite> {
14 state: StreamStartState<S>,
15 jid: Jid,
16 ns: String,
17}
18
19enum StreamStartState<S: AsyncWrite> {
20 SendStart(sink::Send<Framed<S, XMPPCodec>>),
21 RecvStart(Framed<S, XMPPCodec>),
22 RecvFeatures(Framed<S, XMPPCodec>, String),
23 Invalid,
24}
25
26impl<S: AsyncWrite> StreamStart<S> {
27 pub fn from_stream(stream: Framed<S, XMPPCodec>, jid: Jid, ns: String) -> Self {
28 let attrs = [("to".to_owned(), jid.domain.clone()),
29 ("version".to_owned(), "1.0".to_owned()),
30 ("xmlns".to_owned(), ns.clone()),
31 ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
32 ].iter().cloned().collect();
33 let send = stream.send(Packet::StreamStart(attrs));
34
35 StreamStart {
36 state: StreamStartState::SendStart(send),
37 jid,
38 ns,
39 }
40 }
41}
42
43impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
44 type Item = XMPPStream<S>;
45 type Error = Error;
46
47 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
48 let old_state = replace(&mut self.state, StreamStartState::Invalid);
49 let mut retry = false;
50
51 let (new_state, result) = match old_state {
52 StreamStartState::SendStart(mut send) =>
53 match send.poll() {
54 Ok(Async::Ready(stream)) => {
55 retry = true;
56 (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
57 },
58 Ok(Async::NotReady) =>
59 (StreamStartState::SendStart(send), Ok(Async::NotReady)),
60 Err(e) =>
61 (StreamStartState::Invalid, Err(e)),
62 },
63 StreamStartState::RecvStart(mut stream) =>
64 match stream.poll() {
65 Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
66 retry = true;
67 let stream_ns = match stream_attrs.get("xmlns") {
68 Some(ns) => ns.clone(),
69 None =>
70 return Err(Error::from(ErrorKind::InvalidData)),
71 };
72 // TODO: skip RecvFeatures for version < 1.0
73 (StreamStartState::RecvFeatures(stream, stream_ns), Ok(Async::NotReady))
74 },
75 Ok(Async::Ready(_)) =>
76 return Err(Error::from(ErrorKind::InvalidData)),
77 Ok(Async::NotReady) =>
78 (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
79 Err(e) =>
80 return Err(e),
81 },
82 StreamStartState::RecvFeatures(mut stream, stream_ns) =>
83 match stream.poll() {
84 Ok(Async::Ready(Some(Packet::Stanza(stanza)))) =>
85 if stanza.name() == "features"
86 && stanza.ns() == Some(NS_XMPP_STREAM) {
87 let stream = XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), stanza);
88 (StreamStartState::Invalid, Ok(Async::Ready(stream)))
89 } else {
90 (StreamStartState::RecvFeatures(stream, stream_ns), Ok(Async::NotReady))
91 },
92 Ok(Async::Ready(_)) | Ok(Async::NotReady) =>
93 (StreamStartState::RecvFeatures(stream, stream_ns), Ok(Async::NotReady)),
94 Err(e) =>
95 return Err(e),
96 },
97 StreamStartState::Invalid =>
98 unreachable!(),
99 };
100
101 self.state = new_state;
102 if retry {
103 self.poll()
104 } else {
105 result
106 }
107 }
108}