1use anyhow::{Context, Result};
2use rand::{thread_rng, Rng as _};
3use rsa::{PublicKey as _, PublicKeyEncoding, RSAPrivateKey, RSAPublicKey};
4use std::convert::TryFrom;
5
6pub struct PublicKey(RSAPublicKey);
7
8pub struct PrivateKey(RSAPrivateKey);
9
10/// Generate a public and private key for asymmetric encryption.
11pub fn keypair() -> Result<(PublicKey, PrivateKey)> {
12 let mut rng = thread_rng();
13 let bits = 1024;
14 let private_key = RSAPrivateKey::new(&mut rng, bits)?;
15 let public_key = RSAPublicKey::from(&private_key);
16 Ok((PublicKey(public_key), PrivateKey(private_key)))
17}
18
19/// Generate a random 64-character base64 string.
20pub fn random_token() -> String {
21 let mut rng = thread_rng();
22 let mut token_bytes = [0; 48];
23 for byte in token_bytes.iter_mut() {
24 *byte = rng.gen();
25 }
26 base64::encode_config(token_bytes, base64::URL_SAFE)
27}
28
29impl PublicKey {
30 /// Convert a string to a base64-encoded string that can only be decoded with the corresponding
31 /// private key.
32 pub fn encrypt_string(&self, string: &str) -> Result<String> {
33 let mut rng = thread_rng();
34 let bytes = string.as_bytes();
35 let encrypted_bytes = self
36 .0
37 .encrypt(&mut rng, PADDING_SCHEME, bytes)
38 .context("failed to encrypt string with public key")?;
39 let encrypted_string = base64::encode_config(&encrypted_bytes, base64::URL_SAFE);
40 Ok(encrypted_string)
41 }
42}
43
44impl PrivateKey {
45 /// Decrypt a base64-encoded string that was encrypted by the corresponding public key.
46 pub fn decrypt_string(&self, encrypted_string: &str) -> Result<String> {
47 let encrypted_bytes = base64::decode_config(encrypted_string, base64::URL_SAFE)
48 .context("failed to base64-decode encrypted string")?;
49 let bytes = self
50 .0
51 .decrypt(PADDING_SCHEME, &encrypted_bytes)
52 .context("failed to decrypt string with private key")?;
53 let string = String::from_utf8(bytes).context("decrypted content was not valid utf8")?;
54 Ok(string)
55 }
56}
57
58impl TryFrom<PublicKey> for String {
59 type Error = anyhow::Error;
60 fn try_from(key: PublicKey) -> Result<Self> {
61 let bytes = key.0.to_pkcs1().context("failed to serialize public key")?;
62 let string = base64::encode_config(&bytes, base64::URL_SAFE);
63 Ok(string)
64 }
65}
66
67impl TryFrom<String> for PublicKey {
68 type Error = anyhow::Error;
69 fn try_from(value: String) -> Result<Self> {
70 let bytes = base64::decode_config(&value, base64::URL_SAFE)
71 .context("failed to base64-decode public key string")?;
72 let key = Self(RSAPublicKey::from_pkcs1(&bytes).context("failed to parse public key")?);
73 Ok(key)
74 }
75}
76
77const PADDING_SCHEME: rsa::PaddingScheme = rsa::PaddingScheme::PKCS1v15Encrypt;
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
84 fn test_generate_encrypt_and_decrypt_token() {
85 // CLIENT:
86 // * generate a keypair for asymmetric encryption
87 // * serialize the public key to send it to the server.
88 let (public, private) = keypair().unwrap();
89 let public_string = String::try_from(public).unwrap();
90 assert_printable(&public_string);
91
92 // SERVER:
93 // * parse the public key
94 // * generate a random token.
95 // * encrypt the token using the public key.
96 let public = PublicKey::try_from(public_string).unwrap();
97 let token = random_token();
98 let encrypted_token = public.encrypt_string(&token).unwrap();
99 assert_eq!(token.len(), 64);
100 assert_ne!(encrypted_token, token);
101 assert_printable(&token);
102 assert_printable(&encrypted_token);
103
104 // CLIENT:
105 // * decrypt the token using the private key.
106 let decrypted_token = private.decrypt_string(&encrypted_token).unwrap();
107 assert_eq!(decrypted_token, token);
108 }
109
110 #[test]
111 fn test_tokens_are_always_url_safe() {
112 for _ in 0..5 {
113 let token = random_token();
114 let (public_key, _) = keypair().unwrap();
115 let encrypted_token = public_key.encrypt_string(&token).unwrap();
116 let public_key_str = String::try_from(public_key).unwrap();
117
118 assert_printable(&token);
119 assert_printable(&public_key_str);
120 assert_printable(&encrypted_token);
121 }
122 }
123
124 fn assert_printable(token: &str) {
125 for c in token.chars() {
126 assert!(
127 c.is_ascii_graphic(),
128 "token {:?} has non-printable char {}",
129 token,
130 c
131 );
132 assert_ne!(c, '/', "token {:?} is not URL-safe", token);
133 assert_ne!(c, '&', "token {:?} is not URL-safe", token);
134 }
135 }
136}