auth.rs

  1use std::mem::replace;
  2use futures::*;
  3use futures::sink;
  4use tokio_io::{AsyncRead, AsyncWrite};
  5use xml;
  6use sasl::common::Credentials;
  7use sasl::common::scram::*;
  8use sasl::client::Mechanism;
  9use sasl::client::mechanisms::*;
 10use serialize::base64::{self, ToBase64, FromBase64};
 11
 12use xmpp_codec::*;
 13use xmpp_stream::*;
 14use stream_start::*;
 15
 16const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 17
 18pub struct ClientAuth<S: AsyncWrite> {
 19    state: ClientAuthState<S>,
 20    mechanism: Box<Mechanism>,
 21}
 22
 23enum ClientAuthState<S: AsyncWrite> {
 24    WaitSend(sink::Send<XMPPStream<S>>),
 25    WaitRecv(XMPPStream<S>),
 26    Start(StreamStart<S>),
 27    Invalid,
 28}
 29
 30impl<S: AsyncWrite> ClientAuth<S> {
 31    pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, String> {
 32        let mechs: Vec<Box<Mechanism>> = vec![
 33            Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
 34            Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
 35            Box::new(Plain::from_credentials(creds).unwrap()),
 36            Box::new(Anonymous::new()),
 37        ];
 38
 39        let mech_names: Vec<String> =
 40            match stream.stream_features.get_child("mechanisms", Some(NS_XMPP_SASL)) {
 41                None =>
 42                    return Err("No auth mechanisms".to_owned()),
 43                Some(mechs) =>
 44                    mechs.get_children("mechanism", Some(NS_XMPP_SASL))
 45                    .map(|mech_el| mech_el.content_str())
 46                    .collect(),
 47            };
 48        println!("SASL mechanisms offered: {:?}", mech_names);
 49
 50        for mut mech in mechs {
 51            let name = mech.name().to_owned();
 52            if mech_names.iter().any(|name1| *name1 == name) {
 53                println!("SASL mechanism selected: {:?}", name);
 54                let initial = try!(mech.initial());
 55                let mut this = ClientAuth {
 56                    state: ClientAuthState::Invalid,
 57                    mechanism: mech,
 58                };
 59                this.send(
 60                    stream,
 61                    "auth", &[("mechanism".to_owned(), name)],
 62                    &initial
 63                );
 64                return Ok(this);
 65            }
 66        }
 67
 68        Err("No supported SASL mechanism available".to_owned())
 69    }
 70
 71    fn send(&mut self, stream: XMPPStream<S>, nonza_name: &str, attrs: &[(String, String)], content: &[u8]) {
 72        let mut nonza = xml::Element::new(
 73            nonza_name.to_owned(),
 74            Some(NS_XMPP_SASL.to_owned()),
 75            attrs.iter()
 76                .map(|&(ref name, ref value)| (name.clone(), None, value.clone()))
 77                .collect()
 78        );
 79        nonza.text(content.to_base64(base64::URL_SAFE));
 80
 81        let send = stream.send(Packet::Stanza(nonza));
 82
 83        self.state = ClientAuthState::WaitSend(send);
 84    }
 85}
 86
 87impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
 88    type Item = XMPPStream<S>;
 89    type Error = String;
 90
 91    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 92        let state = replace(&mut self.state, ClientAuthState::Invalid);
 93
 94        match state {
 95            ClientAuthState::WaitSend(mut send) =>
 96                match send.poll() {
 97                    Ok(Async::Ready(stream)) => {
 98                        self.state = ClientAuthState::WaitRecv(stream);
 99                        self.poll()
100                    },
101                    Ok(Async::NotReady) => {
102                        self.state = ClientAuthState::WaitSend(send);
103                        Ok(Async::NotReady)
104                    },
105                    Err(e) =>
106                        Err(format!("{}", e)),
107                },
108            ClientAuthState::WaitRecv(mut stream) =>
109                match stream.poll() {
110                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
111                        if stanza.name == "challenge"
112                        && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
113                    {
114                        let content = try!(
115                            stanza.content_str()
116                                .from_base64()
117                                .map_err(|e| format!("{}", e))
118                        );
119                        let response = try!(self.mechanism.response(&content));
120                        self.send(stream, "response", &[], &response);
121                        self.poll()
122                    },
123                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
124                        if stanza.name == "success"
125                        && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
126                    {
127                        let start = stream.restart();
128                        self.state = ClientAuthState::Start(start);
129                        self.poll()
130                    },
131                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
132                        if stanza.name == "failure"
133                        && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
134                    {
135                        let mut e = None;
136                        for child in &stanza.children {
137                            match child {
138                                &xml::Xml::ElementNode(ref child) => {
139                                    e = Some(child.name.clone());
140                                    break
141                                },
142                                _ => (),
143                            }
144                        }
145                        let e = e.unwrap_or_else(|| "Authentication failure".to_owned());
146                        Err(e)
147                    },
148                    Ok(Async::Ready(event)) => {
149                        println!("ClientAuth ignore {:?}", event);
150                        Ok(Async::NotReady)
151                    },
152                    Ok(_) => {
153                        self.state = ClientAuthState::WaitRecv(stream);
154                        Ok(Async::NotReady)
155                    },
156                    Err(e) =>
157                        Err(format!("{}", e)),
158                },
159            ClientAuthState::Start(mut start) =>
160                match start.poll() {
161                    Ok(Async::Ready(stream)) =>
162                        Ok(Async::Ready(stream)),
163                    Ok(Async::NotReady) => {
164                        self.state = ClientAuthState::Start(start);
165                        Ok(Async::NotReady)
166                    },
167                    Err(e) =>
168                        Err(format!("{}", e)),
169                },
170            ClientAuthState::Invalid =>
171                unreachable!(),
172        }
173    }
174}