scram.rs

  1use getrandom::{getrandom, Error as RngError};
  2use hmac::{digest::InvalidLength, Hmac, Mac};
  3use pbkdf2::pbkdf2;
  4use sha1::{Digest, Sha1 as Sha1_hash};
  5use sha2::Sha256 as Sha256_hash;
  6
  7use crate::common::Password;
  8
  9use crate::secret;
 10
 11use base64::{engine::general_purpose::STANDARD as Base64, Engine};
 12
 13/// Generate a nonce for SCRAM authentication.
 14pub fn generate_nonce() -> Result<String, RngError> {
 15    let mut data = [0u8; 32];
 16    getrandom(&mut data)?;
 17    Ok(Base64.encode(data))
 18}
 19
 20#[derive(Debug, PartialEq)]
 21pub enum DeriveError {
 22    IncompatibleHashingMethod(String, String),
 23    IncorrectSalt,
 24    InvalidLength,
 25    IncompatibleIterationCount(u32, u32),
 26}
 27
 28impl std::fmt::Display for DeriveError {
 29    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
 30        match self {
 31            DeriveError::IncompatibleHashingMethod(one, two) => {
 32                write!(fmt, "incompatible hashing method, {} is not {}", one, two)
 33            }
 34            DeriveError::IncorrectSalt => write!(fmt, "incorrect salt"),
 35            DeriveError::InvalidLength => write!(fmt, "invalid length"),
 36            DeriveError::IncompatibleIterationCount(one, two) => {
 37                write!(fmt, "incompatible iteration count, {} is not {}", one, two)
 38            }
 39        }
 40    }
 41}
 42
 43impl std::error::Error for DeriveError {}
 44
 45impl From<hmac::digest::InvalidLength> for DeriveError {
 46    fn from(_err: hmac::digest::InvalidLength) -> DeriveError {
 47        DeriveError::InvalidLength
 48    }
 49}
 50
 51/// A trait which defines the needed methods for SCRAM.
 52pub trait ScramProvider {
 53    /// The kind of secret this `ScramProvider` requires.
 54    type Secret: secret::Secret;
 55
 56    /// The name of the hash function.
 57    fn name() -> &'static str;
 58
 59    /// A function which hashes the data using the hash function.
 60    fn hash(data: &[u8]) -> Vec<u8>;
 61
 62    /// A function which performs an HMAC using the hash function.
 63    fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength>;
 64
 65    /// A function which does PBKDF2 key derivation using the hash function.
 66    fn derive(data: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError>;
 67}
 68
 69/// A `ScramProvider` which provides SCRAM-SHA-1 and SCRAM-SHA-1-PLUS
 70pub struct Sha1;
 71
 72impl ScramProvider for Sha1 {
 73    type Secret = secret::Pbkdf2Sha1;
 74
 75    fn name() -> &'static str {
 76        "SHA-1"
 77    }
 78
 79    fn hash(data: &[u8]) -> Vec<u8> {
 80        let hash = Sha1_hash::digest(data);
 81        let mut vec = Vec::with_capacity(Sha1_hash::output_size());
 82        vec.extend_from_slice(hash.as_slice());
 83        vec
 84    }
 85
 86    fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
 87        type HmacSha1 = Hmac<Sha1_hash>;
 88        let mut mac = HmacSha1::new_from_slice(key)?;
 89        mac.update(data);
 90        let result = mac.finalize();
 91        let mut vec = Vec::with_capacity(Sha1_hash::output_size());
 92        vec.extend_from_slice(result.into_bytes().as_slice());
 93        Ok(vec)
 94    }
 95
 96    fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
 97        match *password {
 98            Password::Plain(ref plain) => {
 99                let mut result = vec![0; 20];
100                pbkdf2::<Hmac<Sha1_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
101                Ok(result)
102            }
103            Password::Pbkdf2 {
104                ref method,
105                salt: ref my_salt,
106                iterations: my_iterations,
107                ref data,
108            } => {
109                if method != Self::name() {
110                    Err(DeriveError::IncompatibleHashingMethod(
111                        method.to_string(),
112                        Self::name().to_string(),
113                    ))
114                } else if my_salt == salt {
115                    Err(DeriveError::IncorrectSalt)
116                } else if my_iterations == iterations {
117                    Err(DeriveError::IncompatibleIterationCount(
118                        my_iterations,
119                        iterations,
120                    ))
121                } else {
122                    Ok(data.to_vec())
123                }
124            }
125        }
126    }
127}
128
129/// A `ScramProvider` which provides SCRAM-SHA-256 and SCRAM-SHA-256-PLUS
130pub struct Sha256;
131
132impl ScramProvider for Sha256 {
133    type Secret = secret::Pbkdf2Sha256;
134
135    fn name() -> &'static str {
136        "SHA-256"
137    }
138
139    fn hash(data: &[u8]) -> Vec<u8> {
140        let hash = Sha256_hash::digest(data);
141        let mut vec = Vec::with_capacity(Sha256_hash::output_size());
142        vec.extend_from_slice(hash.as_slice());
143        vec
144    }
145
146    fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
147        type HmacSha256 = Hmac<Sha256_hash>;
148        let mut mac = HmacSha256::new_from_slice(key)?;
149        mac.update(data);
150        let result = mac.finalize();
151        let mut vec = Vec::with_capacity(Sha256_hash::output_size());
152        vec.extend_from_slice(result.into_bytes().as_slice());
153        Ok(vec)
154    }
155
156    fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
157        match *password {
158            Password::Plain(ref plain) => {
159                let mut result = vec![0; 32];
160                pbkdf2::<Hmac<Sha256_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
161                Ok(result)
162            }
163            Password::Pbkdf2 {
164                ref method,
165                salt: ref my_salt,
166                iterations: my_iterations,
167                ref data,
168            } => {
169                if method != Self::name() {
170                    Err(DeriveError::IncompatibleHashingMethod(
171                        method.to_string(),
172                        Self::name().to_string(),
173                    ))
174                } else if my_salt == salt {
175                    Err(DeriveError::IncorrectSalt)
176                } else if my_iterations == iterations {
177                    Err(DeriveError::IncompatibleIterationCount(
178                        my_iterations,
179                        iterations,
180                    ))
181                } else {
182                    Ok(data.to_vec())
183                }
184            }
185        }
186    }
187}