1use std::mem::replace;
2use std::str::FromStr;
3use futures::{sink, Async, Future, Poll, Stream, future::{ok, err, IntoFuture}};
4use minidom::Element;
5use sasl::client::mechanisms::{Anonymous, Plain, Scram};
6use sasl::client::Mechanism;
7use sasl::common::scram::{Sha1, Sha256};
8use sasl::common::Credentials;
9use tokio_io::{AsyncRead, AsyncWrite};
10use try_from::TryFrom;
11use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
12
13use crate::xmpp_codec::Packet;
14use crate::xmpp_stream::XMPPStream;
15use crate::{AuthError, Error, ProtocolError};
16
17const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
18
19pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
20 future: Box<Future<Item = XMPPStream<S>, Error = Error>>,
21}
22
23impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
24 pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
25 let mechs: Vec<Box<Mechanism>> = vec![
26 // TODO: Box::new(|| …
27 Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
28 Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
29 Box::new(Plain::from_credentials(creds).unwrap()),
30 Box::new(Anonymous::new()),
31 ];
32
33 let mech_names: Vec<String> = stream
34 .stream_features
35 .get_child("mechanisms", NS_XMPP_SASL)
36 .ok_or(AuthError::NoMechanism)?
37 .children()
38 .filter(|child| child.is("mechanism", NS_XMPP_SASL))
39 .map(|mech_el| mech_el.text())
40 .collect();
41 // TODO: iter instead of collect()
42 // println!("SASL mechanisms offered: {:?}", mech_names);
43
44 for mut mechanism in mechs {
45 let name = mechanism.name().to_owned();
46 if mech_names.iter().any(|name1| *name1 == name) {
47 // println!("SASL mechanism selected: {:?}", name);
48 let initial = mechanism.initial().map_err(AuthError::Sasl)?;
49 let mechanism_name = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
50
51 let send_initial = Box::new(stream.send_stanza(Auth {
52 mechanism: mechanism_name,
53 data: initial,
54 }))
55 .map_err(Error::Io);
56 let future = Box::new(send_initial.and_then(
57 |stream| Self::handle_challenge(stream, mechanism)
58 ).and_then(
59 |stream| stream.restart()
60 ));
61 return Ok(ClientAuth {
62 future,
63 });
64 }
65 }
66
67 Err(AuthError::NoMechanism)?
68 }
69
70 fn handle_challenge(stream: XMPPStream<S>, mut mechanism: Box<Mechanism>) -> Box<Future<Item = XMPPStream<S>, Error = Error>> {
71 Box::new(
72 stream.into_future()
73 .map_err(|(e, _stream)| e.into())
74 .and_then(|(stanza, stream)| {
75 match stanza {
76 Some(Packet::Stanza(stanza)) => {
77 if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
78 let response = mechanism
79 .response(&challenge.data);
80 Box::new(
81 response
82 .map_err(|e| AuthError::Sasl(e).into())
83 .into_future()
84 .and_then(|response| {
85 // Send response and loop
86 stream.send_stanza(Response { data: response })
87 .map_err(Error::Io)
88 .and_then(|stream| Self::handle_challenge(stream, mechanism))
89 })
90 )
91 } else if let Ok(_) = Success::try_from(stanza.clone()) {
92 Box::new(ok(stream))
93 } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
94 Box::new(err(Error::Auth(AuthError::Fail(failure.defined_condition))))
95 } else {
96 // ignore and loop
97 println!("Ignore: {:?}", stanza);
98 Self::handle_challenge(stream, mechanism)
99 }
100 }
101 Some(_) => {
102 // ignore and loop
103 Self::handle_challenge(stream, mechanism)
104 }
105 None => Box::new(err(Error::Disconnected))
106 }
107 })
108 )
109 }
110}
111
112impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
113 type Item = XMPPStream<S>;
114 type Error = Error;
115
116 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
117 self.future.poll()
118 }
119}