1use futures::sink;
2use futures::stream::Stream;
3use futures::{Async, Future, Poll, Sink};
4use jid::Jid;
5use minidom::Element;
6use native_tls::TlsConnector as NativeTlsConnector;
7use std::mem::replace;
8use tokio_io::{AsyncRead, AsyncWrite};
9use tokio_tls::{Connect, TlsConnector, TlsStream};
10
11use crate::xmpp_codec::Packet;
12use crate::xmpp_stream::XMPPStream;
13use crate::Error;
14
15/// XMPP TLS XML namespace
16pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
17
18/// XMPP stream that switches to TLS if available in received features
19pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
20 state: StartTlsClientState<S>,
21 jid: Jid,
22}
23
24enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
25 Invalid,
26 SendStartTls(sink::Send<XMPPStream<S>>),
27 AwaitProceed(XMPPStream<S>),
28 StartingTls(Connect<S>),
29}
30
31impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
32 /// Waits for <stream:features>
33 pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
34 let jid = xmpp_stream.jid.clone();
35
36 let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
37 let packet = Packet::Stanza(nonza);
38 let send = xmpp_stream.send(packet);
39
40 StartTlsClient {
41 state: StartTlsClientState::SendStartTls(send),
42 jid,
43 }
44 }
45}
46
47impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
48 type Item = TlsStream<S>;
49 type Error = Error;
50
51 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
52 let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
53 let mut retry = false;
54
55 let (new_state, result) = match old_state {
56 StartTlsClientState::SendStartTls(mut send) => match send.poll() {
57 Ok(Async::Ready(xmpp_stream)) => {
58 let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
59 retry = true;
60 (new_state, Ok(Async::NotReady))
61 }
62 Ok(Async::NotReady) => {
63 (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady))
64 }
65 Err(e) => (StartTlsClientState::SendStartTls(send), Err(e.into())),
66 },
67 StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() {
68 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
69 if stanza.name() == "proceed" =>
70 {
71 let stream = xmpp_stream.stream.into_inner();
72 let connect =
73 TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
74 .connect(&self.jid.domain, stream);
75 let new_state = StartTlsClientState::StartingTls(connect);
76 retry = true;
77 (new_state, Ok(Async::NotReady))
78 }
79 Ok(Async::Ready(value)) => {
80 println!("StartTlsClient ignore {:?}", value);
81 (
82 StartTlsClientState::AwaitProceed(xmpp_stream),
83 Ok(Async::NotReady),
84 )
85 }
86 Ok(_) => (
87 StartTlsClientState::AwaitProceed(xmpp_stream),
88 Ok(Async::NotReady),
89 ),
90 Err(e) => (
91 StartTlsClientState::AwaitProceed(xmpp_stream),
92 Err(Error::Protocol(e.into())),
93 ),
94 },
95 StartTlsClientState::StartingTls(mut connect) => match connect.poll() {
96 Ok(Async::Ready(tls_stream)) => {
97 (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream)))
98 }
99 Ok(Async::NotReady) => (
100 StartTlsClientState::StartingTls(connect),
101 Ok(Async::NotReady),
102 ),
103 Err(e) => (StartTlsClientState::Invalid, Err(e.into())),
104 },
105 StartTlsClientState::Invalid => unreachable!(),
106 };
107
108 self.state = new_state;
109 if retry {
110 self.poll()
111 } else {
112 result
113 }
114 }
115}