simplify the API regarding authentication

lumi created

Change summary

examples/client.rs               | 11 +++---
src/client.rs                    | 53 +++++++++++++++++++++++++++------
src/sasl/mechanisms/anonymous.rs | 11 ++++++
src/sasl/mechanisms/plain.rs     | 11 ++++++
src/sasl/mechanisms/scram.rs     | 18 ++++++++++
src/sasl/mod.rs                  | 14 ++++++++
6 files changed, 99 insertions(+), 19 deletions(-)

Detailed changes

examples/client.rs 🔗

@@ -4,19 +4,18 @@ use xmpp::jid::Jid;
 use xmpp::client::ClientBuilder;
 use xmpp::plugins::messaging::{MessagingPlugin, MessageEvent};
 use xmpp::plugins::presence::{PresencePlugin, Show};
-use xmpp::sasl::mechanisms::{Scram, Sha1};
 
 use std::env;
 
 fn main() {
     let jid: Jid = env::var("JID").unwrap().parse().unwrap();
-    let mut client = ClientBuilder::new(jid.clone()).connect().unwrap();
+    let pass = env::var("PASS").unwrap();
+    let mut client = ClientBuilder::new(jid.clone())
+                                   .password(pass)
+                                   .connect()
+                                   .unwrap();
     client.register_plugin(MessagingPlugin::new());
     client.register_plugin(PresencePlugin::new());
-    let pass = env::var("PASS").unwrap();
-    let name = jid.node.clone().expect("JID requires a node");
-    client.connect(&mut Scram::<Sha1>::new(name, pass).unwrap()).unwrap();
-    client.bind().unwrap();
     client.plugin::<PresencePlugin>().set_presence(Show::Available, None).unwrap();
     loop {
         let event = client.next_event().unwrap();

src/client.rs 🔗

@@ -5,7 +5,8 @@ use ns;
 use plugin::{Plugin, PluginProxyBinding};
 use event::AbstractEvent;
 use connection::{Connection, C2S};
-use sasl::SaslMechanism;
+use sasl::{SaslMechanism, SaslCredentials, SaslSecret};
+use sasl::mechanisms::{Plain, Scram, Sha1, Sha256};
 
 use base64;
 
@@ -15,15 +16,18 @@ use xml::reader::XmlEvent as ReaderEvent;
 
 use std::sync::mpsc::{Receiver, channel};
 
+use std::collections::HashSet;
+
 /// Struct that should be moved somewhere else and cleaned up.
 #[derive(Debug)]
 pub struct StreamFeatures {
-    pub sasl_mechanisms: Option<Vec<String>>,
+    pub sasl_mechanisms: Option<HashSet<String>>,
 }
 
 /// A builder for `Client`s.
 pub struct ClientBuilder {
     jid: Jid,
+    credentials: Option<SaslCredentials>,
     host: Option<String>,
     port: u16,
 }
@@ -33,6 +37,7 @@ impl ClientBuilder {
     pub fn new(jid: Jid) -> ClientBuilder {
         ClientBuilder {
             jid: jid,
+            credentials: None,
             host: None,
             port: 5222,
         }
@@ -50,21 +55,35 @@ impl ClientBuilder {
         self
     }
 
+    /// Sets the password to use.
+    pub fn password<P: Into<String>>(mut self, password: P) -> ClientBuilder {
+        self.credentials = Some(SaslCredentials {
+            username: self.jid.node.clone().expect("JID has no node"),
+            secret: SaslSecret::Password(password.into()),
+            channel_binding: None,
+        });
+        self
+    }
+
     /// Connects to the server and returns a `Client` when succesful.
     pub fn connect(self) -> Result<Client, Error> {
         let host = &self.host.unwrap_or(self.jid.domain.clone());
+        // TODO: channel binding
         let mut transport = SslTransport::connect(host, self.port)?;
         C2S::init(&mut transport, &self.jid.domain, "before_sasl")?;
         let (sender_out, sender_in) = channel();
         let (dispatcher_out, dispatcher_in) = channel();
-        Ok(Client {
+        let mut client = Client {
             jid: self.jid,
             transport: transport,
             plugins: Vec::new(),
             binding: PluginProxyBinding::new(sender_out, dispatcher_out),
             sender_in: sender_in,
             dispatcher_in: dispatcher_in,
-        })
+        };
+        client.connect(self.credentials.expect("can't connect without credentials"))?;
+        client.bind()?;
+        Ok(client)
     }
 }
 
@@ -127,9 +146,23 @@ impl Client {
         Ok(())
     }
 
-    /// Connects and authenticates using the specified SASL mechanism.
-    pub fn connect<S: SaslMechanism>(&mut self, mechanism: &mut S) -> Result<(), Error> {
-        self.wait_for_features()?;
+    fn connect(&mut self, credentials: SaslCredentials) -> Result<(), Error> {
+        let features = self.wait_for_features()?;
+        let ms = &features.sasl_mechanisms.ok_or(Error::SaslError(Some("no SASL mechanisms".to_owned())))?;
+        fn wrap_err(err: String) -> Error { Error::SaslError(Some(err)) }
+        // TODO: better way for selecting these, enabling anonymous auth
+        let mut mechanism: Box<SaslMechanism> = if ms.contains("SCRAM-SHA-256") {
+            Box::new(Scram::<Sha256>::from_credentials(credentials).map_err(wrap_err)?)
+        }
+        else if ms.contains("SCRAM-SHA-1") {
+            Box::new(Scram::<Sha1>::from_credentials(credentials).map_err(wrap_err)?)
+        }
+        else if ms.contains("PLAIN") {
+            Box::new(Plain::from_credentials(credentials).map_err(wrap_err)?)
+        }
+        else {
+            return Err(Error::SaslError(Some("can't find a SASL mechanism to use".to_owned())));
+        };
         let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
         let mut elem = Element::builder("auth")
                                .ns(ns::SASL)
@@ -180,7 +213,7 @@ impl Client {
         }
     }
 
-    pub fn bind(&mut self) -> Result<(), Error> {
+    fn bind(&mut self) -> Result<(), Error> {
         let mut elem = Element::builder("iq")
                                .attr("id", "bind")
                                .attr("type", "set")
@@ -223,9 +256,9 @@ impl Client {
                     sasl_mechanisms: None,
                 };
                 if let Some(ms) = n.get_child("mechanisms", ns::SASL) {
-                    let mut res = Vec::new();
+                    let mut res = HashSet::new();
                     for cld in ms.children() {
-                        res.push(cld.text());
+                        res.insert(cld.text());
                     }
                     features.sasl_mechanisms = Some(res);
                 }

src/sasl/mechanisms/anonymous.rs 🔗

@@ -1,6 +1,6 @@
 //! Provides the SASL "ANONYMOUS" mechanism.
 
-use sasl::SaslMechanism;
+use sasl::{SaslMechanism, SaslCredentials, SaslSecret};
 
 pub struct Anonymous;
 
@@ -12,4 +12,13 @@ impl Anonymous {
 
 impl SaslMechanism for Anonymous {
     fn name(&self) -> &str { "ANONYMOUS" }
+
+    fn from_credentials(credentials: SaslCredentials) -> Result<Anonymous, String> {
+        if let SaslSecret::None = credentials.secret {
+            Ok(Anonymous)
+        }
+        else {
+            Err("the anonymous sasl mechanism requires no credentials".to_owned())
+        }
+    }
 }

src/sasl/mechanisms/plain.rs 🔗

@@ -1,6 +1,6 @@
 //! Provides the SASL "PLAIN" mechanism.
 
-use sasl::SaslMechanism;
+use sasl::{SaslMechanism, SaslCredentials, SaslSecret};
 
 pub struct Plain {
     username: String,
@@ -19,6 +19,15 @@ impl Plain {
 impl SaslMechanism for Plain {
     fn name(&self) -> &str { "PLAIN" }
 
+    fn from_credentials(credentials: SaslCredentials) -> Result<Plain, String> {
+        if let SaslSecret::Password(password) = credentials.secret {
+            Ok(Plain::new(credentials.username, password))
+        }
+        else {
+            Err("PLAIN requires a password".to_owned())
+        }
+    }
+
     fn initial(&mut self) -> Result<Vec<u8>, String> {
         let mut auth = Vec::new();
         auth.push(0);

src/sasl/mechanisms/scram.rs 🔗

@@ -2,7 +2,7 @@
 
 use base64;
 
-use sasl::SaslMechanism;
+use sasl::{SaslMechanism, SaslCredentials, SaslSecret};
 
 use error::Error;
 
@@ -172,6 +172,22 @@ impl<S: ScramProvider> SaslMechanism for Scram<S> {
         &self.name
     }
 
+    fn from_credentials(credentials: SaslCredentials) -> Result<Scram<S>, String> {
+        if let SaslSecret::Password(password) = credentials.secret {
+            if let Some(binding) = credentials.channel_binding {
+                Scram::new_with_channel_binding(credentials.username, password, binding)
+                      .map_err(|_| "can't generate nonce".to_owned())
+            }
+            else {
+                Scram::new(credentials.username, password)
+                      .map_err(|_| "can't generate nonce".to_owned())
+            }
+        }
+        else {
+            Err("SCRAM requires a password".to_owned())
+        }
+    }
+
     fn initial(&mut self) -> Result<Vec<u8>, String> {
         let mut gs2_header = Vec::new();
         if let Some(_) = self.channel_binding {

src/sasl/mod.rs 🔗

@@ -1,9 +1,23 @@
 //! Provides the `SaslMechanism` trait and some implementations.
 
+pub struct SaslCredentials {
+    pub username: String,
+    pub secret: SaslSecret,
+    pub channel_binding: Option<Vec<u8>>,
+}
+
+pub enum SaslSecret {
+    None,
+    Password(String),
+}
+
 pub trait SaslMechanism {
     /// The name of the mechanism.
     fn name(&self) -> &str;
 
+    /// Creates this mechanism from `SaslCredentials`.
+    fn from_credentials(credentials: SaslCredentials) -> Result<Self, String> where Self: Sized;
+
     /// Provides initial payload of the SASL mechanism.
     fn initial(&mut self) -> Result<Vec<u8>, String> {
         Ok(Vec::new())