1use std::mem::replace;
2use std::io::Error;
3use std::sync::Arc;
4use futures::{Future, Sink, Poll, Async};
5use futures::stream::Stream;
6use futures::sink;
7use tokio_io::{AsyncRead, AsyncWrite};
8use rustls::*;
9use tokio_rustls::*;
10use xml;
11
12use xmpp_codec::*;
13use xmpp_stream::*;
14use stream_start::StreamStart;
15
16
17pub const NS_XMPP_TLS: &str = "urn:ietf:params:xml:ns:xmpp-tls";
18
19pub struct StartTlsClient<S: AsyncRead + AsyncWrite> {
20 state: StartTlsClientState<S>,
21 arc_config: Arc<ClientConfig>,
22}
23
24enum StartTlsClientState<S: AsyncRead + AsyncWrite> {
25 Invalid,
26 SendStartTls(sink::Send<XMPPStream<S>>),
27 AwaitProceed(XMPPStream<S>),
28 StartingTls(ConnectAsync<S>),
29 Start(StreamStart<TlsStream<S, ClientSession>>),
30}
31
32impl<S: AsyncRead + AsyncWrite> StartTlsClient<S> {
33 /// Waits for <stream:features>
34 pub fn from_stream(xmpp_stream: XMPPStream<S>, arc_config: Arc<ClientConfig>) -> Self {
35 let nonza = xml::Element::new(
36 "starttls".to_owned(), Some(NS_XMPP_TLS.to_owned()),
37 vec![]
38 );
39 println!("send {}", nonza);
40 let packet = Packet::Stanza(nonza);
41 let send = xmpp_stream.send(packet);
42
43 StartTlsClient {
44 state: StartTlsClientState::SendStartTls(send),
45 arc_config: arc_config,
46 }
47 }
48}
49
50impl<S: AsyncRead + AsyncWrite> Future for StartTlsClient<S> {
51 type Item = XMPPStream<TlsStream<S, ClientSession>>;
52 type Error = Error;
53
54 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
55 let old_state = replace(&mut self.state, StartTlsClientState::Invalid);
56 let mut retry = false;
57
58 let (new_state, result) = match old_state {
59 StartTlsClientState::SendStartTls(mut send) =>
60 match send.poll() {
61 Ok(Async::Ready(xmpp_stream)) => {
62 println!("starttls sent");
63 let new_state = StartTlsClientState::AwaitProceed(xmpp_stream);
64 retry = true;
65 (new_state, Ok(Async::NotReady))
66 },
67 Ok(Async::NotReady) =>
68 (StartTlsClientState::SendStartTls(send), Ok(Async::NotReady)),
69 Err(e) =>
70 (StartTlsClientState::SendStartTls(send), Err(e)),
71 },
72 StartTlsClientState::AwaitProceed(mut xmpp_stream) =>
73 match xmpp_stream.poll() {
74 Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
75 if stanza.name == "proceed" =>
76 {
77 println!("* proceed *");
78 let stream = xmpp_stream.into_inner();
79 let connect = self.arc_config.connect_async("spaceboyz.net", stream);
80 let new_state = StartTlsClientState::StartingTls(connect);
81 retry = true;
82 (new_state, Ok(Async::NotReady))
83 },
84 Ok(Async::Ready(value)) => {
85 println!("StartTlsClient ignore {:?}", value);
86 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady))
87 },
88 Ok(_) =>
89 (StartTlsClientState::AwaitProceed(xmpp_stream), Ok(Async::NotReady)),
90 Err(e) =>
91 (StartTlsClientState::AwaitProceed(xmpp_stream), Err(e)),
92 },
93 StartTlsClientState::StartingTls(mut connect) =>
94 match connect.poll() {
95 Ok(Async::Ready(tls_stream)) => {
96 println!("Got a TLS stream!");
97 let start = XMPPStream::from_stream(tls_stream, "spaceboyz.net".to_owned());
98 let new_state = StartTlsClientState::Start(start);
99 retry = true;
100 (new_state, Ok(Async::NotReady))
101 },
102 Ok(Async::NotReady) =>
103 (StartTlsClientState::StartingTls(connect), Ok(Async::NotReady)),
104 Err(e) =>
105 (StartTlsClientState::StartingTls(connect), Err(e)),
106 },
107 StartTlsClientState::Start(mut start) =>
108 match start.poll() {
109 Ok(Async::Ready(xmpp_stream)) =>
110 (StartTlsClientState::Invalid, Ok(Async::Ready(xmpp_stream))),
111 Ok(Async::NotReady) =>
112 (StartTlsClientState::Start(start), Ok(Async::NotReady)),
113 Err(e) =>
114 (StartTlsClientState::Invalid, Err(e)),
115 },
116 StartTlsClientState::Invalid =>
117 unreachable!(),
118 };
119
120 self.state = new_state;
121 if retry {
122 self.poll()
123 } else {
124 result
125 }
126 }
127}