scram.rs

  1use alloc::borrow::ToOwned;
  2use alloc::format;
  3use alloc::string::{String, ToString};
  4use alloc::vec::Vec;
  5use core::marker::PhantomData;
  6
  7use base64::{engine::general_purpose::STANDARD as Base64, Engine};
  8
  9use crate::common::scram::{generate_nonce, ScramProvider};
 10use crate::common::{parse_frame, xor, ChannelBinding, Identity};
 11use crate::secret;
 12use crate::secret::Pbkdf2Secret;
 13use crate::server::{Mechanism, MechanismError, Provider, Response};
 14
 15enum ScramState {
 16    Init,
 17    SentChallenge {
 18        initial_client_message: Vec<u8>,
 19        initial_server_message: Vec<u8>,
 20        gs2_header: Vec<u8>,
 21        server_nonce: String,
 22        identity: Identity,
 23        salted_password: Vec<u8>,
 24    },
 25    Done,
 26}
 27
 28pub struct Scram<S, P>
 29where
 30    S: ScramProvider,
 31    P: Provider<S::Secret>,
 32    S::Secret: secret::Pbkdf2Secret,
 33{
 34    name: String,
 35    state: ScramState,
 36    channel_binding: ChannelBinding,
 37    provider: P,
 38    _marker: PhantomData<S>,
 39}
 40
 41impl<S, P> Scram<S, P>
 42where
 43    S: ScramProvider,
 44    P: Provider<S::Secret>,
 45    S::Secret: secret::Pbkdf2Secret,
 46{
 47    pub fn new(provider: P, channel_binding: ChannelBinding) -> Scram<S, P> {
 48        Scram {
 49            name: format!("SCRAM-{}", S::name()),
 50            state: ScramState::Init,
 51            channel_binding,
 52            provider,
 53            _marker: PhantomData,
 54        }
 55    }
 56}
 57
 58impl<S, P> Mechanism for Scram<S, P>
 59where
 60    S: ScramProvider,
 61    P: Provider<S::Secret>,
 62    S::Secret: secret::Pbkdf2Secret,
 63{
 64    fn name(&self) -> &str {
 65        &self.name
 66    }
 67
 68    fn respond(&mut self, payload: &[u8]) -> Result<Response, MechanismError> {
 69        let next_state;
 70        let ret;
 71        match self.state {
 72            ScramState::Init => {
 73                // TODO: really ugly, mostly because parse_frame takes a &[u8] and i don't
 74                //       want to double validate utf-8
 75                //
 76                //       NEED TO CHANGE THIS THOUGH. IT'S AWFUL.
 77                let mut commas = 0;
 78                let mut idx = 0;
 79                for &b in payload {
 80                    idx += 1;
 81                    if b == 0x2C {
 82                        commas += 1;
 83                        if commas >= 2 {
 84                            break;
 85                        }
 86                    }
 87                }
 88                if commas < 2 {
 89                    return Err(MechanismError::FailedToDecodeMessage);
 90                }
 91                let gs2_header = payload[..idx].to_vec();
 92                let rest = payload[idx..].to_vec();
 93                // TODO: process gs2 header properly, not this ugly stuff
 94                match self.channel_binding {
 95                    ChannelBinding::None | ChannelBinding::Unsupported => {
 96                        // Not supported.
 97                        if gs2_header[0] != 0x79 {
 98                            // ord("y")
 99                            return Err(MechanismError::ChannelBindingNotSupported);
100                        }
101                    }
102                    ref other => {
103                        // Supported.
104                        if gs2_header[0] == 0x79 {
105                            // ord("y")
106                            return Err(MechanismError::ChannelBindingIsSupported);
107                        } else if !other.supports("tls-unique") {
108                            // TODO: grab the data
109                            return Err(MechanismError::ChannelBindingMechanismIncorrect);
110                        }
111                    }
112                }
113                let frame =
114                    parse_frame(&rest).map_err(|_| MechanismError::CannotDecodeInitialMessage)?;
115                let username = frame.get(&'n').ok_or(MechanismError::NoUsername)?;
116                let identity = Identity::Username(username.to_owned());
117                let client_nonce = frame.get(&'r').ok_or(MechanismError::NoNonce)?;
118                let mut server_nonce = String::new();
119                server_nonce += client_nonce;
120                server_nonce +=
121                    &generate_nonce().map_err(|_| MechanismError::FailedToGenerateNonce)?;
122                let pbkdf2 = self.provider.provide(&identity)?;
123                let mut buf = Vec::new();
124                buf.extend(b"r=");
125                buf.extend(server_nonce.bytes());
126                buf.extend(b",s=");
127                buf.extend(Base64.encode(pbkdf2.salt()).bytes());
128                buf.extend(b",i=");
129                buf.extend(pbkdf2.iterations().to_string().bytes());
130                ret = Response::Proceed(buf.clone());
131                next_state = ScramState::SentChallenge {
132                    server_nonce,
133                    identity,
134                    salted_password: pbkdf2.digest().to_vec(),
135                    initial_client_message: rest,
136                    initial_server_message: buf,
137                    gs2_header,
138                };
139            }
140            ScramState::SentChallenge {
141                ref server_nonce,
142                ref identity,
143                ref salted_password,
144                ref gs2_header,
145                ref initial_client_message,
146                ref initial_server_message,
147            } => {
148                let frame =
149                    parse_frame(payload).map_err(|_| MechanismError::CannotDecodeResponse)?;
150                let mut cb_data: Vec<u8> = Vec::new();
151                cb_data.extend(gs2_header);
152                cb_data.extend(self.channel_binding.data());
153                let mut client_final_message_bare = Vec::new();
154                client_final_message_bare.extend(b"c=");
155                client_final_message_bare.extend(Base64.encode(&cb_data).bytes());
156                client_final_message_bare.extend(b",r=");
157                client_final_message_bare.extend(server_nonce.bytes());
158                let client_key = S::hmac(b"Client Key", salted_password)?;
159                let server_key = S::hmac(b"Server Key", salted_password)?;
160                let mut auth_message = Vec::new();
161                auth_message.extend(initial_client_message);
162                auth_message.extend(b",");
163                auth_message.extend(initial_server_message);
164                auth_message.extend(b",");
165                auth_message.extend(client_final_message_bare.clone());
166                let stored_key = S::hash(&client_key);
167                let client_signature = S::hmac(&auth_message, &stored_key)?;
168                let client_proof = xor(&client_key, &client_signature);
169                let sent_proof = frame.get(&'p').ok_or(MechanismError::NoProof)?;
170                let sent_proof = Base64
171                    .decode(sent_proof)
172                    .map_err(|_| MechanismError::CannotDecodeProof)?;
173                if client_proof != sent_proof {
174                    return Err(MechanismError::AuthenticationFailed);
175                }
176                let server_signature = S::hmac(&auth_message, &server_key)?;
177                let mut buf = Vec::new();
178                buf.extend(b"v=");
179                buf.extend(Base64.encode(server_signature).bytes());
180                ret = Response::Success(identity.clone(), buf);
181                next_state = ScramState::Done;
182            }
183            ScramState::Done => {
184                return Err(MechanismError::SaslSessionAlreadyOver);
185            }
186        }
187        self.state = next_state;
188        Ok(ret)
189    }
190}