auth: clarify + optimize

Astro created

Change summary

src/client/auth.rs | 25 +++++++++++--------------
1 file changed, 11 insertions(+), 14 deletions(-)

Detailed changes

src/client/auth.rs 🔗

@@ -1,5 +1,6 @@
 use std::mem::replace;
 use std::str::FromStr;
+use std::collections::HashSet;
 use futures::{sink, Async, Future, Poll, Stream, future::{ok, err, IntoFuture}};
 use minidom::Element;
 use sasl::client::mechanisms::{Anonymous, Plain, Scram};
@@ -22,15 +23,14 @@ pub struct ClientAuth<S: AsyncRead + AsyncWrite> {
 
 impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
     pub fn new(stream: XMPPStream<S>, creds: Credentials) -> Result<Self, Error> {
-        let mechs: Vec<Box<Mechanism>> = vec![
-            // TODO: Box::new(|| …
-            Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap()),
-            Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap()),
-            Box::new(Plain::from_credentials(creds).unwrap()),
-            Box::new(Anonymous::new()),
+        let local_mechs: Vec<Box<Fn() -> Box<Mechanism>>> = vec![
+            Box::new(|| Box::new(Scram::<Sha256>::from_credentials(creds.clone()).unwrap())),
+            Box::new(|| Box::new(Scram::<Sha1>::from_credentials(creds.clone()).unwrap())),
+            Box::new(|| Box::new(Plain::from_credentials(creds.clone()).unwrap())),
+            Box::new(|| Box::new(Anonymous::new())),
         ];
 
-        let mech_names: Vec<String> = stream
+        let remote_mechs: HashSet<String> = stream
             .stream_features
             .get_child("mechanisms", NS_XMPP_SASL)
             .ok_or(AuthError::NoMechanism)?
@@ -38,15 +38,12 @@ impl<S: AsyncRead + AsyncWrite + 'static> ClientAuth<S> {
             .filter(|child| child.is("mechanism", NS_XMPP_SASL))
             .map(|mech_el| mech_el.text())
             .collect();
-        // TODO: iter instead of collect()
-        // println!("SASL mechanisms offered: {:?}", mech_names);
 
-        for mut mechanism in mechs {
-            let name = mechanism.name().to_owned();
-            if mech_names.iter().any(|name1| *name1 == name) {
-                // println!("SASL mechanism selected: {:?}", name);
+        for local_mech in local_mechs {
+            let mut mechanism = local_mech();
+            if remote_mechs.contains(mechanism.name()) {
                 let initial = mechanism.initial().map_err(AuthError::Sasl)?;
-                let mechanism_name = XMPPMechanism::from_str(&name).map_err(ProtocolError::Parsers)?;
+                let mechanism_name = XMPPMechanism::from_str(mechanism.name()).map_err(ProtocolError::Parsers)?;
 
                 let send_initial = Box::new(stream.send_stanza(Auth {
                     mechanism: mechanism_name,