1use getrandom::{getrandom, Error as RngError};
2use hmac::{digest::InvalidLength, Hmac, Mac};
3use pbkdf2::pbkdf2;
4use sha1::{Digest, Sha1 as Sha1_hash};
5use sha2::Sha256 as Sha256_hash;
6
7use crate::common::Password;
8
9use crate::secret;
10
11use base64::{engine::general_purpose::STANDARD as Base64, Engine};
12
13/// Generate a nonce for SCRAM authentication.
14pub fn generate_nonce() -> Result<String, RngError> {
15 let mut data = [0u8; 32];
16 getrandom(&mut data)?;
17 Ok(Base64.encode(data))
18}
19
20#[derive(Debug, PartialEq)]
21pub enum DeriveError {
22 IncompatibleHashingMethod(String, String),
23 IncorrectSalt,
24 InvalidLength,
25 IncompatibleIterationCount(u32, u32),
26}
27
28impl std::fmt::Display for DeriveError {
29 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
30 match self {
31 DeriveError::IncompatibleHashingMethod(one, two) => {
32 write!(fmt, "incompatible hashing method, {} is not {}", one, two)
33 }
34 DeriveError::IncorrectSalt => write!(fmt, "incorrect salt"),
35 DeriveError::InvalidLength => write!(fmt, "invalid length"),
36 DeriveError::IncompatibleIterationCount(one, two) => {
37 write!(fmt, "incompatible iteration count, {} is not {}", one, two)
38 }
39 }
40 }
41}
42
43impl std::error::Error for DeriveError {}
44
45impl From<hmac::digest::InvalidLength> for DeriveError {
46 fn from(_err: hmac::digest::InvalidLength) -> DeriveError {
47 DeriveError::InvalidLength
48 }
49}
50
51/// A trait which defines the needed methods for SCRAM.
52pub trait ScramProvider {
53 /// The kind of secret this `ScramProvider` requires.
54 type Secret: secret::Secret;
55
56 /// The name of the hash function.
57 fn name() -> &'static str;
58
59 /// A function which hashes the data using the hash function.
60 fn hash(data: &[u8]) -> Vec<u8>;
61
62 /// A function which performs an HMAC using the hash function.
63 fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength>;
64
65 /// A function which does PBKDF2 key derivation using the hash function.
66 fn derive(data: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError>;
67}
68
69/// A `ScramProvider` which provides SCRAM-SHA-1 and SCRAM-SHA-1-PLUS
70pub struct Sha1;
71
72impl ScramProvider for Sha1 {
73 type Secret = secret::Pbkdf2Sha1;
74
75 fn name() -> &'static str {
76 "SHA-1"
77 }
78
79 fn hash(data: &[u8]) -> Vec<u8> {
80 let hash = Sha1_hash::digest(data);
81 let mut vec = Vec::with_capacity(Sha1_hash::output_size());
82 vec.extend_from_slice(hash.as_slice());
83 vec
84 }
85
86 fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
87 type HmacSha1 = Hmac<Sha1_hash>;
88 let mut mac = HmacSha1::new_from_slice(key)?;
89 mac.update(data);
90 let result = mac.finalize();
91 let mut vec = Vec::with_capacity(Sha1_hash::output_size());
92 vec.extend_from_slice(result.into_bytes().as_slice());
93 Ok(vec)
94 }
95
96 fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
97 match *password {
98 Password::Plain(ref plain) => {
99 let mut result = vec![0; 20];
100 pbkdf2::<Hmac<Sha1_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
101 Ok(result)
102 }
103 Password::Pbkdf2 {
104 ref method,
105 salt: ref my_salt,
106 iterations: my_iterations,
107 ref data,
108 } => {
109 if method != Self::name() {
110 Err(DeriveError::IncompatibleHashingMethod(
111 method.to_string(),
112 Self::name().to_string(),
113 ))
114 } else if my_salt == salt {
115 Err(DeriveError::IncorrectSalt)
116 } else if my_iterations == iterations {
117 Err(DeriveError::IncompatibleIterationCount(
118 my_iterations,
119 iterations,
120 ))
121 } else {
122 Ok(data.to_vec())
123 }
124 }
125 }
126 }
127}
128
129/// A `ScramProvider` which provides SCRAM-SHA-256 and SCRAM-SHA-256-PLUS
130pub struct Sha256;
131
132impl ScramProvider for Sha256 {
133 type Secret = secret::Pbkdf2Sha256;
134
135 fn name() -> &'static str {
136 "SHA-256"
137 }
138
139 fn hash(data: &[u8]) -> Vec<u8> {
140 let hash = Sha256_hash::digest(data);
141 let mut vec = Vec::with_capacity(Sha256_hash::output_size());
142 vec.extend_from_slice(hash.as_slice());
143 vec
144 }
145
146 fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
147 type HmacSha256 = Hmac<Sha256_hash>;
148 let mut mac = HmacSha256::new_from_slice(key)?;
149 mac.update(data);
150 let result = mac.finalize();
151 let mut vec = Vec::with_capacity(Sha256_hash::output_size());
152 vec.extend_from_slice(result.into_bytes().as_slice());
153 Ok(vec)
154 }
155
156 fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
157 match *password {
158 Password::Plain(ref plain) => {
159 let mut result = vec![0; 32];
160 pbkdf2::<Hmac<Sha256_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
161 Ok(result)
162 }
163 Password::Pbkdf2 {
164 ref method,
165 salt: ref my_salt,
166 iterations: my_iterations,
167 ref data,
168 } => {
169 if method != Self::name() {
170 Err(DeriveError::IncompatibleHashingMethod(
171 method.to_string(),
172 Self::name().to_string(),
173 ))
174 } else if my_salt == salt {
175 Err(DeriveError::IncorrectSalt)
176 } else if my_iterations == iterations {
177 Err(DeriveError::IncompatibleIterationCount(
178 my_iterations,
179 iterations,
180 ))
181 } else {
182 Ok(data.to_vec())
183 }
184 }
185 }
186 }
187}