1use futures::sink;
2use futures::stream::Stream;
3use futures::{Async, Future, Poll, Sink};
4use xmpp_parsers::{Jid, Element};
5use native_tls::TlsConnector as NativeTlsConnector;
6use std::mem::replace;
7use tokio_io::{AsyncRead, AsyncWrite};
8use tokio_tls::{Connect, TlsConnector, TlsStream};
9
10use crate::xmpp_codec::Packet;
11use crate::xmpp_stream::XMPPStream;
12use crate::Error;
13
14/// XMPP TLS XML namespace
15pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
16
17/// XMPP stream that switches to TLS if available in received features
18pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
19 state: StartTlsClientState<S>,
20 jid: Jid,
21}
22
23enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
24 Invalid,
25 SendStartTls(sink::Send<XMPPStream<S>>),
26 AwaitProceed(XMPPStream<S>),
27 StartingTls(Connect<S>),
28}
29
30impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
31 /// Waits for <stream:features>
32 pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
33 let jid = xmpp_stream.jid.clone();
34
35 let nonza = Element::builder("starttls").ns(NS_XMPP_TLS).build();
36 let packet = Packet::Stanza(nonza);
37 let send = xmpp_stream.send(packet);
38
39 StartTlsClient {
40 state: StartTlsClientState::SendStartTls(send),
41 jid,
42 }
43 }
44}
45
46impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
47 type Item = TlsStream<S>;
48 type Error = Error;
49
50 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
51 let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
52 let mut retry = false;
53
54 let (new_state, result) = match old_state {
55 StartTlsClientState::SendStartTls(mut send) => match send.poll() {
56 Ok(Async::Ready(xmpp_stream)) => {
57 let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
58 retry = true;
59 (new_state, Ok(Async::NotReady))
60 }
61 Ok(Async::NotReady) => {
62 (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady))
63 }
64 Err(e) => (StartTlsClientState::SendStartTls(send), Err(e.into())),
65 },
66 StartTlsClientState::AwaitProceed(mut xmpp_stream) => match xmpp_stream.poll() {
67 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
68 if stanza.name() == "proceed" =>
69 {
70 let stream = xmpp_stream.stream.into_inner();
71 let connect =
72 TlsConnector::from(NativeTlsConnector::builder().build().unwrap())
73 .connect(&self.jid.clone().domain(), stream);
74 let new_state = StartTlsClientState::StartingTls(connect);
75 retry = true;
76 (new_state, Ok(Async::NotReady))
77 }
78 Ok(Async::Ready(_value)) => {
79 // println!("StartTlsClient ignore {:?}", _value);
80 (
81 StartTlsClientState::AwaitProceed(xmpp_stream),
82 Ok(Async::NotReady),
83 )
84 }
85 Ok(_) => (
86 StartTlsClientState::AwaitProceed(xmpp_stream),
87 Ok(Async::NotReady),
88 ),
89 Err(e) => (
90 StartTlsClientState::AwaitProceed(xmpp_stream),
91 Err(Error::Protocol(e.into())),
92 ),
93 },
94 StartTlsClientState::StartingTls(mut connect) => match connect.poll() {
95 Ok(Async::Ready(tls_stream)) => {
96 (StartTlsClientState::Invalid, Ok(Async::Ready(tls_stream)))
97 }
98 Ok(Async::NotReady) => (
99 StartTlsClientState::StartingTls(connect),
100 Ok(Async::NotReady),
101 ),
102 Err(e) => (StartTlsClientState::Invalid, Err(e.into())),
103 },
104 StartTlsClientState::Invalid => unreachable!(),
105 };
106
107 self.state = new_state;
108 if retry {
109 self.poll()
110 } else {
111 result
112 }
113 }
114}