1use std::mem::replace;
2use futures::{Future, Sink, Poll, Async};
3use futures::stream::Stream;
4use futures::sink;
5use tokio_io::{AsyncRead, AsyncWrite};
6use tokio_tls::{TlsStream, TlsConnector, Connect};
7use native_tls::TlsConnector as NativeTlsConnector;
8use minidom::Element;
9use jid::Jid;
10
11use xmpp_codec::Packet;
12use xmpp_stream::XMPPStream;
13use Error;
14
15/// XMPP TLS XML namespace
16pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
17
18
19/// XMPP stream that switches to TLS if available in received features
20pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
21 state: StartTlsClientState<S>,
22 jid: Jid,
23}
24
25enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
26 Invalid,
27 SendStartTls(sink::Send<XMPPStream<S>>),
28 AwaitProceed(XMPPStream<S>),
29 StartingTls(Connect<S>),
30}
31
32impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
33 /// Waits for <stream:features>
34 pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
35 let jid = xmpp_stream.jid.clone();
36
37 let nonza = Element::builder("starttls")
38 .ns(NS_XMPP_TLS)
39 .build();
40 let packet = Packet::Stanza(nonza);
41 let send = xmpp_stream.send(packet);
42
43 StartTlsClient {
44 state: StartTlsClientState::SendStartTls(send),
45 jid,
46 }
47 }
48}
49
50impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
51 type Item = TlsStream<S>;
52 type Error = Error;
53
54 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
55 let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
56 let mut retry = false;
57
58 let (new_state, result) = match old_state {
59 StartTlsClientState::SendStartTls(mut send) =>
60 match send.poll() {
61 Ok(Async::Ready(xmpp_stream)) => {
62 let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
63 retry = true;
64 (new_state, Ok(Async::NotReady))
65 },
66 Ok(Async::NotReady) =>
67 (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
68 Err(e) =>
69 (StartTlsClientState::SendStartTls(send), Err(e.into())),
70 },
71 StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
72 match xmpp_stream.poll() {
73 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
74 if stanza.name() == "proceed" =>
75 {
76 let stream = xmpp_stream.stream.into_inner();
77 let connect = TlsConnector::from(NativeTlsConnector::builder()
78 .build().unwrap())
79 .connect(&self.jid.domain, stream);
80 let new_state = StartTlsClientState::StartingTls(connect);
81 retry = true;
82 (new_state, Ok(Async::NotReady))
83 },
84 Ok(Async::Ready(value)) => {
85 println!("StartTlsClient ignore {:?}", value);
86 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
87 },
88 Ok(_) =>
89 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
90 Err(e) =>
91 (StartTlsClientState::AwaitProceed(xmpp_stream), Err(Error::Protocol(e.into()))),
92 },
93 StartTlsClientState::StartingTls(mut connect) =>
94 match connect.poll() {
95 Ok(Async::Ready(tls_stream)) =>
96 (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream))),
97 Ok(Async::NotReady) =>
98 (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
99 Err(e) =>
100 (StartTlsClientState::Invalid, Err(e.into())),
101 },
102 StartTlsClientState::Invalid =>
103 unreachable!(),
104 };
105
106 self.state = new_state;
107 if retry {
108 self.poll()
109 } else {
110 result
111 }
112 }
113}