1use futures::{sink, Async, Future, Poll, Stream};
2use minidom::Element;
3use sasl::client::mechanisms::{Anonymous, Plain, Scram};
4use sasl::client::Mechanism;
5use sasl::common::scram::{Sha1, Sha256};
6use sasl::common::Credentials;
7use std::mem::replace;
8use std::str::FromStr;
9use tokio_io::{AsyncRead, AsyncWrite};
10use try_from::TryFrom;
11use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
12
13use crate::stream_start::StreamStart;
14use crate::xmpp_codec::Packet;
15use crate::xmpp_stream::XMPPStream;
16use crate::{AuthError, Error, ProtocolError};
17
18const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
19
20pub struct ClientAuth<S: AsyncWrite> {
21 state: ClientAuthState<S>,
22 mechanism: Box<Mechanism>,
23}
24
25enum ClientAuthState<S: AsyncWrite> {
26 WaitSend(sink::Send<XMPPStream<S>>),
27 WaitRecv(XMPPStream<S>),
28 Start(StreamStart<S>),
29 Invalid,
30}
31
32impl<S: AsyncWrite> ClientAuth<S> {
33 pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
34 let mechs: Vec<Box<Mechanism>> = vec![
35 Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
36 Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
37 Box::new(Plain::from_credentials(creds).unwrap()),
38 Box::new(Anonymous::new()),
39 ];
40
41 let mech_names: Vec<String> = stream
42 .stream_features
43 .get_child("mechanisms", NS_XMPP_SASL)
44 .ok_or(AuthError::NoMechanism)?
45 .children()
46 .filter(|child| child.is("mechanism", NS_XMPP_SASL))
47 .map(|mech_el| mech_el.text())
48 .collect();
49 // println!("SASL mechanisms offered: {:?}", mech_names);
50
51 for mut mech in mechs {
52 let name = mech.name().to_owned();
53 if mech_names.iter().any(|name1| *name1 == name) {
54 // println!("SASL mechanism selected: {:?}", name);
55 let initial = mech.initial().map_err(AuthError::Sasl)?;
56 let mut this = ClientAuth {
57 state: ClientAuthState::Invalid,
58 mechanism: mech,
59 };
60 let mechanism = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
61 this.send(
62 stream,
63 Auth {
64 mechanism,
65 data: initial,
66 },
67 );
68 return Ok(this);
69 }
70 }
71
72 Err(AuthError::NoMechanism)?
73 }
74
75 fn send<N: Into<Element>>(&mut self, stream: XMPPStream<S>, nonza: N) {
76 let send = stream.send_stanza(nonza);
77
78 self.state = ClientAuthState::WaitSend(send);
79 }
80}
81
82impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
83 type Item = XMPPStream<S>;
84 type Error = Error;
85
86 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
87 let state = replace(&mut self.state, ClientAuthState::Invalid);
88
89 match state {
90 ClientAuthState::WaitSend(mut send) => match send.poll() {
91 Ok(Async::Ready(stream)) => {
92 self.state = ClientAuthState::WaitRecv(stream);
93 self.poll()
94 }
95 Ok(Async::NotReady) => {
96 self.state = ClientAuthState::WaitSend(send);
97 Ok(Async::NotReady)
98 }
99 Err(e) => Err(e)?,
100 },
101 ClientAuthState::WaitRecv(mut stream) => match stream.poll() {
102 Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
103 if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
104 let response = self
105 .mechanism
106 .response(&challenge.data)
107 .map_err(AuthError::Sasl)?;
108 self.send(stream, Response { data: response });
109 self.poll()
110 } else if let Ok(_) = Success::try_from(stanza.clone()) {
111 let start = stream.restart();
112 self.state = ClientAuthState::Start(start);
113 self.poll()
114 } else if let Ok(failure) = Failure::try_from(stanza) {
115 Err(AuthError::Fail(failure.defined_condition))?
116 } else {
117 Ok(Async::NotReady)
118 }
119 }
120 Ok(Async::Ready(_event)) => {
121 // println!("ClientAuth ignore {:?}", _event);
122 Ok(Async::NotReady)
123 }
124 Ok(_) => {
125 self.state = ClientAuthState::WaitRecv(stream);
126 Ok(Async::NotReady)
127 }
128 Err(e) => Err(ProtocolError::Parser(e))?,
129 },
130 ClientAuthState::Start(mut start) => match start.poll() {
131 Ok(Async::Ready(stream)) => Ok(Async::Ready(stream)),
132 Ok(Async::NotReady) => {
133 self.state = ClientAuthState::Start(start);
134 Ok(Async::NotReady)
135 }
136 Err(e) => Err(e),
137 },
138 ClientAuthState::Invalid => unreachable!(),
139 }
140 }
141}