1use std::mem::replace;
2use futures::{Future, Sink, Poll, Async};
3use futures::stream::Stream;
4use futures::sink;
5use tokio_io::{AsyncRead, AsyncWrite};
6use tokio_tls::{TlsStream, TlsConnectorExt, ConnectAsync};
7use native_tls::TlsConnector;
8use minidom::Element;
9use jid::Jid;
10
11use xmpp_codec::Packet;
12use xmpp_stream::XMPPStream;
13
14
15pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
16
17
18pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
19 state: StartTlsClientState<S>,
20 jid: Jid,
21}
22
23enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
24 Invalid,
25 SendStartTls(sink::Send<XMPPStream<S>>),
26 AwaitProceed(XMPPStream<S>),
27 StartingTls(ConnectAsync<S>),
28}
29
30impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
31 /// Waits for <stream:features>
32 pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
33 let jid = xmpp_stream.jid.clone();
34
35 let nonza = Element::builder("starttls")
36 .ns(NS_XMPP_TLS)
37 .build();
38 let packet = Packet::Stanza(nonza);
39 let send = xmpp_stream.send(packet);
40
41 StartTlsClient {
42 state: StartTlsClientState::SendStartTls(send),
43 jid,
44 }
45 }
46}
47
48impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
49 type Item = TlsStream<S>;
50 type Error = String;
51
52 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
53 let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
54 let mut retry = false;
55
56 let (new_state, result) = match old_state {
57 StartTlsClientState::SendStartTls(mut send) =>
58 match send.poll() {
59 Ok(Async::Ready(xmpp_stream)) => {
60 let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
61 retry = true;
62 (new_state, Ok(Async::NotReady))
63 },
64 Ok(Async::NotReady) =>
65 (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
66 Err(e) =>
67 (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))),
68 },
69 StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
70 match xmpp_stream.poll() {
71 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
72 if stanza.name() == "proceed" =>
73 {
74 let stream = xmpp_stream.stream.into_inner();
75 let connect = TlsConnector::builder().unwrap()
76 .build().unwrap()
77 .connect_async(&self.jid.domain, stream);
78 let new_state = StartTlsClientState::StartingTls(connect);
79 retry = true;
80 (new_state, Ok(Async::NotReady))
81 },
82 Ok(Async::Ready(value)) => {
83 println!("StartTlsClient ignore {:?}", value);
84 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
85 },
86 Ok(_) =>
87 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
88 Err(e) =>
89 (StartTlsClientState::AwaitProceed(xmpp_stream), Err(format!("{}", e))),
90 },
91 StartTlsClientState::StartingTls(mut connect) =>
92 match connect.poll() {
93 Ok(Async::Ready(tls_stream)) =>
94 (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream))),
95 Ok(Async::NotReady) =>
96 (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
97 Err(e) =>
98 (StartTlsClientState::Invalid, Err(format!("{}", e))),
99 },
100 StartTlsClientState::Invalid =>
101 unreachable!(),
102 };
103
104 self.state = new_state;
105 if retry {
106 self.poll()
107 } else {
108 result
109 }
110 }
111}