1use futures::{sink, Async, Future, Poll, Sink, Stream};
2use xmpp_parsers::{Jid, Element};
3use std::mem::replace;
4use tokio_codec::Framed;
5use tokio_io::{AsyncRead, AsyncWrite};
6
7use crate::xmpp_codec::{Packet, XMPPCodec};
8use crate::xmpp_stream::XMPPStream;
9use crate::{Error, ProtocolError};
10
11const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
12
13pub struct StreamStart<S: AsyncWrite> {
14 state: StreamStartState<S>,
15 jid: Jid,
16 ns: String,
17}
18
19enum StreamStartState<S: AsyncWrite> {
20 SendStart(sink::Send<Framed<S, XMPPCodec>>),
21 RecvStart(Framed<S, XMPPCodec>),
22 RecvFeatures(Framed<S, XMPPCodec>, String),
23 Invalid,
24}
25
26impl<S: AsyncWrite> StreamStart<S> {
27 pub fn from_stream(stream: Framed<S, XMPPCodec>, jid: Jid, ns: String) -> Self {
28 let attrs = [
29 ("to".to_owned(), jid.clone().domain()),
30 ("version".to_owned(), "1.0".to_owned()),
31 ("xmlns".to_owned(), ns.clone()),
32 ("xmlns:stream".to_owned(), NS_XMPP_STREAM.to_owned()),
33 ]
34 .iter()
35 .cloned()
36 .collect();
37 let send = stream.send(Packet::StreamStart(attrs));
38
39 StreamStart {
40 state: StreamStartState::SendStart(send),
41 jid,
42 ns,
43 }
44 }
45}
46
47impl<S: AsyncRead + AsyncWrite> Future for StreamStart<S> {
48 type Item = XMPPStream<S>;
49 type Error = Error;
50
51 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
52 let old_state = replace(&mut self.state, StreamStartState::Invalid);
53 let mut retry = false;
54
55 let (new_state, result) = match old_state {
56 StreamStartState::SendStart(mut send) => match send.poll() {
57 Ok(Async::Ready(stream)) => {
58 retry = true;
59 (StreamStartState::RecvStart(stream), Ok(Async::NotReady))
60 }
61 Ok(Async::NotReady) => (StreamStartState::SendStart(send), Ok(Async::NotReady)),
62 Err(e) => (StreamStartState::Invalid, Err(e.into())),
63 },
64 StreamStartState::RecvStart(mut stream) => match stream.poll() {
65 Ok(Async::Ready(Some(Packet::StreamStart(stream_attrs)))) => {
66 let stream_ns = stream_attrs
67 .get("xmlns")
68 .ok_or(ProtocolError::NoStreamNamespace)?
69 .clone();
70 if self.ns == "jabber:client" {
71 retry = true;
72 // TODO: skip RecvFeatures for version < 1.0
73 (
74 StreamStartState::RecvFeatures(stream, stream_ns),
75 Ok(Async::NotReady),
76 )
77 } else {
78 let id = stream_attrs
79 .get("id")
80 .ok_or(ProtocolError::NoStreamId)?
81 .clone();
82 // FIXME: huge hack, shouldn’t be an element!
83 let stream = XMPPStream::new(
84 self.jid.clone(),
85 stream,
86 self.ns.clone(),
87 Element::builder(id).build(),
88 );
89 (StreamStartState::Invalid, Ok(Async::Ready(stream)))
90 }
91 }
92 Ok(Async::Ready(_)) => return Err(ProtocolError::InvalidToken.into()),
93 Ok(Async::NotReady) => (StreamStartState::RecvStart(stream), Ok(Async::NotReady)),
94 Err(e) => return Err(ProtocolError::from(e).into()),
95 },
96 StreamStartState::RecvFeatures(mut stream, stream_ns) => match stream.poll() {
97 Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
98 if stanza.is("features", NS_XMPP_STREAM) {
99 let stream =
100 XMPPStream::new(self.jid.clone(), stream, self.ns.clone(), stanza);
101 (StreamStartState::Invalid, Ok(Async::Ready(stream)))
102 } else {
103 (
104 StreamStartState::RecvFeatures(stream, stream_ns),
105 Ok(Async::NotReady),
106 )
107 }
108 }
109 Ok(Async::Ready(_)) | Ok(Async::NotReady) => (
110 StreamStartState::RecvFeatures(stream, stream_ns),
111 Ok(Async::NotReady),
112 ),
113 Err(e) => return Err(ProtocolError::from(e).into()),
114 },
115 StreamStartState::Invalid => unreachable!(),
116 };
117
118 self.state = new_state;
119 if retry {
120 self.poll()
121 } else {
122 result
123 }
124 }
125}