scram.rs

  1use alloc::string::{String, ToString};
  2use alloc::vec;
  3use alloc::vec::Vec;
  4use core::fmt;
  5use hmac::{digest::InvalidLength, Hmac, Mac};
  6use pbkdf2::pbkdf2;
  7use sha1::{Digest, Sha1 as Sha1_hash};
  8use sha2::Sha256 as Sha256_hash;
  9
 10use crate::common::Password;
 11
 12use crate::secret;
 13
 14use base64::{engine::general_purpose::STANDARD as Base64, Engine};
 15
 16/// Generate a nonce for SCRAM authentication.
 17pub fn generate_nonce() -> Result<String, getrandom::Error> {
 18    let mut data = [0u8; 32];
 19    getrandom::fill(&mut data)?;
 20    Ok(Base64.encode(data))
 21}
 22
 23#[derive(Debug, PartialEq)]
 24pub enum DeriveError {
 25    IncompatibleHashingMethod(String, String),
 26    IncorrectSalt,
 27    InvalidLength,
 28    IncompatibleIterationCount(u32, u32),
 29}
 30
 31impl fmt::Display for DeriveError {
 32    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
 33        match self {
 34            DeriveError::IncompatibleHashingMethod(one, two) => {
 35                write!(fmt, "incompatible hashing method, {} is not {}", one, two)
 36            }
 37            DeriveError::IncorrectSalt => write!(fmt, "incorrect salt"),
 38            DeriveError::InvalidLength => write!(fmt, "invalid length"),
 39            DeriveError::IncompatibleIterationCount(one, two) => {
 40                write!(fmt, "incompatible iteration count, {} is not {}", one, two)
 41            }
 42        }
 43    }
 44}
 45
 46impl core::error::Error for DeriveError {}
 47
 48impl From<hmac::digest::InvalidLength> for DeriveError {
 49    fn from(_err: hmac::digest::InvalidLength) -> DeriveError {
 50        DeriveError::InvalidLength
 51    }
 52}
 53
 54/// A trait which defines the needed methods for SCRAM.
 55pub trait ScramProvider {
 56    /// The kind of secret this `ScramProvider` requires.
 57    type Secret: secret::Secret;
 58
 59    /// The name of the hash function.
 60    fn name() -> &'static str;
 61
 62    /// A function which hashes the data using the hash function.
 63    fn hash(data: &[u8]) -> Vec<u8>;
 64
 65    /// A function which performs an HMAC using the hash function.
 66    fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength>;
 67
 68    /// A function which does PBKDF2 key derivation using the hash function.
 69    fn derive(data: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError>;
 70}
 71
 72/// A `ScramProvider` which provides SCRAM-SHA-1 and SCRAM-SHA-1-PLUS
 73pub struct Sha1;
 74
 75impl ScramProvider for Sha1 {
 76    type Secret = secret::Pbkdf2Sha1;
 77
 78    fn name() -> &'static str {
 79        "SHA-1"
 80    }
 81
 82    fn hash(data: &[u8]) -> Vec<u8> {
 83        let hash = Sha1_hash::digest(data);
 84        hash.to_vec()
 85    }
 86
 87    fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
 88        type HmacSha1 = Hmac<Sha1_hash>;
 89        let mut mac = HmacSha1::new_from_slice(key)?;
 90        mac.update(data);
 91        Ok(mac.finalize().into_bytes().to_vec())
 92    }
 93
 94    fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
 95        match *password {
 96            Password::Plain(ref plain) => {
 97                let mut result = vec![0; 20];
 98                pbkdf2::<Hmac<Sha1_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
 99                Ok(result)
100            }
101            Password::Pbkdf2 {
102                ref method,
103                salt: ref my_salt,
104                iterations: my_iterations,
105                ref data,
106            } => {
107                if method != Self::name() {
108                    Err(DeriveError::IncompatibleHashingMethod(
109                        method.to_string(),
110                        Self::name().to_string(),
111                    ))
112                } else if my_salt == salt {
113                    Err(DeriveError::IncorrectSalt)
114                } else if my_iterations == iterations {
115                    Err(DeriveError::IncompatibleIterationCount(
116                        my_iterations,
117                        iterations,
118                    ))
119                } else {
120                    Ok(data.to_vec())
121                }
122            }
123        }
124    }
125}
126
127/// A `ScramProvider` which provides SCRAM-SHA-256 and SCRAM-SHA-256-PLUS
128pub struct Sha256;
129
130impl ScramProvider for Sha256 {
131    type Secret = secret::Pbkdf2Sha256;
132
133    fn name() -> &'static str {
134        "SHA-256"
135    }
136
137    fn hash(data: &[u8]) -> Vec<u8> {
138        let hash = Sha256_hash::digest(data);
139        hash.to_vec()
140    }
141
142    fn hmac(data: &[u8], key: &[u8]) -> Result<Vec<u8>, InvalidLength> {
143        type HmacSha256 = Hmac<Sha256_hash>;
144        let mut mac = HmacSha256::new_from_slice(key)?;
145        mac.update(data);
146        Ok(mac.finalize().into_bytes().to_vec())
147    }
148
149    fn derive(password: &Password, salt: &[u8], iterations: u32) -> Result<Vec<u8>, DeriveError> {
150        match *password {
151            Password::Plain(ref plain) => {
152                let mut result = vec![0; 32];
153                pbkdf2::<Hmac<Sha256_hash>>(plain.as_bytes(), salt, iterations, &mut result)?;
154                Ok(result)
155            }
156            Password::Pbkdf2 {
157                ref method,
158                salt: ref my_salt,
159                iterations: my_iterations,
160                ref data,
161            } => {
162                if method != Self::name() {
163                    Err(DeriveError::IncompatibleHashingMethod(
164                        method.to_string(),
165                        Self::name().to_string(),
166                    ))
167                } else if my_salt == salt {
168                    Err(DeriveError::IncorrectSalt)
169                } else if my_iterations == iterations {
170                    Err(DeriveError::IncompatibleIterationCount(
171                        my_iterations,
172                        iterations,
173                    ))
174                } else {
175                    Ok(data.to_vec())
176                }
177            }
178        }
179    }
180}