1use std::mem::replace;
2use std::io::{Error, ErrorKind};
3use futures::{Future, Async, Poll, Stream, sink, Sink};
4use tokio_io::{AsyncRead, AsyncWrite};
5use tokio_codec::Framed;
6use jid::Jid;
7use minidom::Element;
8
9use xmpp_codec::{XMPPCodec, Packet};
10use xmpp_stream::XMPPStream;
11
12const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
13
14pub struct StreamStart<S: AsyncWrite> {
15 state: StreamStartState<S>,
16 jid: Jid,
17 ns: String,
18}
19
20enum StreamStartState<S: AsyncWrite> {
21 SendStart(sink::Send<Framed<S, XMPPCodec>>),
22 RecvStart(Framed<S, XMPPCodec>),
23 RecvFeatures(Framed<S, XMPPCodec>, String),
24 Invalid,
25}
26
27impl<S: AsyncWrite> StreamStart<S> {
28 pub fn from_stream(stream: Framed<S, XMPPCodec>, jid: Jid, ns: String) -> Self {
29 let attrs = [("to".to_owned(), jid.domain.clone()),
30 ("version".to_owned(), "1.0".to_owned()),
31 ("xmlns".to_owned(), ns.clone()),
32 ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
33 ].iter().cloned().collect();
34 let send = stream.send(Packet::StreamStart(attrs));
35
36 StreamStart {
37 state: StreamStartState::SendStart(send),
38 jid,
39 ns,
40 }
41 }
42}
43
44impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
45 type Item = XMPPStream<S>;
46 type Error = Error;
47
48 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
49 let old_state = replace(&mut self.state, StreamStartState::Invalid);
50 let mut retry = false;
51
52 let (new_state, result) = match old_state {
53 StreamStartState::SendStart(mut send) =>
54 match send.poll() {
55 Ok(Async::Ready(stream)) => {
56 retry = true;
57 (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
58 },
59 Ok(Async::NotReady) =>
60 (StreamStartState::SendStart(send), Ok(Async::NotReady)),
61 Err(e) =>
62 (StreamStartState::Invalid, Err(e)),
63 },
64 StreamStartState::RecvStart(mut stream) =>
65 match stream.poll() {
66 Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
67 let stream_ns = match stream_attrs.get("xmlns") {
68 Some(ns) => ns.clone(),
69 None =>
70 return Err(Error::from(ErrorKind::InvalidData)),
71 };
72 if self.ns == "jabber:client" {
73 retry = true;
74 // TODO: skip RecvFeatures for version < 1.0
75 (StreamStartState::RecvFeatures(stream, stream_ns), Ok(Async::NotReady))
76 } else {
77 let id = match stream_attrs.get("id") {
78 Some(id) => id.clone(),
79 None =>
80 return Err(Error::from(ErrorKind::InvalidData)),
81 };
82 // FIXME: huge hack, shouldn’t be an element!
83 let stream = XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), Element::builder(id).build());
84 (StreamStartState::Invalid, Ok(Async::Ready(stream)))
85 }
86 },
87 Ok(Async::Ready(_)) =>
88 return Err(Error::from(ErrorKind::InvalidData)),
89 Ok(Async::NotReady) =>
90 (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
91 Err(e) =>
92 return Err(e),
93 },
94 StreamStartState::RecvFeatures(mut stream, stream_ns) =>
95 match stream.poll() {
96 Ok(Async::Ready(Some(Packet::Stanza(stanza)))) =>
97 if stanza.is("features", NS_XMPP_STREAM) {
98 let stream = XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), stanza);
99 (StreamStartState::Invalid, Ok(Async::Ready(stream)))
100 } else {
101 (StreamStartState::RecvFeatures(stream, stream_ns), Ok(Async::NotReady))
102 },
103 Ok(Async::Ready(_)) | Ok(Async::NotReady) =>
104 (StreamStartState::RecvFeatures(stream, stream_ns), Ok(Async::NotReady)),
105 Err(e) =>
106 return Err(e),
107 },
108 StreamStartState::Invalid =>
109 unreachable!(),
110 };
111
112 self.state = new_state;
113 if retry {
114 self.poll()
115 } else {
116 result
117 }
118 }
119}