auth.rs

  1use std::mem::replace;
  2use futures::{Future, Poll, Async, sink, Sink, Stream};
  3use tokio_io::{AsyncRead, AsyncWrite};
  4use minidom::Element;
  5use sasl::common::Credentials;
  6use sasl::common::scram::{Sha1, Sha256};
  7use sasl::client::Mechanism;
  8use sasl::client::mechanisms::{Scram, Plain, Anonymous};
  9use serialize::base64::{self, ToBase64, FromBase64};
 10
 11use xmpp_codec::Packet;
 12use xmpp_stream::XMPPStream;
 13use stream_start::StreamStart;
 14
 15const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 16
 17pub struct ClientAuth<S: AsyncWrite> {
 18    state: ClientAuthState<S>,
 19    mechanism: Box<Mechanism>,
 20}
 21
 22enum ClientAuthState<S: AsyncWrite> {
 23    WaitSend(sink::Send<XMPPStream<S>>),
 24    WaitRecv(XMPPStream<S>),
 25    Start(StreamStart<S>),
 26    Invalid,
 27}
 28
 29impl<S: AsyncWrite> ClientAuth<S> {
 30    pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, String> {
 31        let mechs: Vec<Box<Mechanism>> = vec![
 32            Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
 33            Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
 34            Box::new(Plain::from_credentials(creds).unwrap()),
 35            Box::new(Anonymous::new()),
 36        ];
 37
 38        let mech_names: Vec<String> =
 39            match stream.stream_features.get_child("mechanisms", NS_XMPP_SASL) {
 40                None =>
 41                    return Err("No auth mechanisms".to_owned()),
 42                Some(mechs) =>
 43                    mechs.children()
 44                    .filter(|child| child.is("mechanism", NS_XMPP_SASL))
 45                    .map(|mech_el| mech_el.text())
 46                    .collect(),
 47            };
 48        println!("SASL mechanisms offered: {:?}", mech_names);
 49
 50        for mut mech in mechs {
 51            let name = mech.name().to_owned();
 52            if mech_names.iter().any(|name1| *name1 == name) {
 53                println!("SASL mechanism selected: {:?}", name);
 54                let initial = try!(mech.initial());
 55                let mut this = ClientAuth {
 56                    state: ClientAuthState::Invalid,
 57                    mechanism: mech,
 58                };
 59                this.send(
 60                    stream,
 61                    "auth", &[("mechanism", &name)],
 62                    &initial
 63                );
 64                return Ok(this);
 65            }
 66        }
 67
 68        Err("No supported SASL mechanism available".to_owned())
 69    }
 70
 71    fn send(&mut self, stream: XMPPStream<S>, nonza_name: &str, attrs: &[(&str, &str)], content: &[u8]) {
 72        let nonza = Element::builder(nonza_name)
 73            .ns(NS_XMPP_SASL);
 74        let nonza = attrs.iter()
 75            .fold(nonza, |nonza, &(name, value)| nonza.attr(name, value))
 76            .append(content.to_base64(base64::STANDARD))
 77            .build();
 78
 79        let send = stream.send_stanza(nonza);
 80
 81        self.state = ClientAuthState::WaitSend(send);
 82    }
 83}
 84
 85impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
 86    type Item = XMPPStream<S>;
 87    type Error = String;
 88
 89    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 90        let state = replace(&mut self.state, ClientAuthState::Invalid);
 91
 92        match state {
 93            ClientAuthState::WaitSend(mut send) =>
 94                match send.poll() {
 95                    Ok(Async::Ready(stream)) => {
 96                        self.state = ClientAuthState::WaitRecv(stream);
 97                        self.poll()
 98                    },
 99                    Ok(Async::NotReady) => {
100                        self.state = ClientAuthState::WaitSend(send);
101                        Ok(Async::NotReady)
102                    },
103                    Err(e) =>
104                        Err(format!("{}", e)),
105                },
106            ClientAuthState::WaitRecv(mut stream) =>
107                match stream.poll() {
108                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
109                        if stanza.is("challenge", NS_XMPP_SASL) =>
110                    {
111                        let content = try!(
112                            stanza.text()
113                                .from_base64()
114                                .map_err(|e| format!("{}", e))
115                        );
116                        let response = try!(self.mechanism.response(&content));
117                        self.send(stream, "response", &[], &response);
118                        self.poll()
119                    },
120                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
121                        if stanza.is("success", NS_XMPP_SASL) =>
122                    {
123                        let start = stream.restart();
124                        self.state = ClientAuthState::Start(start);
125                        self.poll()
126                    },
127                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
128                        if stanza.is("failure", NS_XMPP_SASL) =>
129                    {
130                        let e = stanza.children().next()
131                            .map(|child| child.name())
132                            .unwrap_or("Authentication failure");
133                        Err(e.to_owned())
134                    },
135                    Ok(Async::Ready(event)) => {
136                        println!("ClientAuth ignore {:?}", event);
137                        Ok(Async::NotReady)
138                    },
139                    Ok(_) => {
140                        self.state = ClientAuthState::WaitRecv(stream);
141                        Ok(Async::NotReady)
142                    },
143                    Err(e) =>
144                        Err(format!("{}", e)),
145                },
146            ClientAuthState::Start(mut start) =>
147                match start.poll() {
148                    Ok(Async::Ready(stream)) =>
149                        Ok(Async::Ready(stream)),
150                    Ok(Async::NotReady) => {
151                        self.state = ClientAuthState::Start(start);
152                        Ok(Async::NotReady)
153                    },
154                    Err(e) =>
155                        Err(format!("{}", e)),
156                },
157            ClientAuthState::Invalid =>
158                unreachable!(),
159        }
160    }
161}