1use std::mem::replace;
2use futures::{Future, Sink, Poll, Async};
3use futures::stream::Stream;
4use futures::sink;
5use tokio_io::{AsyncRead, AsyncWrite};
6use tokio_tls::*;
7use native_tls::TlsConnector;
8use xml;
9use jid::Jid;
10
11use xmpp_codec::*;
12use xmpp_stream::*;
13use stream_start::StreamStart;
14
15
16pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
17
18
19pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
20 state: StartTlsClientState<S>,
21 jid: Jid,
22}
23
24enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
25 Invalid,
26 SendStartTls(sink::Send<XMPPStream<S>>),
27 AwaitProceed(XMPPStream<S>),
28 StartingTls(ConnectAsync<S>),
29 Start(StreamStart<TlsStream<S>>),
30}
31
32impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
33 /// Waits for <stream:features>
34 pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
35 let jid = xmpp_stream.jid.clone();
36
37 let nonza = xml::Element::new(
38 "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
39 vec![]
40 );
41 let packet = Packet::Stanza(nonza);
42 let send = xmpp_stream.send(packet);
43
44 StartTlsClient {
45 state: StartTlsClientState::SendStartTls(send),
46 jid,
47 }
48 }
49}
50
51impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
52 type Item = XMPPStream<TlsStream<S>>;
53 type Error = String;
54
55 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
56 let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
57 let mut retry = false;
58
59 let (new_state, result) = match old_state {
60 StartTlsClientState::SendStartTls(mut send) =>
61 match send.poll() {
62 Ok(Async::Ready(xmpp_stream)) => {
63 let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
64 retry = true;
65 (new_state, Ok(Async::NotReady))
66 },
67 Ok(Async::NotReady) =>
68 (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
69 Err(e) =>
70 (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))),
71 },
72 StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
73 match xmpp_stream.poll() {
74 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
75 if stanza.name == "proceed" =>
76 {
77 let stream = xmpp_stream.stream.into_inner();
78 let connect = TlsConnector::builder().unwrap()
79 .build().unwrap()
80 .connect_async(&self.jid.domain, stream);
81 let new_state = StartTlsClientState::StartingTls(connect);
82 retry = true;
83 (new_state, Ok(Async::NotReady))
84 },
85 Ok(Async::Ready(value)) => {
86 println!("StartTlsClient ignore {:?}", value);
87 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
88 },
89 Ok(_) =>
90 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
91 Err(e) =>
92 (StartTlsClientState::AwaitProceed(xmpp_stream), Err(format!("{}", e))),
93 },
94 StartTlsClientState::StartingTls(mut connect) =>
95 match connect.poll() {
96 Ok(Async::Ready(tls_stream)) => {
97 println!("TLS stream established");
98 let start = XMPPStream::from_stream(tls_stream, self.jid.clone());
99 let new_state = StartTlsClientState::Start(start);
100 retry = true;
101 (new_state, Ok(Async::NotReady))
102 },
103 Ok(Async::NotReady) =>
104 (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
105 Err(e) =>
106 (StartTlsClientState::StartingTls(connect), Err(format!("{}", e))),
107 },
108 StartTlsClientState::Start(mut start) =>
109 match start.poll() {
110 Ok(Async::Ready(xmpp_stream)) =>
111 (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))),
112 Ok(Async::NotReady) =>
113 (StartTlsClientState::Start(start), Ok(Async::NotReady)),
114 Err(e) =>
115 (StartTlsClientState::Invalid, Err(format!("{}", e))),
116 },
117 StartTlsClientState::Invalid =>
118 unreachable!(),
119 };
120
121 self.state = new_state;
122 if retry {
123 self.poll()
124 } else {
125 result
126 }
127 }
128}