1use std::str::FromStr;
2use std::collections::HashSet;
3use futures::{Future, Poll, Stream, future::{ok, err, IntoFuture}};
4use sasl::client::mechanisms::{Anonymous, Plain, Scram};
5use sasl::client::Mechanism;
6use sasl::common::scram::{Sha1, Sha256};
7use sasl::common::Credentials;
8use tokio_io::{AsyncRead, AsyncWrite};
9use xmpp_parsers::TryFrom;
10use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
11
12use crate::xmpp_codec::Packet;
13use crate::xmpp_stream::XMPPStream;
14use crate::{AuthError, Error, ProtocolError};
15
16const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
17
18pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
19 future: Box<dyn Future<Item = XMPPStream<S>, Error = Error>>,
20}
21
22impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
23 pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
24 let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism>>> = vec![
25 Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
26 Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
27 Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
28 Box::new(|| Box::new(Anonymous::new())),
29 ];
30
31 let remote_mechs: HashSet<String> = stream
32 .stream_features
33 .get_child("mechanisms", NS_XMPP_SASL)
34 .ok_or(AuthError::NoMechanism)?
35 .children()
36 .filter(|child| child.is("mechanism", NS_XMPP_SASL))
37 .map(|mech_el| mech_el.text())
38 .collect();
39
40 for local_mech in local_mechs {
41 let mut mechanism = local_mech();
42 if remote_mechs.contains(mechanism.name()) {
43 let initial = mechanism.initial().map_err(AuthError::Sasl)?;
44 let mechanism_name = XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
45
46 let send_initial = Box::new(stream.send_stanza(Auth {
47 mechanism: mechanism_name,
48 data: initial,
49 }))
50 .map_err(Error::Io);
51 let future = Box::new(send_initial.and_then(
52 |stream| Self::handle_challenge(stream, mechanism)
53 ).and_then(
54 |stream| stream.restart()
55 ));
56 return Ok(ClientAuth {
57 future,
58 });
59 }
60 }
61
62 Err(AuthError::NoMechanism)?
63 }
64
65 fn handle_challenge(stream: XMPPStream<S>, mut mechanism: Box<dyn Mechanism>) -> Box<dyn Future<Item = XMPPStream<S>, Error = Error>> {
66 Box::new(
67 stream.into_future()
68 .map_err(|(e, _stream)| e.into())
69 .and_then(|(stanza, stream)| {
70 match stanza {
71 Some(Packet::Stanza(stanza)) => {
72 if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
73 let response = mechanism
74 .response(&challenge.data);
75 Box::new(
76 response
77 .map_err(|e| AuthError::Sasl(e).into())
78 .into_future()
79 .and_then(|response| {
80 // Send response and loop
81 stream.send_stanza(Response { data: response })
82 .map_err(Error::Io)
83 .and_then(|stream| Self::handle_challenge(stream, mechanism))
84 })
85 )
86 } else if let Ok(_) = Success::try_from(stanza.clone()) {
87 Box::new(ok(stream))
88 } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
89 Box::new(err(Error::Auth(AuthError::Fail(failure.defined_condition))))
90 } else if stanza.name() == "failure" {
91 // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1
92 Box::new(err(Error::Auth(AuthError::Sasl("failure".to_string()))))
93 } else {
94 // ignore and loop
95 Self::handle_challenge(stream, mechanism)
96 }
97 }
98 Some(_) => {
99 // ignore and loop
100 Self::handle_challenge(stream, mechanism)
101 }
102 None => Box::new(err(Error::Disconnected))
103 }
104 })
105 )
106 }
107}
108
109impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
110 type Item = XMPPStream<S>;
111 type Error = Error;
112
113 fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
114 self.future.poll()
115 }
116}