Ensure that base64 token values are URL-safe

Max Brunsfeld created

Change summary

zed-rpc/src/auth.rs | 54 +++++++++++++++++++++++++++++-----------------
1 file changed, 34 insertions(+), 20 deletions(-)

Detailed changes

zed-rpc/src/auth.rs 🔗

@@ -1,8 +1,7 @@
-use std::convert::{TryFrom, TryInto};
-
 use anyhow::{Context, Result};
-use rand::{rngs::OsRng, Rng as _};
+use rand::{thread_rng, Rng as _};
 use rsa::{PublicKey as _, PublicKeyEncoding, RSAPrivateKey, RSAPublicKey};
+use std::convert::TryFrom;
 
 pub struct PublicKey(RSAPublicKey);
 
@@ -10,7 +9,7 @@ pub struct PrivateKey(RSAPrivateKey);
 
 /// Generate a public and private key for asymmetric encryption.
 pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
-    let mut rng = OsRng;
+    let mut rng = thread_rng();
     let bits = 1024;
     let private_key = RSAPrivateKey::new(&mut rng, bits)?;
     let public_key = RSAPublicKey::from(&private_key);
@@ -19,25 +18,25 @@ pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
 
 /// Generate a random 64-character base64 string.
 pub fn random_token() -> String {
-    let mut rng = OsRng;
+    let mut rng = thread_rng();
     let mut token_bytes = [0; 48];
     for byte in token_bytes.iter_mut() {
         *byte = rng.gen();
     }
-    base64::encode(&token_bytes)
+    base64::encode_config(&token_bytes, base64::URL_SAFE)
 }
 
 impl PublicKey {
     /// Convert a string to a base64-encoded string that can only be decoded with the corresponding
     /// private key.
     pub fn encrypt_string(&self, string: &str) -> Result<String> {
-        let mut rng = OsRng;
+        let mut rng = thread_rng();
         let bytes = string.as_bytes();
         let encrypted_bytes = self
             .0
             .encrypt(&mut rng, PADDING_SCHEME, bytes)
             .context("failed to encrypt string with public key")?;
-        let encrypted_string = base64::encode(&encrypted_bytes);
+        let encrypted_string = base64::encode_config(&encrypted_bytes, base64::URL_SAFE);
         Ok(encrypted_string)
     }
 }
@@ -45,8 +44,8 @@ impl PublicKey {
 impl PrivateKey {
     /// Decrypt a base64-encoded string that was encrypted by the correspoding public key.
     pub fn decrypt_string(&self, encrypted_string: &str) -> Result<String> {
-        let encrypted_bytes =
-            base64::decode(encrypted_string).context("failed to base64-decode encrypted string")?;
+        let encrypted_bytes = base64::decode_config(encrypted_string, base64::URL_SAFE)
+            .context("failed to base64-decode encrypted string")?;
         let bytes = self
             .0
             .decrypt(PADDING_SCHEME, &encrypted_bytes)
@@ -56,14 +55,11 @@ impl PrivateKey {
     }
 }
 
-impl TryInto<String> for PublicKey {
+impl TryFrom<PublicKey> for String {
     type Error = anyhow::Error;
-    fn try_into(self) -> Result<String> {
-        let bytes = self
-            .0
-            .to_pkcs1()
-            .context("failed to serialize public key")?;
-        let string = base64::encode(&bytes);
+    fn try_from(key: PublicKey) -> Result<Self> {
+        let bytes = key.0.to_pkcs1().context("failed to serialize public key")?;
+        let string = base64::encode_config(&bytes, base64::URL_SAFE);
         Ok(string)
     }
 }
@@ -71,7 +67,8 @@ impl TryInto<String> for PublicKey {
 impl TryFrom<String> for PublicKey {
     type Error = anyhow::Error;
     fn try_from(value: String) -> Result<Self> {
-        let bytes = base64::decode(&value).context("failed to base64-decode public key string")?;
+        let bytes = base64::decode_config(&value, base64::URL_SAFE)
+            .context("failed to base64-decode public key string")?;
         let key = Self(RSAPublicKey::from_pkcs1(&bytes).context("failed to parse public key")?);
         Ok(key)
     }
@@ -89,13 +86,14 @@ mod tests {
         // * generate a keypair for asymmetric encryption
         // * serialize the public key to send it to the server.
         let (public, private) = keypair().unwrap();
-        let public_string: String = public.try_into().unwrap();
+        let public_string = String::try_from(public).unwrap();
+        assert_printable(&public_string);
 
         // SERVER:
         // * parse the public key
         // * generate a random token.
         // * encrypt the token using the public key.
-        let public: PublicKey = public_string.try_into().unwrap();
+        let public = PublicKey::try_from(public_string).unwrap();
         let token = random_token();
         let encrypted_token = public.encrypt_string(&token).unwrap();
         assert_eq!(token.len(), 64);
@@ -109,6 +107,20 @@ mod tests {
         assert_eq!(decrypted_token, token);
     }
 
+    #[test]
+    fn test_tokens_are_always_url_safe() {
+        for _ in 0..5 {
+            let token = random_token();
+            let (public_key, _) = keypair().unwrap();
+            let encrypted_token = public_key.encrypt_string(&token).unwrap();
+            let public_key_str = String::try_from(public_key).unwrap();
+
+            assert_printable(&token);
+            assert_printable(&public_key_str);
+            assert_printable(&encrypted_token);
+        }
+    }
+
     fn assert_printable(token: &str) {
         for c in token.chars() {
             assert!(
@@ -117,6 +129,8 @@ mod tests {
                 token,
                 c
             );
+            assert_ne!(c, '/', "token {:?} is not URL-safe", token);
+            assert_ne!(c, '&', "token {:?} is not URL-safe", token);
         }
     }
 }