client_auth.rs

  1use std::mem::replace;
  2use futures::*;
  3use futures::sink;
  4use tokio_io::{AsyncRead, AsyncWrite};
  5use xml;
  6use sasl::common::Credentials;
  7use sasl::common::scram::*;
  8use sasl::client::Mechanism;
  9use sasl::client::mechanisms::*;
 10use serialize::base64::{self, ToBase64, FromBase64};
 11
 12use xmpp_codec::*;
 13use xmpp_stream::*;
 14use stream_start::*;
 15
 16const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
 17
 18pub struct ClientAuth<S: AsyncWrite> {
 19    state: ClientAuthState<S>,
 20    mechanism: Box<Mechanism>,
 21}
 22
 23enum ClientAuthState<S: AsyncWrite> {
 24    WaitSend(sink::Send<XMPPStream<S>>),
 25    WaitRecv(XMPPStream<S>),
 26    Start(StreamStart<S>),
 27    Invalid,
 28}
 29
 30impl<S: AsyncWrite> ClientAuth<S> {
 31    pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, String> {
 32        let mechs: Vec<Box<Mechanism>> = vec![
 33            Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
 34            Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
 35            Box::new(Plain::from_credentials(creds).unwrap()),
 36            Box::new(Anonymous::new()),
 37        ];
 38
 39        println!("stream_features: {}", stream.stream_features);
 40        let mech_names: Vec<String> =
 41            match stream.stream_features.get_child("mechanisms", Some(NS_XMPP_SASL)) {
 42                None =>
 43                    return Err("No auth mechanisms".to_owned()),
 44                Some(mechs) =>
 45                    mechs.get_children("mechanism", Some(NS_XMPP_SASL))
 46                    .map(|mech_el| mech_el.content_str())
 47                    .collect(),
 48            };
 49        println!("Offered mechanisms: {:?}", mech_names);
 50
 51        for mut mech in mechs {
 52            let name = mech.name().to_owned();
 53            if mech_names.iter().any(|name1| *name1 == name) {
 54                println!("Selected mechanism: {:?}", name);
 55                let initial = try!(mech.initial());
 56                let mut this = ClientAuth {
 57                    state: ClientAuthState::Invalid,
 58                    mechanism: mech,
 59                };
 60                this.send(
 61                    stream,
 62                    "auth", &[("mechanism".to_owned(), name)],
 63                    &initial
 64                );
 65                return Ok(this);
 66            }
 67        }
 68
 69        Err("No supported SASL mechanism available".to_owned())
 70    }
 71
 72    fn send(&mut self, stream: XMPPStream<S>, nonza_name: &str, attrs: &[(String, String)], content: &[u8]) {
 73        let mut nonza = xml::Element::new(
 74            nonza_name.to_owned(),
 75            Some(NS_XMPP_SASL.to_owned()),
 76            attrs.iter()
 77                .map(|&(ref name, ref value)| (name.clone(), None, value.clone()))
 78                .collect()
 79        );
 80        nonza.text(content.to_base64(base64::URL_SAFE));
 81
 82        println!("send {}", nonza);
 83        let send = stream.send(Packet::Stanza(nonza));
 84
 85        self.state = ClientAuthState::WaitSend(send);
 86    }
 87}
 88
 89impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
 90    type Item = XMPPStream<S>;
 91    type Error = String;
 92
 93    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
 94        let state = replace(&mut self.state, ClientAuthState::Invalid);
 95
 96        match state {
 97            ClientAuthState::WaitSend(mut send) =>
 98                match send.poll() {
 99                    Ok(Async::Ready(stream)) => {
100                        println!("send done");
101                        self.state = ClientAuthState::WaitRecv(stream);
102                        self.poll()
103                    },
104                    Ok(Async::NotReady) => {
105                        self.state = ClientAuthState::WaitSend(send);
106                        Ok(Async::NotReady)
107                    },
108                    Err(e) =>
109                        Err(format!("{}", e)),
110                },
111            ClientAuthState::WaitRecv(mut stream) =>
112                match stream.poll() {
113                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
114                        if stanza.name == "challenge"
115                        && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
116                    {
117                        let content = try!(
118                            stanza.content_str()
119                                .from_base64()
120                                .map_err(|e| format!("{}", e))
121                        );
122                        let response = try!(self.mechanism.response(&content));
123                        self.send(stream, "response", &[], &response);
124                        self.poll()
125                    },
126                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
127                        if stanza.name == "success"
128                        && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
129                    {
130                        let start = stream.restart();
131                        self.state = ClientAuthState::Start(start);
132                        self.poll()
133                    },
134                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
135                        if stanza.name == "failure"
136                        && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
137                    {
138                        let mut e = None;
139                        for child in &stanza.children {
140                            match child {
141                                &xml::Xml::ElementNode(ref child) => {
142                                    e = Some(child.name.clone());
143                                    break
144                                },
145                                _ => (),
146                            }
147                        }
148                        let e = e.unwrap_or_else(|| "Authentication failure".to_owned());
149                        Err(e)
150                    },
151                    Ok(Async::Ready(event)) => {
152                        println!("ClientAuth ignore {:?}", event);
153                        Ok(Async::NotReady)
154                    },
155                    Ok(_) => {
156                        self.state = ClientAuthState::WaitRecv(stream);
157                        Ok(Async::NotReady)
158                    },
159                    Err(e) =>
160                        Err(format!("{}", e)),
161                },
162            ClientAuthState::Start(mut start) =>
163                match start.poll() {
164                    Ok(Async::Ready(stream)) =>
165                        Ok(Async::Ready(stream)),
166                    Ok(Async::NotReady) => {
167                        self.state = ClientAuthState::Start(start);
168                        Ok(Async::NotReady)
169                    },
170                    Err(e) =>
171                        Err(format!("{}", e)),
172                },
173            ClientAuthState::Invalid =>
174                unreachable!(),
175        }
176    }
177}