support arbitrary SASL mechanisms

lumi created

Change summary

examples/client.rs |  6 ++++--
src/client.rs      | 38 +++++++++++++++++++++++++++-----------
src/error.rs       |  9 +++++++++
src/sasl.rs        |  2 +-
4 files changed, 41 insertions(+), 14 deletions(-)

Detailed changes

examples/client.rs 🔗

@@ -4,15 +4,17 @@ use xmpp::jid::Jid;
 use xmpp::client::ClientBuilder;
 use xmpp::plugins::messaging::{MessagingPlugin, MessageEvent};
 use xmpp::plugins::presence::{PresencePlugin, Show};
+use xmpp::sasl::mechanisms::Plain;
 
 use std::env;
 
 fn main() {
     let jid: Jid = env::var("JID").unwrap().parse().unwrap();
-    let mut client = ClientBuilder::new(jid).connect().unwrap();
+    let mut client = ClientBuilder::new(jid.clone()).connect().unwrap();
     client.register_plugin(MessagingPlugin::new());
     client.register_plugin(PresencePlugin::new());
-    client.connect_plain(&env::var("PASS").unwrap()).unwrap();
+    let pass = env::var("PASS").unwrap();
+    client.connect(&mut Plain::new(jid.node.clone().expect("JID requires a node"), pass)).unwrap();
     client.plugin::<PresencePlugin>().set_presence(Show::Available, None).unwrap();
     loop {
         let event = client.next_event().unwrap();

src/client.rs 🔗

@@ -6,7 +6,6 @@ use plugin::{Plugin, PluginProxyBinding};
 use event::AbstractEvent;
 use connection::{Connection, C2S};
 use sasl::SaslMechanism;
-use sasl::mechanisms::Plain as SaslPlain;
 
 use base64;
 
@@ -122,8 +121,8 @@ impl Client {
         Ok(())
     }
 
-    /// Connects using SASL plain authentication.
-    pub fn connect_plain(&mut self, password: &str) -> Result<(), Error> {
+    /// Connects using the specified SASL mechanism.
+    pub fn connect<S: SaslMechanism>(&mut self, mechanism: &mut S) -> Result<(), Error> {
         // TODO: this is very ugly
         loop {
             let e = self.transport.read_event().unwrap();
@@ -150,19 +149,36 @@ impl Client {
                     self.transport.write_element(&elem)?;
                 }
                 else {
-                    let name = self.jid.node.clone().expect("JID has no node");
-                    let mut plain = SaslPlain::new(name, password.to_owned());
-                    let auth = plain.initial();
-                    let elem = Element::builder("auth")
-                                       .text(base64::encode(&auth))
+                    let auth = mechanism.initial();
+                    let mut elem = Element::builder("auth")
+                                           .ns(ns::SASL)
+                                           .attr("mechanism", "PLAIN")
+                                           .build();
+                    if !auth.is_empty() {
+                        elem.append_text_node(base64::encode(&auth));
+                    }
+                    self.transport.write_element(&elem)?;
+                }
+            }
+            else if n.is("challenge", ns::SASL) {
+                let text = n.text();
+                let challenge = if text == "" {
+                    Vec::new()
+                }
+                else {
+                    base64::decode(&text)?
+                };
+                let response = mechanism.response(&challenge);
+                let mut elem = Element::builder("response")
                                        .ns(ns::SASL)
-                                       .attr("mechanism", "PLAIN")
                                        .build();
-                    self.transport.write_element(&elem)?;
-                    did_sasl = true;
+                if !response.is_empty() {
+                    elem.append_text_node(base64::encode(&response));
                 }
+                self.transport.write_element(&elem)?;
             }
             else if n.is("success", ns::SASL) {
+                did_sasl = true;
                 self.transport.reset_stream();
                 C2S::init(&mut self.transport, &self.jid.domain, "after_sasl")?;
                 loop {

src/error.rs 🔗

@@ -12,6 +12,8 @@ use xml::writer::Error as EmitterError;
 
 use minidom::Error as MinidomError;
 
+use base64::Base64Error;
+
 /// An error which wraps a bunch of errors from different crates and the stdlib.
 #[derive(Debug)]
 pub enum Error {
@@ -21,6 +23,7 @@ pub enum Error {
     HandshakeError(HandshakeError<TcpStream>),
     OpenSslErrorStack(ErrorStack),
     MinidomError(MinidomError),
+    Base64Error(Base64Error),
     StreamError,
     EndOfDocument,
 }
@@ -60,3 +63,9 @@ impl From<MinidomError> for Error {
         Error::MinidomError(err)
     }
 }
+
+impl From<Base64Error> for Error {
+    fn from(err: Base64Error) -> Error {
+        Error::Base64Error(err)
+    }
+}

src/sasl.rs 🔗

@@ -10,7 +10,7 @@ pub trait SaslMechanism {
     }
 
     /// Creates a response to the SASL challenge.
-    fn respond(&mut self, _challenge: &[u8]) -> Vec<u8> {
+    fn response(&mut self, _challenge: &[u8]) -> Vec<u8> {
         Vec::new()
     }
 }