add client_auth using sasl

Astro created

Change summary

Cargo.toml           |   2 
examples/echo_bot.rs |   5 +
src/client_auth.rs   | 160 ++++++++++++++++++++++++++++++++++++++++++++++
src/lib.rs           |   4 +
src/xmpp_stream.rs   |  12 +++
5 files changed, 182 insertions(+), 1 deletion(-)

Detailed changes

Cargo.toml 🔗

@@ -11,3 +11,5 @@ bytes = "*"
 RustyXML = "*"
 rustls = "*"
 tokio-rustls = "*"
+sasl = "*"
+rustc-serialize = "*"

examples/echo_bot.rs 🔗

@@ -33,6 +33,9 @@ fn main() {
         } else {
             panic!("No STARTTLS")
         }
+    }).map_err(|e| format!("{}", e)
+    ).and_then(|stream| {
+        stream.auth("astrobot", "").expect("auth")
     }).and_then(|stream| {
         stream.for_each(|event| {
             match event {
@@ -40,7 +43,7 @@ fn main() {
                 _ => println!("!! {:?}", event),
             }
             Ok(())
-        })
+        }).map_err(|e| format!("{}", e))
     });
     match core.run(client) {
         Ok(_) => (),

src/client_auth.rs 🔗

@@ -0,0 +1,160 @@
+use std::mem::replace;
+use futures::*;
+use futures::sink;
+use tokio_io::{AsyncRead, AsyncWrite};
+use xml;
+use sasl::common::Credentials;
+use sasl::common::scram::*;
+use sasl::client::Mechanism;
+use sasl::client::mechanisms::*;
+use serialize::base64::{self, ToBase64, FromBase64};
+
+use xmpp_codec::*;
+use xmpp_stream::*;
+
+const NS_XMPP_SASL: &str = "urn:ietf:params:xml:ns:xmpp-sasl";
+
+pub struct ClientAuth<S: AsyncWrite> {
+    state: ClientAuthState<S>,
+    mechanism: Box<Mechanism>,
+}
+
+enum ClientAuthState<S: AsyncWrite> {
+    WaitSend(sink::Send<XMPPStream<S>>),
+    WaitRecv(XMPPStream<S>),
+    Invalid,
+}
+
+impl<S: AsyncWrite> ClientAuth<S> {
+    pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, String> {
+        let mechs: Vec<Box<Mechanism>> = vec![
+            Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
+            Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
+            Box::new(Plain::from_credentials(creds).unwrap()),
+            Box::new(Anonymous::new()),
+        ];
+
+        println!("stream_features: {}", stream.stream_features);
+        let mech_names: Vec<String> =
+            match stream.stream_features.get_child("mechanisms", Some(NS_XMPP_SASL)) {
+                None =>
+                    return Err("No auth mechanisms".to_owned()),
+                Some(mechs) =>
+                    mechs.get_children("mechanism", Some(NS_XMPP_SASL))
+                    .map(|mech_el| mech_el.content_str())
+                    .collect(),
+            };
+        println!("Offered mechanisms: {:?}", mech_names);
+
+        for mut mech in mechs {
+            let name = mech.name().to_owned();
+            if mech_names.iter().any(|name1| *name1 == name) {
+                println!("Selected mechanism: {:?}", name);
+                let initial = try!(mech.initial());
+                let mut this = ClientAuth {
+                    state: ClientAuthState::Invalid,
+                    mechanism: mech,
+                };
+                this.send(
+                    stream,
+                    "auth", &[("mechanism".to_owned(), name)],
+                    &initial
+                );
+                return Ok(this);
+            }
+        }
+
+        Err("No supported SASL mechanism available".to_owned())
+    }
+
+    fn send(&mut self, stream: XMPPStream<S>, nonza_name: &str, attrs: &[(String, String)], content: &[u8]) {
+        let mut nonza = xml::Element::new(
+            nonza_name.to_owned(),
+            Some(NS_XMPP_SASL.to_owned()),
+            attrs.iter()
+                .map(|&(ref name, ref value)| (name.clone(), None, value.clone()))
+                .collect()
+        );
+        nonza.text(content.to_base64(base64::URL_SAFE));
+
+        println!("send {}", nonza);
+        let send = stream.send(Packet::Stanza(nonza));
+
+        self.state = ClientAuthState::WaitSend(send);
+    }
+}
+
+impl<S: AsyncRead + AsyncWrite> Future for ClientAuth<S> {
+    type Item = XMPPStream<S>;
+    type Error = String;
+
+    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
+        let state = replace(&mut self.state, ClientAuthState::Invalid);
+
+        match state {
+            ClientAuthState::WaitSend(mut send) =>
+                match send.poll() {
+                    Ok(Async::Ready(stream)) => {
+                        println!("send done");
+                        self.state = ClientAuthState::WaitRecv(stream);
+                        self.poll()
+                    },
+                    Ok(Async::NotReady) => {
+                        self.state = ClientAuthState::WaitSend(send);
+                        Ok(Async::NotReady)
+                    },
+                    Err(e) =>
+                        Err(format!("{}", e)),
+                },
+            ClientAuthState::WaitRecv(mut stream) =>
+                match stream.poll() {
+                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
+                        if stanza.name == "challenge"
+                        && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
+                    {
+                        let content = try!(
+                            stanza.content_str()
+                                .from_base64()
+                                .map_err(|e| format!("{}", e))
+                        );
+                        let response = try!(self.mechanism.response(&content));
+                        self.send(stream, "response", &[], &response);
+                        self.poll()
+                    },
+                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
+                        if stanza.name == "success"
+                        && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
+                        Ok(Async::Ready(stream)),
+                    Ok(Async::Ready(Some(Packet::Stanza(ref stanza))))
+                        if stanza.name == "failure"
+                        && stanza.ns == Some(NS_XMPP_SASL.to_owned()) =>
+                    {
+                        let mut e = None;
+                        for child in &stanza.children {
+                            match child {
+                                &xml::Xml::ElementNode(ref child) => {
+                                    e = Some(child.name.clone());
+                                    break
+                                },
+                                _ => (),
+                            }
+                        }
+                        let e = e.unwrap_or_else(|| "Authentication failure".to_owned());
+                        Err(e)
+                    },
+                    Ok(Async::Ready(event)) => {
+                        println!("ClientAuth ignore {:?}", event);
+                        Ok(Async::NotReady)
+                    },
+                    Ok(_) => {
+                        self.state = ClientAuthState::WaitRecv(stream);
+                        Ok(Async::NotReady)
+                    },
+                    Err(e) =>
+                        Err(format!("{}", e)),
+                },
+            ClientAuthState::Invalid =>
+                unreachable!(),
+        }
+    }
+}

