1use anyhow::{Context as _, Result};
2use base64::prelude::*;
3use rand::prelude::*;
4use rsa::pkcs1::{DecodeRsaPublicKey, EncodeRsaPublicKey};
5use rsa::traits::PaddingScheme;
6use rsa::{Oaep, Pkcs1v15Encrypt, RsaPrivateKey, RsaPublicKey};
7use sha2::Sha256;
8use std::convert::TryFrom;
9
10fn oaep_sha256_padding() -> impl PaddingScheme {
11 Oaep::new::<Sha256>()
12}
13
14#[derive(Debug, PartialEq, Eq, Clone, Copy)]
15pub enum EncryptionFormat {
16 /// The original encryption format.
17 ///
18 /// This is using [`Pkcs1v15Encrypt`], which is vulnerable to side-channel attacks.
19 /// As such, we're in the process of phasing it out.
20 ///
21 /// See [here](https://people.redhat.com/~hkario/marvin/) for more details.
22 V0,
23
24 /// The new encryption key format using Optimal Asymmetric Encryption Padding (OAEP) with a SHA-256 digest.
25 V1,
26}
27
28pub struct PublicKey(RsaPublicKey);
29
30pub struct PrivateKey(RsaPrivateKey);
31
32/// Generate a public and private key for asymmetric encryption.
33pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
34 let mut rng = RsaRngCompat::new();
35 let bits = 2048;
36 let private_key = RsaPrivateKey::new(&mut rng, bits)?;
37 let public_key = RsaPublicKey::from(&private_key);
38 Ok((PublicKey(public_key), PrivateKey(private_key)))
39}
40
41/// Generate a random 64-character base64 string.
42pub fn random_token() -> String {
43 let mut rng = rand::rng();
44 let mut token_bytes = [0; 48];
45 for byte in token_bytes.iter_mut() {
46 *byte = rng.random();
47 }
48 BASE64_URL_SAFE.encode(token_bytes)
49}
50
51impl PublicKey {
52 /// Convert a string to a base64-encoded string that can only be decoded with the corresponding
53 /// private key.
54 pub fn encrypt_string(&self, string: &str, format: EncryptionFormat) -> Result<String> {
55 let mut rng = RsaRngCompat::new();
56 let bytes = string.as_bytes();
57 let encrypted_bytes = match format {
58 EncryptionFormat::V0 => self.0.encrypt(&mut rng, Pkcs1v15Encrypt, bytes),
59 EncryptionFormat::V1 => self.0.encrypt(&mut rng, oaep_sha256_padding(), bytes),
60 }
61 .context("failed to encrypt string with public key")?;
62 let encrypted_string = BASE64_URL_SAFE.encode(&encrypted_bytes);
63 Ok(encrypted_string)
64 }
65}
66
67impl PrivateKey {
68 /// Decrypt a base64-encoded string that was encrypted by the corresponding public key.
69 pub fn decrypt_string(&self, encrypted_string: &str) -> Result<String> {
70 let encrypted_bytes = BASE64_URL_SAFE
71 .decode(encrypted_string)
72 .context("failed to base64-decode encrypted string")?;
73 let bytes = self
74 .0
75 .decrypt(oaep_sha256_padding(), &encrypted_bytes)
76 .or_else(|_err| {
77 // If we failed to decrypt using the new format, try decrypting with the old
78 // one to handle mismatches between the client and server.
79 self.0.decrypt(Pkcs1v15Encrypt, &encrypted_bytes)
80 })
81 .context("failed to decrypt string with private key")?;
82 let string = String::from_utf8(bytes).context("decrypted content was not valid utf8")?;
83 Ok(string)
84 }
85}
86
87impl TryFrom<PublicKey> for String {
88 type Error = anyhow::Error;
89 fn try_from(key: PublicKey) -> Result<Self> {
90 let bytes = key
91 .0
92 .to_pkcs1_der()
93 .context("failed to serialize public key")?;
94 let string = BASE64_URL_SAFE.encode(&bytes);
95 Ok(string)
96 }
97}
98
99impl TryFrom<String> for PublicKey {
100 type Error = anyhow::Error;
101 fn try_from(value: String) -> Result<Self> {
102 let bytes = BASE64_URL_SAFE
103 .decode(&value)
104 .context("failed to base64-decode public key string")?;
105 let key = Self(RsaPublicKey::from_pkcs1_der(&bytes).context("failed to parse public key")?);
106 Ok(key)
107 }
108}
109
110// TODO: remove once we rsa v0.10 is released.
111struct RsaRngCompat(rand::rngs::ThreadRng);
112
113impl RsaRngCompat {
114 fn new() -> Self {
115 Self(rand::rng())
116 }
117}
118
119impl rsa::signature::rand_core::RngCore for RsaRngCompat {
120 fn next_u32(&mut self) -> u32 {
121 self.0.next_u32()
122 }
123
124 fn next_u64(&mut self) -> u64 {
125 self.0.next_u64()
126 }
127
128 fn fill_bytes(&mut self, dest: &mut [u8]) {
129 self.0.fill_bytes(dest);
130 }
131
132 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rsa::signature::rand_core::Error> {
133 self.fill_bytes(dest);
134 Ok(())
135 }
136}
137
138impl rsa::signature::rand_core::CryptoRng for RsaRngCompat {}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn test_generate_encrypt_and_decrypt_token() {
146 // CLIENT:
147 // * generate a keypair for asymmetric encryption
148 // * serialize the public key to send it to the server.
149 let (public, private) = keypair().unwrap();
150 let public_string = String::try_from(public).unwrap();
151 assert_printable(&public_string);
152
153 // SERVER:
154 // * parse the public key
155 // * generate a random token.
156 // * encrypt the token using the public key.
157 let public = PublicKey::try_from(public_string).unwrap();
158 let token = random_token();
159 let encrypted_token = public.encrypt_string(&token, EncryptionFormat::V1).unwrap();
160 assert_eq!(token.len(), 64);
161 assert_ne!(encrypted_token, token);
162 assert_printable(&token);
163 assert_printable(&encrypted_token);
164
165 // CLIENT:
166 // * decrypt the token using the private key.
167 let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
168 assert_eq!(decrypted_token, token);
169 }
170
171 #[test]
172 fn test_generate_encrypt_and_decrypt_token_with_v0_encryption_format() {
173 // CLIENT:
174 // * generate a keypair for asymmetric encryption
175 // * serialize the public key to send it to the server.
176 let (public, private) = keypair().unwrap();
177 let public_string = String::try_from(public).unwrap();
178 assert_printable(&public_string);
179
180 // SERVER:
181 // * parse the public key
182 // * generate a random token.
183 // * encrypt the token using the public key.
184 let public = PublicKey::try_from(public_string).unwrap();
185 let token = random_token();
186 let encrypted_token = public.encrypt_string(&token, EncryptionFormat::V0).unwrap();
187 assert_eq!(token.len(), 64);
188 assert_ne!(encrypted_token, token);
189 assert_printable(&token);
190 assert_printable(&encrypted_token);
191
192 // CLIENT:
193 // * decrypt the token using the private key.
194 let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
195 assert_eq!(decrypted_token, token);
196 }
197
198 #[test]
199 fn test_encode_and_decode_base64_public_key() {
200 // A base64-encoded public key.
201 //
202 // We're using a literal string to ensure that encoding and decoding works across differences in implementations.
203 let encoded_public_key = "MIGJAoGBAMPvufou8wOuUIF1Wlkbtn0ZMM9nC55QJ06nTZvgMfZv5esFVU9-cQO_JC1P9ZoEcMDJweFERnQuQLqzsrMDLFbkdgL128ZU43WOLiQraxaICFIZsPUeTtWMKp2D5bPWsNxs-lnCma7vCAry6fpXuj5AKQdk7cTZJNucgvZQ0uUfAgMBAAE=".to_string();
204
205 // Make sure we can parse the public key.
206 let public_key = PublicKey::try_from(encoded_public_key.clone()).unwrap();
207
208 // Make sure we re-encode to the same format.
209 assert_eq!(encoded_public_key, String::try_from(public_key).unwrap());
210 }
211
212 #[test]
213 fn test_tokens_are_always_url_safe() {
214 for _ in 0..5 {
215 let token = random_token();
216 let (public_key, _) = keypair().unwrap();
217 let encrypted_token = public_key
218 .encrypt_string(&token, EncryptionFormat::V1)
219 .unwrap();
220 let public_key_str = String::try_from(public_key).unwrap();
221
222 assert_printable(&token);
223 assert_printable(&public_key_str);
224 assert_printable(&encrypted_token);
225 }
226 }
227
228 fn assert_printable(token: &str) {
229 for c in token.chars() {
230 assert!(
231 c.is_ascii_graphic(),
232 "token {:?} has non-printable char {}",
233 token,
234 c
235 );
236 assert_ne!(c, '/', "token {:?} is not URL-safe", token);
237 assert_ne!(c, '&', "token {:?} is not URL-safe", token);
238 }
239 }
240}