1use std::mem::replace;
2use std::io::{Error, ErrorKind};
3use std::collections::HashMap;
4use futures::*;
5use tokio_io::{AsyncRead, AsyncWrite};
6use tokio_io::codec::Framed;
7use jid::Jid;
8
9use xmpp_codec::*;
10use xmpp_stream::*;
11
12const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
13
14pub struct StreamStart<S: AsyncWrite> {
15 state: StreamStartState<S>,
16 jid: Jid,
17}
18
19enum StreamStartState<S: AsyncWrite> {
20 SendStart(sink::Send<Framed<S, XMPPCodec>>),
21 RecvStart(Framed<S, XMPPCodec>),
22 RecvFeatures(Framed<S, XMPPCodec>, HashMap<String, String>),
23 Invalid,
24}
25
26impl<S: AsyncWrite> StreamStart<S> {
27 pub fn from_stream(stream: Framed<S, XMPPCodec>, jid: Jid) -> Self {
28 let attrs = [("to".to_owned(), jid.domain.clone()),
29 ("version".to_owned(), "1.0".to_owned()),
30 ("xmlns".to_owned(), "jabber:client".to_owned()),
31 ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
32 ].iter().cloned().collect();
33 let send = stream.send(Packet::StreamStart(attrs));
34
35 StreamStart {
36 state: StreamStartState::SendStart(send),
37 jid,
38 }
39 }
40}
41
42impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
43 type Item = XMPPStream<S>;
44 type Error = Error;
45
46 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
47 let old_state = replace(&mut self.state, StreamStartState::Invalid);
48 let mut retry = false;
49
50 let (new_state, result) = match old_state {
51 StreamStartState::SendStart(mut send) =>
52 match send.poll() {
53 Ok(Async::Ready(stream)) => {
54 retry = true;
55 (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
56 },
57 Ok(Async::NotReady) =>
58 (StreamStartState::SendStart(send), Ok(Async::NotReady)),
59 Err(e) =>
60 (StreamStartState::Invalid, Err(e)),
61 },
62 StreamStartState::RecvStart(mut stream) =>
63 match stream.poll() {
64 Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
65 retry = true;
66 // TODO: skip RecvFeatures for version < 1.0
67 (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
68 },
69 Ok(Async::Ready(_)) =>
70 return Err(Error::from(ErrorKind::InvalidData)),
71 Ok(Async::NotReady) =>
72 (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
73 Err(e) =>
74 return Err(e),
75 },
76 StreamStartState::RecvFeatures(mut stream, stream_attrs) =>
77 match stream.poll() {
78 Ok(Async::Ready(Some(Packet::Stanza(stanza)))) =>
79 if stanza.name == "features"
80 && stanza.ns == Some(NS_XMPP_STREAM.to_owned()) {
81 let stream = XMPPStream::new(self.jid.clone(), stream, stream_attrs, stanza);
82 (StreamStartState::Invalid, Ok(Async::Ready(stream)))
83 } else {
84 (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
85 },
86 Ok(Async::Ready(item)) => {
87 println!("StreamStart skip {:?}", item);
88 (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady))
89 },
90 Ok(Async::NotReady) =>
91 (StreamStartState::RecvFeatures(stream, stream_attrs), Ok(Async::NotReady)),
92 Err(e) =>
93 return Err(e),
94 },
95 StreamStartState::Invalid =>
96 unreachable!(),
97 };
98
99 self.state = new_state;
100 if retry {
101 self.poll()
102 } else {
103 result
104 }
105 }
106}