1use std::mem::replace;
2use std::io::{Error, ErrorKind};
3use std::collections::HashMap;
4use futures::*;
5use tokio_io::{AsyncRead, AsyncWrite};
6use tokio_io::codec::Framed;
7
8use xmpp_codec::*;
9use xmpp_stream::*;
10
11const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
12
13pub struct StreamStart<S: AsyncWrite> {
14 state: StreamStartState<S>,
15}
16
17enum StreamStartState<S: AsyncWrite> {
18 SendStart(sink::Send<Framed<S, XMPPCodec>>),
19 RecvStart(Framed<S, XMPPCodec>),
20 RecvFeatures(Framed<S, XMPPCodec>, HashMap<String, String>),
21 Invalid,
22}
23
24impl<S: AsyncWrite> StreamStart<S> {
25 pub fn from_stream(stream: Framed<S, XMPPCodec>, to: String) -> Self {
26 let attrs = [("to".to_owned(), to),
27 ("version".to_owned(), "1.0".to_owned()),
28 ("xmlns".to_owned(), "jabber:client".to_owned()),
29 ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
30 ].iter().cloned().collect();
31 let send = stream.send(Packet::StreamStart(attrs));
32
33 StreamStart {
34 state: StreamStartState::SendStart(send),
35 }
36 }
37}
38
39impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
40 type Item = XMPPStream<S>;
41 type Error = Error;
42
43 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
44 let old_state = replace(&mut self.state, StreamStartState::Invalid);
45 let mut retry = false;
46
47 let (new_state, result) = match old_state {
48 StreamStartState::SendStart(mut send) =>
49 match send.poll() {
50 Ok(Async::Ready(stream)) => {
51 retry = true;
52 (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
53 },
54 Ok(Async::NotReady) =>
55 (StreamStartState::SendStart(send), Ok(Async::NotReady)),
56 Err(e) =>
57 (StreamStartState::Invalid, Err(e)),
58 },
59 StreamStartState::RecvStart(mut stream) =>
60 match stream.poll() {
61 Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
62 retry = true;
63 // TODO: skip RecvFeatures for version < 1.0
64 (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
65 },
66 Ok(Async::Ready(_)) =>
67 return Err(Error::from(ErrorKind::InvalidData)),
68 Ok(Async::NotReady) =>
69 (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
70 Err(e) =>
71 return Err(e),
72 },
73 StreamStartState::RecvFeatures(mut stream, stream_attrs) =>
74 match stream.poll() {
75 Ok(Async::Ready(Some(Packet::Stanza(stanza)))) =>
76 if stanza.name == "features"
77 && stanza.ns == Some(NS_XMPP_STREAM.to_owned()) {
78 (StreamStartState::Invalid, Ok(Async::Ready(XMPPStream { stream, stream_attrs, stream_features: stanza })))
79 } else {
80 (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
81 },
82 Ok(Async::Ready(item)) => {
83 println!("StreamStart skip {:?}", item);
84 (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
85 },
86 Ok(Async::NotReady) =>
87 (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady)),
88 Err(e) =>
89 return Err(e),
90 },
91 StreamStartState::Invalid =>
92 unreachable!(),
93 };
94
95 self.state = new_state;
96 if retry {
97 self.poll()
98 } else {
99 result
100 }
101 }
102}