src/lib.rs 🔗

@@ -6,6 +6,8 @@ extern crate bytes;
 extern crate xml;
 extern crate rustls;
 extern crate tokio_rustls;
+extern crate sasl;
+extern crate rustc_serialize as serialize;
 
 
 pub mod xmpp_codec;
@@ -15,6 +17,8 @@ mod tcp;
 pub use tcp::*;
 mod starttls;
 pub use starttls::*;
+mod client_auth;
+pub use client_auth::*;
 
 
 // type FullClient = sasl::Client<StartTLS<TCPConnection>>

src/xmpp_stream.rs 🔗

@@ -1,3 +1,4 @@
+use std::default::Default;
 use std::sync::Arc;
 use std::collections::HashMap;
 use futures::*;
@@ -5,10 +6,12 @@ use tokio_io::{AsyncRead, AsyncWrite};
 use tokio_io::codec::Framed;
 use rustls::ClientConfig;
 use xml;
+use sasl::common::Credentials;
 
 use xmpp_codec::*;
 use stream_start::*;
 use starttls::{NS_XMPP_TLS, StartTlsClient};
+use client_auth::ClientAuth;
 
 pub const NS_XMPP_STREAM: &str = "http://etherx.jabber.org/streams";
 
@@ -37,8 +40,16 @@ impl<S: AsyncRead + AsyncWrite> XMPPStream<S> {
     pub fn starttls(self, arc_config: Arc<ClientConfig>) -> StartTlsClient<S> {
         StartTlsClient::from_stream(self, arc_config)
     }
+
+    pub fn auth(self, username: &str, password: &str) -> Result<ClientAuth<S>, String> {
+        let creds = Credentials::default()
+            .with_username(username)
+            .with_password(password);
+        ClientAuth::new(self, creds)
+    }
 }
 
+/// Proxy to self.stream
 impl<S: AsyncWrite> Sink for XMPPStream<S> {
     type SinkItem = <Framed<S, XMPPCodec> as Sink>::SinkItem;
     type SinkError = <Framed<S, XMPPCodec> as Sink>::SinkError;
@@ -52,6 +63,7 @@ impl<S: AsyncWrite> Sink for XMPPStream<S> {
     }
 }
 
+/// Proxy to self.stream
 impl<S: AsyncRead> Stream for XMPPStream<S> {
     type Item = <Framed<S, XMPPCodec> as Stream>::Item;
     type Error = <Framed<S, XMPPCodec> as Stream>::Error;