1use std::mem::replace;
2use futures::{Future, Sink, Poll, Async};
3use futures::stream::Stream;
4use futures::sink;
5use tokio_io::{AsyncRead, AsyncWrite};
6use tokio_tls::*;
7use native_tls::TlsConnector;
8use xml;
9use jid::Jid;
10
11use xmpp_codec::*;
12use xmpp_stream::*;
13use stream_start::StreamStart;
14
15
16pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
17
18pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
19 state: StartTlsClientState<S>,
20 jid: Jid,
21}
22
23enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
24 Invalid,
25 SendStartTls(sink::Send<XMPPStream<S>>),
26 AwaitProceed(XMPPStream<S>),
27 StartingTls(ConnectAsync<S>),
28 Start(StreamStart<TlsStream<S>>),
29}
30
31impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
32 /// Waits for <stream:features>
33 pub fn from_stream(xmpp_stream: XMPPStream<S>) -> Self {
34 let jid = xmpp_stream.jid.clone();
35
36 let nonza = xml::Element::new(
37 "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
38 vec![]
39 );
40 println!("send {}", nonza);
41 let packet = Packet::Stanza(nonza);
42 let send = xmpp_stream.send(packet);
43
44 StartTlsClient {
45 state: StartTlsClientState::SendStartTls(send),
46 jid,
47 }
48 }
49}
50
51impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
52 type Item = XMPPStream<TlsStream<S>>;
53 type Error = String;
54
55 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
56 let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
57 let mut retry = false;
58
59 let (new_state, result) = match old_state {
60 StartTlsClientState::SendStartTls(mut send) =>
61 match send.poll() {
62 Ok(Async::Ready(xmpp_stream)) => {
63 println!("starttls sent");
64 let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
65 retry = true;
66 (new_state, Ok(Async::NotReady))
67 },
68 Ok(Async::NotReady) =>
69 (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
70 Err(e) =>
71 (StartTlsClientState::SendStartTls(send), Err(format!("{}", e))),
72 },
73 StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
74 match xmpp_stream.poll() {
75 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
76 if stanza.name == "proceed" =>
77 {
78 println!("* proceed *");
79 let stream = xmpp_stream.stream.into_inner();
80 let connect = TlsConnector::builder().unwrap()
81 .build().unwrap()
82 .connect_async(&self.jid.domain, stream);
83 let new_state = StartTlsClientState::StartingTls(connect);
84 retry = true;
85 (new_state, Ok(Async::NotReady))
86 },
87 Ok(Async::Ready(value)) => {
88 println!("StartTlsClient ignore {:?}", value);
89 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
90 },
91 Ok(_) =>
92 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
93 Err(e) =>
94 (StartTlsClientState::AwaitProceed(xmpp_stream), Err(format!("{}", e))),
95 },
96 StartTlsClientState::StartingTls(mut connect) =>
97 match connect.poll() {
98 Ok(Async::Ready(tls_stream)) => {
99 println!("Got a TLS stream!");
100 let start = XMPPStream::from_stream(tls_stream, self.jid.clone());
101 let new_state = StartTlsClientState::Start(start);
102 retry = true;
103 (new_state, Ok(Async::NotReady))
104 },
105 Ok(Async::NotReady) =>
106 (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
107 Err(e) =>
108 (StartTlsClientState::StartingTls(connect), Err(format!("{}", e))),
109 },
110 StartTlsClientState::Start(mut start) =>
111 match start.poll() {
112 Ok(Async::Ready(xmpp_stream)) =>
113 (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))),
114 Ok(Async::NotReady) =>
115 (StartTlsClientState::Start(start), Ok(Async::NotReady)),
116 Err(e) =>
117 (StartTlsClientState::Invalid, Err(format!("{}", e))),
118 },
119 StartTlsClientState::Invalid =>
120 unreachable!(),
121 };
122
123 self.state = new_state;
124 if retry {
125 self.poll()
126 } else {
127 result
128 }
129 }
130}