auth.rs

  1use std::mem::replace;
  2use std::str::FromStr;
  3use futures::{Future, Poll, Async, sink, Stream};
  4use tokio_io::{AsyncRead, AsyncWrite};
  5use sasl::common::Credentials;
  6use sasl::common::scram::{Sha1, Sha256};
  7use sasl::client::Mechanism;
  8use sasl::client::mechanisms::{Scram, Plain, Anonymous};
  9use minidom::Element;
 10use xmpp_parsers::sasl::{Auth, Challenge, Response, Success, Failure, Mechanism as XMPPMechanism};
 11use try_from::TryFrom;
 12
 13use xmpp_codec::Packet;
 14use xmpp_stream::XMPPStream;
 15use stream_start::StreamStart;
 16use {Error, AuthError, ProtocolError};
 17
 18const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 19
 20
 21pub struct ClientAuth<S: AsyncWrite> {
 22    state: ClientAuthState<S>,
 23    mechanism: Box<Mechanism>,
 24}
 25
 26enum ClientAuthState<S: AsyncWrite> {
 27    WaitSend(sink::Send<XMPPStream<S>>),
 28    WaitRecv(XMPPStream<S>),
 29    Start(StreamStart<S>),
 30    Invalid,
 31}
 32
 33impl<S: AsyncWrite> ClientAuth<S> {
 34    pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
 35        let mechs: Vec<Box<Mechanism>> = vec![
 36            Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
 37            Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
 38            Box::new(Plain::from_credentials(creds).unwrap()),
 39            Box::new(Anonymous::new()),
 40        ];
 41
 42        let mech_names: Vec<String> =
 43            match stream.stream_features.get_child("mechanisms", NS_XMPP_SASL) {
 44                None =>
 45                    return Err(AuthError::NoMechanism.into()),
 46                Some(mechs) =>
 47                    mechs.children()
 48                    .filter(|child| child.is("mechanism", NS_XMPP_SASL))
 49                    .map(|mech_el| mech_el.text())
 50                    .collect(),
 51            };
 52        // println!("SASL mechanisms offered: {:?}", mech_names);
 53
 54        for mut mech in mechs {
 55            let name = mech.name().to_owned();
 56            if mech_names.iter().any(|name1| *name1 == name) {
 57                // println!("SASL mechanism selected: {:?}", name);
 58                let initial = match mech.initial() {
 59                    Ok(initial) => initial,
 60                    Err(e) => return Err(AuthError::Sasl(e).into()),
 61                };
 62                let mut this = ClientAuth {
 63                    state: ClientAuthState::Invalid,
 64                    mechanism: mech,
 65                };
 66                let mechanism = match XMPPMechanism::from_str(&name) {
 67                    Ok(mechanism) => mechanism,
 68                    Err(e) => return Err(ProtocolError::Parsers(e).into()),
 69                };
 70                this.send(
 71                    stream,
 72                    Auth {
 73                        mechanism,
 74                        data: initial,
 75                    }
 76                );
 77                return Ok(this);
 78            }
 79        }
 80
 81        Err(AuthError::NoMechanism.into())
 82    }
 83
 84    fn send<N: Into<Element>>(&mut self, stream: XMPPStream<S>, nonza: N) {
 85        let send = stream.send_stanza(nonza);
 86
 87        self.state = ClientAuthState::WaitSend(send);
 88    }
 89}
 90
 91impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
 92    type Item = XMPPStream<S>;
 93    type Error = Error;
 94
 95    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 96        let state = replace(&mut self.state, ClientAuthState::Invalid);
 97
 98        match state {
 99            ClientAuthState::WaitSend(mut send) =>
100                match send.poll() {
101                    Ok(Async::Ready(stream)) => {
102                        self.state = ClientAuthState::WaitRecv(stream);
103                        self.poll()
104                    },
105                    Ok(Async::NotReady) => {
106                        self.state = ClientAuthState::WaitSend(send);
107                        Ok(Async::NotReady)
108                    },
109                    Err(e) =>
110                        Err(e.into()),
111                },
112            ClientAuthState::WaitRecv(mut stream) =>
113                match stream.poll() {
114                    Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
115                        if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
116                            let response = self.mechanism.response(&challenge.data)
117                                .map_err(AuthError::Sasl)?;
118                            self.send(stream, Response { data: response });
119                            self.poll()
120                        } else if let Ok(_) = Success::try_from(stanza.clone()) {
121                            let start = stream.restart();
122                            self.state = ClientAuthState::Start(start);
123                            self.poll()
124                        } else if let Ok(failure) = Failure::try_from(stanza) {
125                            Err(AuthError::Fail(failure.defined_condition).into())
126                        } else {
127                            Ok(Async::NotReady)
128                        }
129                    }
130                    Ok(Async::Ready(_event)) => {
131                        // println!("ClientAuth ignore {:?}", _event);
132                        Ok(Async::NotReady)
133                    },
134                    Ok(_) => {
135                        self.state = ClientAuthState::WaitRecv(stream);
136                        Ok(Async::NotReady)
137                    },
138                    Err(e) =>
139                        Err(ProtocolError::Parser(e).into())
140                },
141            ClientAuthState::Start(mut start) =>
142                match start.poll() {
143                    Ok(Async::Ready(stream)) =>
144                        Ok(Async::Ready(stream)),
145                    Ok(Async::NotReady) => {
146                        self.state = ClientAuthState::Start(start);
147                        Ok(Async::NotReady)
148                    },
149                    Err(e) =>
150                        Err(e.into())
151                },
152            ClientAuthState::Invalid =>
153                unreachable!(),
154        }
155    }
156}