auth.rs

  1use std::str::FromStr;
  2use std::collections::HashSet;
  3use futures::{Future, Poll, Stream, future::{ok, err, IntoFuture}};
  4use sasl::client::mechanisms::{Anonymous, Plain, Scram};
  5use sasl::client::Mechanism;
  6use sasl::common::scram::{Sha1, Sha256};
  7use sasl::common::Credentials;
  8use tokio_io::{AsyncRead, AsyncWrite};
  9use xmpp_parsers::TryFrom;
 10use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
 11
 12use crate::xmpp_codec::Packet;
 13use crate::xmpp_stream::XMPPStream;
 14use crate::{AuthError, Error, ProtocolError};
 15
 16const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 17
 18pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
 19    future: Box<dyn Future<Item = XMPPStream<S>, Error = Error>>,
 20}
 21
 22impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
 23    pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
 24        let local_mechs: Vec<Box<dyn Fn() -> Box<dyn Mechanism>>> = vec![
 25            Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
 26            Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
 27            Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
 28            Box::new(|| Box::new(Anonymous::new())),
 29        ];
 30
 31        let remote_mechs: HashSet<String> = stream
 32            .stream_features
 33            .get_child("mechanisms", NS_XMPP_SASL)
 34            .ok_or(AuthError::NoMechanism)?
 35            .children()
 36            .filter(|child| child.is("mechanism", NS_XMPP_SASL))
 37            .map(|mech_el| mech_el.text())
 38            .collect();
 39
 40        for local_mech in local_mechs {
 41            let mut mechanism = local_mech();
 42            if remote_mechs.contains(mechanism.name()) {
 43                let initial = mechanism.initial().map_err(AuthError::Sasl)?;
 44                let mechanism_name = XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
 45
 46                let send_initial = Box::new(stream.send_stanza(Auth {
 47                    mechanism: mechanism_name,
 48                    data: initial,
 49                }))
 50                    .map_err(Error::Io);
 51                let future = Box::new(send_initial.and_then(
 52                    |stream| Self::handle_challenge(stream, mechanism)
 53                ).and_then(
 54                    |stream| stream.restart()
 55                ));
 56                return Ok(ClientAuth {
 57                    future,
 58                });
 59            }
 60        }
 61
 62        Err(AuthError::NoMechanism)?
 63    }
 64
 65    fn handle_challenge(stream: XMPPStream<S>, mut mechanism: Box<dyn Mechanism>) -> Box<dyn Future<Item = XMPPStream<S>, Error = Error>> {
 66        Box::new(
 67            stream.into_future()
 68            .map_err(|(e, _stream)| e.into())
 69            .and_then(|(stanza, stream)| {
 70                match stanza {
 71                    Some(Packet::Stanza(stanza)) => {
 72                        if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
 73                            let response = mechanism
 74                                .response(&challenge.data);
 75                            Box::new(
 76                                response
 77                                    .map_err(|e| AuthError::Sasl(e).into())
 78                                    .into_future()
 79                                    .and_then(|response| {
 80                                        // Send response and loop
 81                                        stream.send_stanza(Response { data: response })
 82                                            .map_err(Error::Io)
 83                                            .and_then(|stream| Self::handle_challenge(stream, mechanism))
 84                                    })
 85                            )
 86                        } else if let Ok(_) = Success::try_from(stanza.clone()) {
 87                            Box::new(ok(stream))
 88                        } else if let Ok(failure) = Failure::try_from(stanza.clone()) {
 89                            Box::new(err(Error::Auth(AuthError::Fail(failure.defined_condition))))
 90                        } else if stanza.name() == "failure" {
 91                            // Workaround for https://gitlab.com/xmpp-rs/xmpp-parsers/merge_requests/1
 92                            Box::new(err(Error::Auth(AuthError::Sasl("failure".to_string()))))
 93                        } else {
 94                            // ignore and loop
 95                            Self::handle_challenge(stream, mechanism)
 96                        }
 97                    }
 98                    Some(_) => {
 99                        // ignore and loop
100                        Self::handle_challenge(stream, mechanism)
101                    }
102                    None => Box::new(err(Error::Disconnected))
103                }
104            })
105        )
106    }
107}
108
109impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
110    type Item = XMPPStream<S>;
111    type Error = Error;
112
113    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
114        self.future.poll()
115    }
116}