1use std::mem::replace;
2use std::str::FromStr;
3use futures::{Future, Poll, Async, sink, Stream};
4use tokio_io::{AsyncRead, AsyncWrite};
5use sasl::common::Credentials;
6use sasl::common::scram::{Sha1, Sha256};
7use sasl::client::Mechanism;
8use sasl::client::mechanisms::{Scram, Plain, Anonymous};
9use minidom::Element;
10use xmpp_parsers::sasl::{Auth, Challenge, Response, Success, Failure, Mechanism as XMPPMechanism};
11use try_from::TryFrom;
12
13use xmpp_codec::Packet;
14use xmpp_stream::XMPPStream;
15use stream_start::StreamStart;
16use {Error, AuthError, ProtocolError};
17
18const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
19
20
21pub struct ClientAuth<S: AsyncWrite> {
22 state: ClientAuthState<S>,
23 mechanism: Box<Mechanism>,
24}
25
26enum ClientAuthState<S: AsyncWrite> {
27 WaitSend(sink::Send<XMPPStream<S>>),
28 WaitRecv(XMPPStream<S>),
29 Start(StreamStart<S>),
30 Invalid,
31}
32
33impl<S: AsyncWrite> ClientAuth<S> {
34 pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
35 let mechs: Vec<Box<Mechanism>> = vec![
36 Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
37 Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
38 Box::new(Plain::from_credentials(creds).unwrap()),
39 Box::new(Anonymous::new()),
40 ];
41
42 let mech_names: Vec<String> =
43 match stream.stream_features.get_child("mechanisms", NS_XMPP_SASL) {
44 None =>
45 return Err(AuthError::NoMechanism.into()),
46 Some(mechs) =>
47 mechs.children()
48 .filter(|child| child.is("mechanism", NS_XMPP_SASL))
49 .map(|mech_el| mech_el.text())
50 .collect(),
51 };
52 // println!("SASL mechanisms offered: {:?}", mech_names);
53
54 for mut mech in mechs {
55 let name = mech.name().to_owned();
56 if mech_names.iter().any(|name1| *name1 == name) {
57 // println!("SASL mechanism selected: {:?}", name);
58 let initial = match mech.initial() {
59 Ok(initial) => initial,
60 Err(e) => return Err(AuthError::Sasl(e).into()),
61 };
62 let mut this = ClientAuth {
63 state: ClientAuthState::Invalid,
64 mechanism: mech,
65 };
66 let mechanism = match XMPPMechanism::from_str(&name) {
67 Ok(mechanism) => mechanism,
68 Err(e) => return Err(ProtocolError::Parsers(e).into()),
69 };
70 this.send(
71 stream,
72 Auth {
73 mechanism,
74 data: initial,
75 }
76 );
77 return Ok(this);
78 }
79 }
80
81 Err(AuthError::NoMechanism.into())
82 }
83
84 fn send<N: Into<Element>>(&mut self, stream: XMPPStream<S>, nonza: N) {
85 let send = stream.send_stanza(nonza);
86
87 self.state = ClientAuthState::WaitSend(send);
88 }
89}
90
91impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
92 type Item = XMPPStream<S>;
93 type Error = Error;
94
95 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
96 let state = replace(&mut self.state, ClientAuthState::Invalid);
97
98 match state {
99 ClientAuthState::WaitSend(mut send) =>
100 match send.poll() {
101 Ok(Async::Ready(stream)) => {
102 self.state = ClientAuthState::WaitRecv(stream);
103 self.poll()
104 },
105 Ok(Async::NotReady) => {
106 self.state = ClientAuthState::WaitSend(send);
107 Ok(Async::NotReady)
108 },
109 Err(e) =>
110 Err(e.into()),
111 },
112 ClientAuthState::WaitRecv(mut stream) =>
113 match stream.poll() {
114 Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
115 if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
116 let response = self.mechanism.response(&challenge.data)
117 .map_err(AuthError::Sasl)?;
118 self.send(stream, Response { data: response });
119 self.poll()
120 } else if let Ok(_) = Success::try_from(stanza.clone()) {
121 let start = stream.restart();
122 self.state = ClientAuthState::Start(start);
123 self.poll()
124 } else if let Ok(failure) = Failure::try_from(stanza) {
125 Err(AuthError::Fail(failure.defined_condition).into())
126 } else {
127 Ok(Async::NotReady)
128 }
129 }
130 Ok(Async::Ready(_event)) => {
131 // println!("ClientAuth ignore {:?}", _event);
132 Ok(Async::NotReady)
133 },
134 Ok(_) => {
135 self.state = ClientAuthState::WaitRecv(stream);
136 Ok(Async::NotReady)
137 },
138 Err(e) =>
139 Err(ProtocolError::Parser(e).into())
140 },
141 ClientAuthState::Start(mut start) =>
142 match start.poll() {
143 Ok(Async::Ready(stream)) =>
144 Ok(Async::Ready(stream)),
145 Ok(Async::NotReady) => {
146 self.state = ClientAuthState::Start(start);
147 Ok(Async::NotReady)
148 },
149 Err(e) =>
150 Err(e.into())
151 },
152 ClientAuthState::Invalid =>
153 unreachable!(),
154 }
155 }
156}