auth.rs

  1use std::mem::replace;
  2use std::str::FromStr;
  3use futures::{Future, Poll, Async, sink, Stream};
  4use tokio_io::{AsyncRead, AsyncWrite};
  5use sasl::common::Credentials;
  6use sasl::common::scram::{Sha1, Sha256};
  7use sasl::client::Mechanism;
  8use sasl::client::mechanisms::{Scram, Plain, Anonymous};
  9use minidom::Element;
 10use xmpp_parsers::sasl::{Auth, Challenge, Response, Success, Failure, Mechanism as XMPPMechanism};
 11use try_from::TryFrom;
 12
 13use xmpp_codec::Packet;
 14use xmpp_stream::XMPPStream;
 15use stream_start::StreamStart;
 16
 17const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 18
 19pub struct ClientAuth<S: AsyncWrite> {
 20    state: ClientAuthState<S>,
 21    mechanism: Box<Mechanism>,
 22}
 23
 24enum ClientAuthState<S: AsyncWrite> {
 25    WaitSend(sink::Send<XMPPStream<S>>),
 26    WaitRecv(XMPPStream<S>),
 27    Start(StreamStart<S>),
 28    Invalid,
 29}
 30
 31impl<S: AsyncWrite> ClientAuth<S> {
 32    pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, String> {
 33        let mechs: Vec<Box<Mechanism>> = vec![
 34            Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
 35            Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
 36            Box::new(Plain::from_credentials(creds).unwrap()),
 37            Box::new(Anonymous::new()),
 38        ];
 39
 40        let mech_names: Vec<String> =
 41            match stream.stream_features.get_child("mechanisms", NS_XMPP_SASL) {
 42                None =>
 43                    return Err("No auth mechanisms".to_owned()),
 44                Some(mechs) =>
 45                    mechs.children()
 46                    .filter(|child| child.is("mechanism", NS_XMPP_SASL))
 47                    .map(|mech_el| mech_el.text())
 48                    .collect(),
 49            };
 50        // println!("SASL mechanisms offered: {:?}", mech_names);
 51
 52        for mut mech in mechs {
 53            let name = mech.name().to_owned();
 54            if mech_names.iter().any(|name1| *name1 == name) {
 55                // println!("SASL mechanism selected: {:?}", name);
 56                let initial = mech.initial()?;
 57                let mut this = ClientAuth {
 58                    state: ClientAuthState::Invalid,
 59                    mechanism: mech,
 60                };
 61                let mechanism = XMPPMechanism::from_str(&name)
 62                    .map_err(|e| format!("{:?}", e))?;
 63                this.send(
 64                    stream,
 65                    Auth {
 66                        mechanism,
 67                        data: initial,
 68                    }
 69                );
 70                return Ok(this);
 71            }
 72        }
 73
 74        Err("No supported SASL mechanism available".to_owned())
 75    }
 76
 77    fn send<N: Into<Element>>(&mut self, stream: XMPPStream<S>, nonza: N) {
 78        let send = stream.send_stanza(nonza);
 79
 80        self.state = ClientAuthState::WaitSend(send);
 81    }
 82}
 83
 84impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
 85    type Item = XMPPStream<S>;
 86    type Error = String;
 87
 88    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 89        let state = replace(&mut self.state, ClientAuthState::Invalid);
 90
 91        match state {
 92            ClientAuthState::WaitSend(mut send) =>
 93                match send.poll() {
 94                    Ok(Async::Ready(stream)) => {
 95                        self.state = ClientAuthState::WaitRecv(stream);
 96                        self.poll()
 97                    },
 98                    Ok(Async::NotReady) => {
 99                        self.state = ClientAuthState::WaitSend(send);
100                        Ok(Async::NotReady)
101                    },
102                    Err(e) =>
103                        Err(format!("{}", e)),
104                },
105            ClientAuthState::WaitRecv(mut stream) =>
106                match stream.poll() {
107                    Ok(Async::Ready(Some(Packet::Stanza(stanza)))) => {
108                        if let Ok(challenge) = Challenge::try_from(stanza.clone()) {
109                            let response = self.mechanism.response(&challenge.data)?;
110                            self.send(stream, Response { data: response });
111                            self.poll()
112                        } else if let Ok(_) = Success::try_from(stanza.clone()) {
113                            let start = stream.restart();
114                            self.state = ClientAuthState::Start(start);
115                            self.poll()
116                        } else if let Ok(failure) = Failure::try_from(stanza) {
117                            let e = format!("{:?}", failure.defined_condition);
118                            Err(e)
119                        } else {
120                            Ok(Async::NotReady)
121                        }
122                    }
123                    Ok(Async::Ready(_event)) => {
124                        // println!("ClientAuth ignore {:?}", _event);
125                        Ok(Async::NotReady)
126                    },
127                    Ok(_) => {
128                        self.state = ClientAuthState::WaitRecv(stream);
129                        Ok(Async::NotReady)
130                    },
131                    Err(e) =>
132                        Err(format!("{}", e)),
133                },
134            ClientAuthState::Start(mut start) =>
135                match start.poll() {
136                    Ok(Async::Ready(stream)) =>
137                        Ok(Async::Ready(stream)),
138                    Ok(Async::NotReady) => {
139                        self.state = ClientAuthState::Start(start);
140                        Ok(Async::NotReady)
141                    },
142                    Err(e) =>
143                        Err(format!("{}", e)),
144                },
145            ClientAuthState::Invalid =>
146                unreachable!(),
147        }
148    }
149}