auth.rs

  1use anyhow::{Context, Result};
  2use rand::{thread_rng, Rng as _};
  3use rsa::pkcs1::{DecodeRsaPublicKey, EncodeRsaPublicKey};
  4use rsa::traits::PaddingScheme;
  5use rsa::{Oaep, Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey};
  6use sha2::Sha256;
  7use std::convert::TryFrom;
  8
  9fn oaep_sha256_padding() -> impl PaddingScheme {
 10    Oaep::new::<Sha256>()
 11}
 12
 13#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 14pub enum EncryptionFormat {
 15    /// The original encryption format.
 16    ///
 17    /// This is using [`Pkcs1v15Encrypt`], which is vulnerable to side-channel attacks.
 18    /// As such, we're in the process of phasing it out.
 19    ///
 20    /// See [here](https://people.redhat.com/~hkario/marvin/) for more details.
 21    V0,
 22
 23    /// The new encryption key format using Optimal Asymmetric Encryption Padding (OAEP) with a SHA-256 digest.
 24    V1,
 25}
 26
 27pub struct PublicKey(RsaPublicKey);
 28
 29pub struct PrivateKey(RsaPrivateKey);
 30
 31/// Generate a public and private key for asymmetric encryption.
 32pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
 33    let mut rng = thread_rng();
 34    let bits = 2048;
 35    let private_key = RsaPrivateKey::new(&mut rng, bits)?;
 36    let public_key = RsaPublicKey::from(&private_key);
 37    Ok((PublicKey(public_key), PrivateKey(private_key)))
 38}
 39
 40/// Generate a random 64-character base64 string.
 41pub fn random_token() -> String {
 42    let mut rng = thread_rng();
 43    let mut token_bytes = [0; 48];
 44    for byte in token_bytes.iter_mut() {
 45        *byte = rng.gen();
 46    }
 47    base64::encode_config(token_bytes, base64::URL_SAFE)
 48}
 49
 50impl PublicKey {
 51    /// Convert a string to a base64-encoded string that can only be decoded with the corresponding
 52    /// private key.
 53    pub fn encrypt_string(&self, string: &str, format: EncryptionFormat) -> Result<String> {
 54        let mut rng = thread_rng();
 55        let bytes = string.as_bytes();
 56        let encrypted_bytes = match format {
 57            EncryptionFormat::V0 => self.0.encrypt(&mut rng, Pkcs1v15Encrypt, bytes),
 58            EncryptionFormat::V1 => self.0.encrypt(&mut rng, oaep_sha256_padding(), bytes),
 59        }
 60        .context("failed to encrypt string with public key")?;
 61        let encrypted_string = base64::encode_config(&encrypted_bytes, base64::URL_SAFE);
 62        Ok(encrypted_string)
 63    }
 64}
 65
 66impl PrivateKey {
 67    /// Decrypt a base64-encoded string that was encrypted by the corresponding public key.
 68    pub fn decrypt_string(&self, encrypted_string: &str) -> Result<String> {
 69        let encrypted_bytes = base64::decode_config(encrypted_string, base64::URL_SAFE)
 70            .context("failed to base64-decode encrypted string")?;
 71        let bytes = self
 72            .0
 73            .decrypt(oaep_sha256_padding(), &encrypted_bytes)
 74            .or_else(|_err| {
 75                // If we failed to decrypt using the new format, try decrypting with the old
 76                // one to handle mismatches between the client and server.
 77                self.0.decrypt(Pkcs1v15Encrypt, &encrypted_bytes)
 78            })
 79            .context("failed to decrypt string with private key")?;
 80        let string = String::from_utf8(bytes).context("decrypted content was not valid utf8")?;
 81        Ok(string)
 82    }
 83}
 84
 85impl TryFrom<PublicKey> for String {
 86    type Error = anyhow::Error;
 87    fn try_from(key: PublicKey) -> Result<Self> {
 88        let bytes = key
 89            .0
 90            .to_pkcs1_der()
 91            .context("failed to serialize public key")?;
 92        let string = base64::encode_config(&bytes, base64::URL_SAFE);
 93        Ok(string)
 94    }
 95}
 96
 97impl TryFrom<String> for PublicKey {
 98    type Error = anyhow::Error;
 99    fn try_from(value: String) -> Result<Self> {
100        let bytes = base64::decode_config(&value, base64::URL_SAFE)
101            .context("failed to base64-decode public key string")?;
102        let key = Self(RsaPublicKey::from_pkcs1_der(&bytes).context("failed to parse public key")?);
103        Ok(key)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_generate_encrypt_and_decrypt_token() {
113        // CLIENT:
114        // * generate a keypair for asymmetric encryption
115        // * serialize the public key to send it to the server.
116        let (public, private) = keypair().unwrap();
117        let public_string = String::try_from(public).unwrap();
118        assert_printable(&public_string);
119
120        // SERVER:
121        // * parse the public key
122        // * generate a random token.
123        // * encrypt the token using the public key.
124        let public = PublicKey::try_from(public_string).unwrap();
125        let token = random_token();
126        let encrypted_token = public.encrypt_string(&token, EncryptionFormat::V1).unwrap();
127        assert_eq!(token.len(), 64);
128        assert_ne!(encrypted_token, token);
129        assert_printable(&token);
130        assert_printable(&encrypted_token);
131
132        // CLIENT:
133        // * decrypt the token using the private key.
134        let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
135        assert_eq!(decrypted_token, token);
136    }
137
138    #[test]
139    fn test_generate_encrypt_and_decrypt_token_with_v0_encryption_format() {
140        // CLIENT:
141        // * generate a keypair for asymmetric encryption
142        // * serialize the public key to send it to the server.
143        let (public, private) = keypair().unwrap();
144        let public_string = String::try_from(public).unwrap();
145        assert_printable(&public_string);
146
147        // SERVER:
148        // * parse the public key
149        // * generate a random token.
150        // * encrypt the token using the public key.
151        let public = PublicKey::try_from(public_string).unwrap();
152        let token = random_token();
153        let encrypted_token = public.encrypt_string(&token, EncryptionFormat::V0).unwrap();
154        assert_eq!(token.len(), 64);
155        assert_ne!(encrypted_token, token);
156        assert_printable(&token);
157        assert_printable(&encrypted_token);
158
159        // CLIENT:
160        // * decrypt the token using the private key.
161        let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
162        assert_eq!(decrypted_token, token);
163    }
164
165    #[test]
166    fn test_encode_and_decode_base64_public_key() {
167        // A base64-encoded public key.
168        //
169        // We're using a literal string to ensure that encoding and decoding works across differences in implementations.
170        let encoded_public_key = "MIGJAoGBAMPvufou8wOuUIF1Wlkbtn0ZMM9nC55QJ06nTZvgMfZv5esFVU9-cQO_JC1P9ZoEcMDJweFERnQuQLqzsrMDLFbkdgL128ZU43WOLiQraxaICFIZsPUeTtWMKp2D5bPWsNxs-lnCma7vCAry6fpXuj5AKQdk7cTZJNucgvZQ0uUfAgMBAAE=".to_string();
171
172        // Make sure we can parse the public key.
173        let public_key = PublicKey::try_from(encoded_public_key.clone()).unwrap();
174
175        // Make sure we re-encode to the same format.
176        assert_eq!(encoded_public_key, String::try_from(public_key).unwrap());
177    }
178
179    #[test]
180    fn test_tokens_are_always_url_safe() {
181        for _ in 0..5 {
182            let token = random_token();
183            let (public_key, _) = keypair().unwrap();
184            let encrypted_token = public_key
185                .encrypt_string(&token, EncryptionFormat::V1)
186                .unwrap();
187            let public_key_str = String::try_from(public_key).unwrap();
188
189            assert_printable(&token);
190            assert_printable(&public_key_str);
191            assert_printable(&encrypted_token);
192        }
193    }
194
195    fn assert_printable(token: &str) {
196        for c in token.chars() {
197            assert!(
198                c.is_ascii_graphic(),
199                "token {:?} has non-printable char {}",
200                token,
201                c
202            );
203            assert_ne!(c, '/', "token {:?} is not URL-safe", token);
204            assert_ne!(c, '&', "token {:?} is not URL-safe", token);
205        }
206    }
207}