1use anyhow::{Context, Result};
2use rand::{thread_rng, Rng as _};
3use rsa::pkcs1::{DecodeRsaPublicKey, EncodeRsaPublicKey};
4use rsa::traits::PaddingScheme;
5use rsa::{Oaep, Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey};
6use sha2::Sha256;
7use std::convert::TryFrom;
8
9fn oaep_sha256_padding() -> impl PaddingScheme {
10 Oaep::new::<Sha256>()
11}
12
13#[derive(Debug, PartialEq, Eq, Clone, Copy)]
14pub enum EncryptionFormat {
15 /// The original encryption format.
16 ///
17 /// This is using [`Pkcs1v15Encrypt`], which is vulnerable to side-channel attacks.
18 /// As such, we're in the process of phasing it out.
19 ///
20 /// See [here](https://people.redhat.com/~hkario/marvin/) for more details.
21 V0,
22
23 /// The new encryption key format using Optimal Asymmetric Encryption Padding (OAEP) with a SHA-256 digest.
24 V1,
25}
26
27pub struct PublicKey(RsaPublicKey);
28
29pub struct PrivateKey(RsaPrivateKey);
30
31/// Generate a public and private key for asymmetric encryption.
32pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
33 let mut rng = thread_rng();
34 let bits = 2048;
35 let private_key = RsaPrivateKey::new(&mut rng, bits)?;
36 let public_key = RsaPublicKey::from(&private_key);
37 Ok((PublicKey(public_key), PrivateKey(private_key)))
38}
39
40/// Generate a random 64-character base64 string.
41pub fn random_token() -> String {
42 let mut rng = thread_rng();
43 let mut token_bytes = [0; 48];
44 for byte in token_bytes.iter_mut() {
45 *byte = rng.gen();
46 }
47 base64::encode_config(token_bytes, base64::URL_SAFE)
48}
49
50impl PublicKey {
51 /// Convert a string to a base64-encoded string that can only be decoded with the corresponding
52 /// private key.
53 pub fn encrypt_string(&self, string: &str, format: EncryptionFormat) -> Result<String> {
54 let mut rng = thread_rng();
55 let bytes = string.as_bytes();
56 let encrypted_bytes = match format {
57 EncryptionFormat::V0 => self.0.encrypt(&mut rng, Pkcs1v15Encrypt, bytes),
58 EncryptionFormat::V1 => self.0.encrypt(&mut rng, oaep_sha256_padding(), bytes),
59 }
60 .context("failed to encrypt string with public key")?;
61 let encrypted_string = base64::encode_config(&encrypted_bytes, base64::URL_SAFE);
62 Ok(encrypted_string)
63 }
64}
65
66impl PrivateKey {
67 /// Decrypt a base64-encoded string that was encrypted by the corresponding public key.
68 pub fn decrypt_string(&self, encrypted_string: &str) -> Result<String> {
69 let encrypted_bytes = base64::decode_config(encrypted_string, base64::URL_SAFE)
70 .context("failed to base64-decode encrypted string")?;
71 let bytes = self
72 .0
73 .decrypt(oaep_sha256_padding(), &encrypted_bytes)
74 .or_else(|_err| {
75 // If we failed to decrypt using the new format, try decrypting with the old
76 // one to handle mismatches between the client and server.
77 self.0.decrypt(Pkcs1v15Encrypt, &encrypted_bytes)
78 })
79 .context("failed to decrypt string with private key")?;
80 let string = String::from_utf8(bytes).context("decrypted content was not valid utf8")?;
81 Ok(string)
82 }
83}
84
85impl TryFrom<PublicKey> for String {
86 type Error = anyhow::Error;
87 fn try_from(key: PublicKey) -> Result<Self> {
88 let bytes = key
89 .0
90 .to_pkcs1_der()
91 .context("failed to serialize public key")?;
92 let string = base64::encode_config(&bytes, base64::URL_SAFE);
93 Ok(string)
94 }
95}
96
97impl TryFrom<String> for PublicKey {
98 type Error = anyhow::Error;
99 fn try_from(value: String) -> Result<Self> {
100 let bytes = base64::decode_config(&value, base64::URL_SAFE)
101 .context("failed to base64-decode public key string")?;
102 let key = Self(RsaPublicKey::from_pkcs1_der(&bytes).context("failed to parse public key")?);
103 Ok(key)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn test_generate_encrypt_and_decrypt_token() {
113 // CLIENT:
114 // * generate a keypair for asymmetric encryption
115 // * serialize the public key to send it to the server.
116 let (public, private) = keypair().unwrap();
117 let public_string = String::try_from(public).unwrap();
118 assert_printable(&public_string);
119
120 // SERVER:
121 // * parse the public key
122 // * generate a random token.
123 // * encrypt the token using the public key.
124 let public = PublicKey::try_from(public_string).unwrap();
125 let token = random_token();
126 let encrypted_token = public.encrypt_string(&token, EncryptionFormat::V1).unwrap();
127 assert_eq!(token.len(), 64);
128 assert_ne!(encrypted_token, token);
129 assert_printable(&token);
130 assert_printable(&encrypted_token);
131
132 // CLIENT:
133 // * decrypt the token using the private key.
134 let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
135 assert_eq!(decrypted_token, token);
136 }
137
138 #[test]
139 fn test_generate_encrypt_and_decrypt_token_with_v0_encryption_format() {
140 // CLIENT:
141 // * generate a keypair for asymmetric encryption
142 // * serialize the public key to send it to the server.
143 let (public, private) = keypair().unwrap();
144 let public_string = String::try_from(public).unwrap();
145 assert_printable(&public_string);
146
147 // SERVER:
148 // * parse the public key
149 // * generate a random token.
150 // * encrypt the token using the public key.
151 let public = PublicKey::try_from(public_string).unwrap();
152 let token = random_token();
153 let encrypted_token = public.encrypt_string(&token, EncryptionFormat::V0).unwrap();
154 assert_eq!(token.len(), 64);
155 assert_ne!(encrypted_token, token);
156 assert_printable(&token);
157 assert_printable(&encrypted_token);
158
159 // CLIENT:
160 // * decrypt the token using the private key.
161 let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
162 assert_eq!(decrypted_token, token);
163 }
164
165 #[test]
166 fn test_encode_and_decode_base64_public_key() {
167 // A base64-encoded public key.
168 //
169 // We're using a literal string to ensure that encoding and decoding works across differences in implementations.
170 let encoded_public_key = "MIGJAoGBAMPvufou8wOuUIF1Wlkbtn0ZMM9nC55QJ06nTZvgMfZv5esFVU9-cQO_JC1P9ZoEcMDJweFERnQuQLqzsrMDLFbkdgL128ZU43WOLiQraxaICFIZsPUeTtWMKp2D5bPWsNxs-lnCma7vCAry6fpXuj5AKQdk7cTZJNucgvZQ0uUfAgMBAAE=".to_string();
171
172 // Make sure we can parse the public key.
173 let public_key = PublicKey::try_from(encoded_public_key.clone()).unwrap();
174
175 // Make sure we re-encode to the same format.
176 assert_eq!(encoded_public_key, String::try_from(public_key).unwrap());
177 }
178
179 #[test]
180 fn test_tokens_are_always_url_safe() {
181 for _ in 0..5 {
182 let token = random_token();
183 let (public_key, _) = keypair().unwrap();
184 let encrypted_token = public_key
185 .encrypt_string(&token, EncryptionFormat::V1)
186 .unwrap();
187 let public_key_str = String::try_from(public_key).unwrap();
188
189 assert_printable(&token);
190 assert_printable(&public_key_str);
191 assert_printable(&encrypted_token);
192 }
193 }
194
195 fn assert_printable(token: &str) {
196 for c in token.chars() {
197 assert!(
198 c.is_ascii_graphic(),
199 "token {:?} has non-printable char {}",
200 token,
201 c
202 );
203 assert_ne!(c, '/', "token {:?} is not URL-safe", token);
204 assert_ne!(c, '&', "token {:?} is not URL-safe", token);
205 }
206 }
207}