1use std::mem::replace;
2use futures::{Future, Sink, Poll, Async};
3use futures::stream::Stream;
4use futures::sink;
5use tokio_io::{AsyncRead, AsyncWrite};
6use tokio_tls::{TlsStream, TlsConnectorExt, ConnectAsync};
7use native_tls::TlsConnector;
8use minidom::Element;
9use jid::Jid;
10
11use xmpp_codec::Packet;
12use xmpp_stream::XMPPStream;
13
14/// XMPP TLS XML namespace
15pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
16
17
18/// XMPP stream that switches to TLS if available in received features
19pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
20 state: StartTlsClientState<S>,
21 jid: Jid,
22}
23
24enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
25 Invalid,
26 SendStartTls(sink::Send<XMPPStream<S>>),
27 AwaitProceed(XMPPStream<S>),
28 StartingTls(ConnectAsync<S>),
29}
30
31impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
32 /// Waits for <stream:features>
33 pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
34 let jid = xmpp_stream.jid.clone();
35
36 let nonza = Element::builder("starttls")
37 .ns(NS_XMPP_TLS)
38 .build();
39 let packet = Packet::Stanza(nonza);
40 let send = xmpp_stream.send(packet);
41
42 StartTlsClient {
43 state: StartTlsClientState::SendStartTls(send),
44 jid,
45 }
46 }
47}
48
49impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
50 type Item = TlsStream<S>;
51 type Error = String;
52
53 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
54 let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
55 let mut retry = false;
56
57 let (new_state, result) = match old_state {
58 StartTlsClientState::SendStartTls(mut send) =>
59 match send.poll() {
60 Ok(Async::Ready(xmpp_stream)) => {
61 let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
62 retry = true;
63 (new_state, Ok(Async::NotReady))
64 },
65 Ok(Async::NotReady) =>
66 (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
67 Err(e) =>
68 (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))),
69 },
70 StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
71 match xmpp_stream.poll() {
72 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
73 if stanza.name() == "proceed" =>
74 {
75 let stream = xmpp_stream.stream.into_inner();
76 let connect = TlsConnector::builder().unwrap()
77 .build().unwrap()
78 .connect_async(&self.jid.domain, stream);
79 let new_state = StartTlsClientState::StartingTls(connect);
80 retry = true;
81 (new_state, Ok(Async::NotReady))
82 },
83 Ok(Async::Ready(value)) => {
84 println!("StartTlsClient ignore {:?}", value);
85 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
86 },
87 Ok(_) =>
88 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
89 Err(e) =>
90 (StartTlsClientState::AwaitProceed(xmpp_stream), Err(format!("{}", e))),
91 },
92 StartTlsClientState::StartingTls(mut connect) =>
93 match connect.poll() {
94 Ok(Async::Ready(tls_stream)) =>
95 (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream))),
96 Ok(Async::NotReady) =>
97 (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
98 Err(e) =>
99 (StartTlsClientState::Invalid, Err(format!("{}", e))),
100 },
101 StartTlsClientState::Invalid =>
102 unreachable!(),
103 };
104
105 self.state = new_state;
106 if retry {
107 self.poll()
108 } else {
109 result
110 }
111 }
112}