1use std::mem::replace;
2use std::str::FromStr;
3use futures::{Future, Poll, Async, sink, Stream};
4use tokio_io::{AsyncRead, AsyncWrite};
5use sasl::common::Credentials;
6use sasl::common::scram::{Sha1, Sha256};
7use sasl::client::Mechanism;
8use sasl::client::mechanisms::{Scram, Plain, Anonymous};
9use minidom::Element;
10use xmpp_parsers::sasl::{Auth, Challenge, Response, Success, Failure, Mechanism as XMPPMechanism};
11use try_from::TryFrom;
12
13use xmpp_codec::Packet;
14use xmpp_stream::XMPPStream;
15use stream_start::StreamStart;
16
17const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
18
19pub struct ClientAuth<S: AsyncWrite> {
20 state: ClientAuthState<S>,
21 mechanism: Box<Mechanism>,
22}
23
24enum ClientAuthState<S: AsyncWrite> {
25 WaitSend(sink::Send<XMPPStream<S>>),
26 WaitRecv(XMPPStream<S>),
27 Start(StreamStart<S>),
28 Invalid,
29}
30
31impl<S: AsyncWrite> ClientAuth<S> {
32 pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, String> {
33 let mechs: Vec<Box<Mechanism>> = vec![
34 Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
35 Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
36 Box::new(Plain::from_credentials(creds).unwrap()),
37 Box::new(Anonymous::new()),
38 ];
39
40 let mech_names: Vec<String> =
41 match stream.stream_features.get_child("mechanisms", NS_XMPP_SASL) {
42 None =>
43 return Err("No auth mechanisms".to_owned()),
44 Some(mechs) =>
45 mechs.children()
46 .filter(|child| child.is("mechanism", NS_XMPP_SASL))
47 .map(|mech_el| mech_el.text())
48 .collect(),
49 };
50 // println!("SASL mechanisms offered: {:?}", mech_names);
51
52 for mut mech in mechs {
53 let name = mech.name().to_owned();
54 if mech_names.iter().any(|name1| *name1 == name) {
55 // println!("SASL mechanism selected: {:?}", name);
56 let initial = mech.initial()?;
57 let mut this = ClientAuth {
58 state: ClientAuthState::Invalid,
59 mechanism: mech,
60 };
61 let mechanism = XMPPMechanism::from_str(&name)
62 .map_err(|e| format!("{:?}", e))?;
63 this.send(
64 stream,
65 Auth {
66 mechanism,
67 data: initial,
68 }
69 );
70 return Ok(this);
71 }
72 }
73
74 Err("No supported SASL mechanism available".to_owned())
75 }
76
77 fn send<N: Into<Element>>(&mut self, stream: XMPPStream<S>, nonza: N) {
78 let send = stream.send_stanza(nonza);
79
80 self.state = ClientAuthState::WaitSend(send);
81 }
82}
83
84impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
85 type Item = XMPPStream<S>;
86 type Error = String;
87
88 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
89 let state = replace(&mut self.state, ClientAuthState::Invalid);
90
91 match state {
92 ClientAuthState::WaitSend(mut send) =>
93 match send.poll() {
94 Ok(Async::Ready(stream)) => {
95 self.state = ClientAuthState::WaitRecv(stream);
96 self.poll()
97 },
98 Ok(Async::NotReady) => {
99 self.state = ClientAuthState::WaitSend(send);
100 Ok(Async::NotReady)
101 },
102 Err(e) =>
103 Err(format!("{}", e)),
104 },
105 ClientAuthState::WaitRecv(mut stream) =>
106 match stream.poll() {
107 Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
108 if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
109 let response = self.mechanism.response(&challenge.data)?;
110 self.send(stream, Response { data: response });
111 self.poll()
112 } else if let Ok(_) = Success::try_from(stanza.clone()) {
113 let start = stream.restart();
114 self.state = ClientAuthState::Start(start);
115 self.poll()
116 } else if let Ok(failure) = Failure::try_from(stanza) {
117 let e = format!("{:?}", failure.defined_condition);
118 Err(e)
119 } else {
120 Ok(Async::NotReady)
121 }
122 }
123 Ok(Async::Ready(_event)) => {
124 // println!("ClientAuth ignore {:?}", _event);
125 Ok(Async::NotReady)
126 },
127 Ok(_) => {
128 self.state = ClientAuthState::WaitRecv(stream);
129 Ok(Async::NotReady)
130 },
131 Err(e) =>
132 Err(format!("{}", e)),
133 },
134 ClientAuthState::Start(mut start) =>
135 match start.poll() {
136 Ok(Async::Ready(stream)) =>
137 Ok(Async::Ready(stream)),
138 Ok(Async::NotReady) => {
139 self.state = ClientAuthState::Start(start);
140 Ok(Async::NotReady)
141 },
142 Err(e) =>
143 Err(format!("{}", e)),
144 },
145 ClientAuthState::Invalid =>
146 unreachable!(),
147 }
148 }
149}