some work towards channel binding support (SCRAM-SHA-{1,256}-PLUS)

lumi created

Change summary

src/client.rs                    |   2 
src/sasl/mechanisms/anonymous.rs |   2 
src/sasl/mechanisms/plain.rs     |  10 +-
src/sasl/mechanisms/scram.rs     | 134 ++++++++++++++++++---------------
src/sasl/mod.rs                  |   2 
5 files changed, 82 insertions(+), 68 deletions(-)

Detailed changes

src/client.rs πŸ”—

@@ -133,7 +133,7 @@ impl Client {
         let auth = mechanism.initial().map_err(|x| Error::SaslError(Some(x)))?;
         let mut elem = Element::builder("auth")
                                .ns(ns::SASL)
-                               .attr("mechanism", S::name())
+                               .attr("mechanism", mechanism.name())
                                .build();
         if !auth.is_empty() {
             elem.append_text_node(base64::encode(&auth));

src/sasl/mechanisms/anonymous.rs πŸ”—

@@ -11,5 +11,5 @@ impl Anonymous {
 }
 
 impl SaslMechanism for Anonymous {
-    fn name() -> &'static str { "ANONYMOUS" }
+    fn name(&self) -> &str { "ANONYMOUS" }
 }

src/sasl/mechanisms/plain.rs πŸ”—

@@ -3,26 +3,26 @@
 use sasl::SaslMechanism;
 
 pub struct Plain {
-    name: String,
+    username: String,
     password: String,
 }
 
 impl Plain {
-    pub fn new<N: Into<String>, P: Into<String>>(name: N, password: P) -> Plain {
+    pub fn new<N: Into<String>, P: Into<String>>(username: N, password: P) -> Plain {
         Plain {
-            name: name.into(),
+            username: username.into(),
             password: password.into(),
         }
     }
 }
 
 impl SaslMechanism for Plain {
-    fn name() -> &'static str { "PLAIN" }
+    fn name(&self) -> &str { "PLAIN" }
 
     fn initial(&mut self) -> Result<Vec<u8>, String> {
         let mut auth = Vec::new();
         auth.push(0);
-        auth.extend(self.name.bytes());
+        auth.extend(self.username.bytes());
         auth.push(0);
         auth.extend(self.password.bytes());
         Ok(auth)

src/sasl/mechanisms/scram.rs πŸ”—

@@ -16,6 +16,10 @@ use openssl::error::ErrorStack;
 
 use std::marker::PhantomData;
 
+use std::collections::HashMap;
+
+use std::string::FromUtf8Error;
+
 #[cfg(test)]
 #[test]
 fn xor_works() {
@@ -33,6 +37,23 @@ fn xor(a: &[u8], b: &[u8]) -> Vec<u8> {
     ret
 }
 
+fn parse_frame(frame: &[u8]) -> Result<HashMap<String, String>, FromUtf8Error> {
+    let inner = String::from_utf8(frame.to_owned())?;
+    let mut ret = HashMap::new();
+    for s in inner.split(',') {
+        let mut tmp = s.splitn(2, '=');
+        let key = tmp.next();
+        let val = tmp.next();
+        match (key, val) {
+            (Some(k), Some(v)) => {
+                ret.insert(k.to_owned(), v.to_owned());
+            },
+            _ =>(),
+        }
+    }
+    Ok(ret)
+}
+
 fn generate_nonce() -> Result<String, ErrorStack> {
     let mut data = vec![0; 32];
     rand_bytes(&mut data)?;
@@ -49,7 +70,7 @@ pub trait ScramProvider {
 pub struct Sha1;
 
 impl ScramProvider for Sha1 { // TODO: look at all these unwraps
-    fn name() -> &'static str { "SCRAM-SHA-1" }
+    fn name() -> &'static str { "SHA-1" }
 
     fn hash(data: &[u8]) -> Vec<u8> {
         hash(MessageDigest::sha1(), data).unwrap()
@@ -72,7 +93,7 @@ impl ScramProvider for Sha1 { // TODO: look at all these unwraps
 pub struct Sha256;
 
 impl ScramProvider for Sha256 { // TODO: look at all these unwraps
-    fn name() -> &'static str { "SCRAM-SHA-256" }
+    fn name() -> &'static str { "SHA-256" }
 
     fn hash(data: &[u8]) -> Vec<u8> {
         hash(MessageDigest::sha256(), data).unwrap()
@@ -94,55 +115,80 @@ impl ScramProvider for Sha256 { // TODO: look at all these unwraps
 
 enum ScramState {
     Init,
-    SentInitialMessage { initial_message: Vec<u8> },
+    SentInitialMessage { initial_message: Vec<u8>, gs2_header: Vec<u8>},
     GotServerData { server_signature: Vec<u8> },
 }
 
 pub struct Scram<S: ScramProvider> {
     name: String,
+    username: String,
     password: String,
     client_nonce: String,
     state: ScramState,
+    channel_binding: Option<Vec<u8>>,
     _marker: PhantomData<S>,
 }
 
 impl<S: ScramProvider> Scram<S> {
-    pub fn new<N: Into<String>, P: Into<String>>(name: N, password: P) -> Result<Scram<S>, Error> {
+    pub fn new<N: Into<String>, P: Into<String>>(username: N, password: P) -> Result<Scram<S>, Error> {
         Ok(Scram {
-            name: name.into(),
+            name: format!("SCRAM-{}", S::name()),
+            username: username.into(),
             password: password.into(),
             client_nonce: generate_nonce()?,
             state: ScramState::Init,
+            channel_binding: None,
             _marker: PhantomData,
         })
     }
 
-    pub fn new_with_nonce<N: Into<String>, P: Into<String>>(name: N, password: P, nonce: String) -> Scram<S> {
+    pub fn new_with_nonce<N: Into<String>, P: Into<String>>(username: N, password: P, nonce: String) -> Scram<S> {
         Scram {
-            name: name.into(),
+            name: format!("SCRAM-{}", S::name()),
+            username: username.into(),
             password: password.into(),
             client_nonce: nonce,
             state: ScramState::Init,
+            channel_binding: None,
             _marker: PhantomData,
         }
     }
+
+    pub fn new_with_channel_binding<N: Into<String>, P: Into<String>>(username: N, password: P, channel_binding: Vec<u8>) -> Result<Scram<S>, Error> {
+        Ok(Scram {
+            name: format!("SCRAM-{}-PLUS", S::name()),
+            username: username.into(),
+            password: password.into(),
+            client_nonce: generate_nonce()?,
+            state: ScramState::Init,
+            channel_binding: Some(channel_binding),
+            _marker: PhantomData,
+        })
+    }
 }
 
 impl<S: ScramProvider> SaslMechanism for Scram<S> {
-    fn name() -> &'static str {
-        S::name()
+    fn name(&self) -> &str { // TODO: this is quite the workaround…
+        &self.name
     }
 
     fn initial(&mut self) -> Result<Vec<u8>, String> {
+        let mut gs2_header = Vec::new();
+        if let Some(_) = self.channel_binding {
+            gs2_header.extend(b"p=tls-unique,,");
+        }
+        else {
+            gs2_header.extend(b"n,,");
+        }
         let mut bare = Vec::new();
         bare.extend(b"n=");
-        bare.extend(self.name.bytes());
+        bare.extend(self.username.bytes());
         bare.extend(b",r=");
         bare.extend(self.client_nonce.bytes());
-        self.state = ScramState::SentInitialMessage { initial_message: bare.clone() };
         let mut data = Vec::new();
-        data.extend(b"n,,");
-        data.extend(bare);
+        data.extend(&gs2_header);
+        data.extend(bare.clone());
+        self.state = ScramState::SentInitialMessage { initial_message: bare, gs2_header: gs2_header };
         Ok(data)
     }
 
@@ -150,41 +196,24 @@ impl<S: ScramProvider> SaslMechanism for Scram<S> {
         let next_state;
         let ret;
         match self.state {
-            ScramState::SentInitialMessage { ref initial_message } => {
-                let chal = String::from_utf8(challenge.to_owned()).map_err(|_| "can't decode challenge".to_owned())?;
-                let mut server_nonce: Option<String> = None;
-                let mut salt: Option<Vec<u8>> = None;
-                let mut iterations: Option<usize> = None;
-                for s in chal.split(',') {
-                    let mut tmp = s.splitn(2, '=');
-                    let key = tmp.next();
-                    if let Some(val) = tmp.next() {
-                        match key {
-                            Some("r") => {
-                                if val.starts_with(&self.client_nonce) {
-                                    server_nonce = Some(val.to_owned());
-                                }
-                            },
-                            Some("s") => {
-                                if let Ok(s) = base64::decode(val) {
-                                    salt = Some(s);
-                                }
-                            },
-                            Some("i") => {
-                                if let Ok(iters) = val.parse() {
-                                    iterations = Some(iters);
-                                }
-                            },
-                            _ => (),
-                        }
-                    }
-                }
+            ScramState::SentInitialMessage { ref initial_message, ref gs2_header } => {
+                let frame = parse_frame(challenge).map_err(|_| "can't decode challenge".to_owned())?;
+                let server_nonce = frame.get("r");
+                let salt = frame.get("s").and_then(|v| base64::decode(v).ok());
+                let iterations = frame.get("i").and_then(|v| v.parse().ok());
                 let server_nonce = server_nonce.ok_or_else(|| "no server nonce".to_owned())?;
                 let salt = salt.ok_or_else(|| "no server salt".to_owned())?;
                 let iterations = iterations.ok_or_else(|| "no server iterations".to_owned())?;
                 // TODO: SASLprep
                 let mut client_final_message_bare = Vec::new();
-                client_final_message_bare.extend(b"c=biws,r=");
+                client_final_message_bare.extend(b"c=");
+                let mut cb_data: Vec<u8> = Vec::new();
+                cb_data.extend(gs2_header);
+                if let Some(ref cb) = self.channel_binding {
+                    cb_data.extend(cb);
+                }
+                client_final_message_bare.extend(base64::encode(gs2_header).bytes());
+                client_final_message_bare.extend(b",r=");
                 client_final_message_bare.extend(server_nonce.bytes());
                 let salted_password = S::derive(self.password.as_bytes(), &salt, iterations);
                 let client_key = S::hmac(b"Client Key", &salted_password);
@@ -215,25 +244,10 @@ impl<S: ScramProvider> SaslMechanism for Scram<S> {
     }
 
     fn success(&mut self, data: &[u8]) -> Result<(), String> {
-        let data = String::from_utf8(data.to_owned()).map_err(|_| "can't decode success message".to_owned())?;
-        let mut received_signature = None;
+        let frame = parse_frame(data).map_err(|_| "can't decode success response".to_owned())?;
         match self.state {
             ScramState::GotServerData { ref server_signature } => {
-                for s in data.split(',') {
-                    let mut tmp = s.splitn(2, '=');
-                    let key = tmp.next();
-                    if let Some(val) = tmp.next() {
-                        match key {
-                            Some("v") => {
-                                if let Ok(v) = base64::decode(val) {
-                                    received_signature = Some(v);
-                                }
-                            },
-                            _ => (),
-                        }
-                    }
-                }
-                if let Some(sig) = received_signature {
+                if let Some(sig) = frame.get("v").and_then(|v| base64::decode(&v).ok()) {
                     if sig == *server_signature {
                         Ok(())
                     }

src/sasl/mod.rs πŸ”—

@@ -2,7 +2,7 @@
 
 pub trait SaslMechanism {
     /// The name of the mechanism.
-    fn name() -> &'static str;
+    fn name(&self) -> &str;
 
     /// Provides initial payload of the SASL mechanism.
     fn initial(&mut self) -> Result<Vec<u8>, String> {