auth.rs

  1use futures::{sink, Async, Future, Poll, Stream};
  2use minidom::Element;
  3use sasl::client::mechanisms::{Anonymous, Plain, Scram};
  4use sasl::client::Mechanism;
  5use sasl::common::scram::{Sha1, Sha256};
  6use sasl::common::Credentials;
  7use std::mem::replace;
  8use std::str::FromStr;
  9use tokio_io::{AsyncRead, AsyncWrite};
 10use try_from::TryFrom;
 11use xmpp_parsers::sasl::{Auth, Challenge, Failure, Mechanism as XMPPMechanism, Response, Success};
 12
 13use crate::stream_start::StreamStart;
 14use crate::xmpp_codec::Packet;
 15use crate::xmpp_stream::XMPPStream;
 16use crate::{AuthError, Error, ProtocolError};
 17
 18const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 19
 20pub struct ClientAuth<S: AsyncWrite> {
 21    state: ClientAuthState<S>,
 22    mechanism: Box<Mechanism>,
 23}
 24
 25enum ClientAuthState<S: AsyncWrite> {
 26    WaitSend(sink::Send<XMPPStream<S>>),
 27    WaitRecv(XMPPStream<S>),
 28    Start(StreamStart<S>),
 29    Invalid,
 30}
 31
 32impl<S: AsyncWrite> ClientAuth<S> {
 33    pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
 34        let mechs: Vec<Box<Mechanism>> = vec![
 35            Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
 36            Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
 37            Box::new(Plain::from_credentials(creds).unwrap()),
 38            Box::new(Anonymous::new()),
 39        ];
 40
 41        let mech_names: Vec<String> = stream
 42            .stream_features
 43            .get_child("mechanisms", NS_XMPP_SASL)
 44            .ok_or(AuthError::NoMechanism)?
 45            .children()
 46            .filter(|child| child.is("mechanism", NS_XMPP_SASL))
 47            .map(|mech_el| mech_el.text())
 48            .collect();
 49        // println!("SASL mechanisms offered: {:?}", mech_names);
 50
 51        for mut mech in mechs {
 52            let name = mech.name().to_owned();
 53            if mech_names.iter().any(|name1| *name1 == name) {
 54                // println!("SASL mechanism selected: {:?}", name);
 55                let initial = mech.initial().map_err(AuthError::Sasl)?;
 56                let mut this = ClientAuth {
 57                    state: ClientAuthState::Invalid,
 58                    mechanism: mech,
 59                };
 60                let mechanism = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
 61                this.send(
 62                    stream,
 63                    Auth {
 64                        mechanism,
 65                        data: initial,
 66                    },
 67                );
 68                return Ok(this);
 69            }
 70        }
 71
 72        Err(AuthError::NoMechanism)?
 73    }
 74
 75    fn send<N: Into<Element>>(&mut self, stream: XMPPStream<S>, nonza: N) {
 76        let send = stream.send_stanza(nonza);
 77
 78        self.state = ClientAuthState::WaitSend(send);
 79    }
 80}
 81
 82impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
 83    type Item = XMPPStream<S>;
 84    type Error = Error;
 85
 86    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 87        let state = replace(&mut self.state, ClientAuthState::Invalid);
 88
 89        match state {
 90            ClientAuthState::WaitSend(mut send) => match send.poll() {
 91                Ok(Async::Ready(stream)) => {
 92                    self.state = ClientAuthState::WaitRecv(stream);
 93                    self.poll()
 94                }
 95                Ok(Async::NotReady) => {
 96                    self.state = ClientAuthState::WaitSend(send);
 97                    Ok(Async::NotReady)
 98                }
 99                Err(e) => Err(e)?,
100            },
101            ClientAuthState::WaitRecv(mut stream) => match stream.poll() {
102                Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
103                    if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
104                        let response = self
105                            .mechanism
106                            .response(&challenge.data)
107                            .map_err(AuthError::Sasl)?;
108                        self.send(stream, Response { data: response });
109                        self.poll()
110                    } else if let Ok(_) = Success::try_from(stanza.clone()) {
111                        let start = stream.restart();
112                        self.state = ClientAuthState::Start(start);
113                        self.poll()
114                    } else if let Ok(failure) = Failure::try_from(stanza) {
115                        Err(AuthError::Fail(failure.defined_condition))?
116                    } else {
117                        Ok(Async::NotReady)
118                    }
119                }
120                Ok(Async::Ready(_event)) => {
121                    // println!("ClientAuth ignore {:?}", _event);
122                    Ok(Async::NotReady)
123                }
124                Ok(_) => {
125                    self.state = ClientAuthState::WaitRecv(stream);
126                    Ok(Async::NotReady)
127                }
128                Err(e) => Err(ProtocolError::Parser(e))?,
129            },
130            ClientAuthState::Start(mut start) => match start.poll() {
131                Ok(Async::Ready(stream)) => Ok(Async::Ready(stream)),
132                Ok(Async::NotReady) => {
133                    self.state = ClientAuthState::Start(start);
134                    Ok(Async::NotReady)
135                }
136                Err(e) => Err(e),
137            },
138            ClientAuthState::Invalid => unreachable!(),
139        }
140    }
141}