auth.rs

  1use crate::{
  2    db::{self, AccessTokenId, Database, UserId},
  3    AppState, Error, Result,
  4};
  5use anyhow::{anyhow, Context};
  6use axum::{
  7    http::{self, Request, StatusCode},
  8    middleware::Next,
  9    response::IntoResponse,
 10};
 11use prometheus::{exponential_buckets, register_histogram, Histogram};
 12use scrypt::{
 13    password_hash::{PasswordHash, PasswordVerifier},
 14    Scrypt,
 15};
 16use serde::{Deserialize, Serialize};
 17use sha2::Digest;
 18use std::sync::OnceLock;
 19use std::{sync::Arc, time::Instant};
 20use subtle::ConstantTimeEq;
 21
 22#[derive(Clone, Debug, Default, PartialEq, Eq)]
 23pub struct Impersonator(pub Option<db::User>);
 24
 25/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
 26/// and one for the access tokens that we issue.
 27pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 28    let mut auth_header = req
 29        .headers()
 30        .get(http::header::AUTHORIZATION)
 31        .and_then(|header| header.to_str().ok())
 32        .ok_or_else(|| {
 33            Error::Http(
 34                StatusCode::UNAUTHORIZED,
 35                "missing authorization header".to_string(),
 36            )
 37        })?
 38        .split_whitespace();
 39
 40    let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
 41        Error::Http(
 42            StatusCode::BAD_REQUEST,
 43            "missing user id in authorization header".to_string(),
 44        )
 45    })?);
 46
 47    let access_token = auth_header.next().ok_or_else(|| {
 48        Error::Http(
 49            StatusCode::BAD_REQUEST,
 50            "missing access token in authorization header".to_string(),
 51        )
 52    })?;
 53
 54    let state = req.extensions().get::<Arc<AppState>>().unwrap();
 55
 56    // In development, allow impersonation using the admin API token.
 57    // Don't allow this in production because we can't tell who is doing
 58    // the impersonating.
 59    let validate_result = if let (Some(admin_token), true) = (
 60        access_token.strip_prefix("ADMIN_TOKEN:"),
 61        state.config.is_development(),
 62    ) {
 63        Ok(VerifyAccessTokenResult {
 64            is_valid: state.config.api_token == admin_token,
 65            impersonator_id: None,
 66        })
 67    } else {
 68        verify_access_token(&access_token, user_id, &state.db).await
 69    };
 70
 71    if let Ok(validate_result) = validate_result {
 72        if validate_result.is_valid {
 73            let user = state
 74                .db
 75                .get_user_by_id(user_id)
 76                .await?
 77                .ok_or_else(|| anyhow!("user {} not found", user_id))?;
 78
 79            let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
 80                let impersonator = state
 81                    .db
 82                    .get_user_by_id(impersonator_id)
 83                    .await?
 84                    .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
 85                Some(impersonator)
 86            } else {
 87                None
 88            };
 89            req.extensions_mut().insert(user);
 90            req.extensions_mut().insert(Impersonator(impersonator));
 91            return Ok::<_, Error>(next.run(req).await);
 92        }
 93    }
 94
 95    Err(Error::Http(
 96        StatusCode::UNAUTHORIZED,
 97        "invalid credentials".to_string(),
 98    ))
 99}
100
101const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
102
103#[derive(Serialize, Deserialize)]
104struct AccessTokenJson {
105    version: usize,
106    id: AccessTokenId,
107    token: String,
108}
109
110/// Creates a new access token to identify the given user. before returning it, you should
111/// encrypt it with the user's public key.
112pub async fn create_access_token(
113    db: &db::Database,
114    user_id: UserId,
115    impersonated_user_id: Option<UserId>,
116) -> Result<String> {
117    const VERSION: usize = 1;
118    let access_token = rpc::auth::random_token();
119    let access_token_hash = hash_access_token(&access_token);
120    let id = db
121        .create_access_token(
122            user_id,
123            impersonated_user_id,
124            &access_token_hash,
125            MAX_ACCESS_TOKENS_TO_STORE,
126        )
127        .await?;
128    Ok(serde_json::to_string(&AccessTokenJson {
129        version: VERSION,
130        id,
131        token: access_token,
132    })?)
133}
134
135/// Hashing prevents anyone with access to the database being able to login.
136/// As the token is randomly generated, we don't need to worry about scrypt-style
137/// protection.
138fn hash_access_token(token: &str) -> String {
139    let digest = sha2::Sha256::digest(token);
140    format!(
141        "$sha256${}",
142        base64::encode_config(digest, base64::URL_SAFE)
143    )
144}
145
146/// Encrypts the given access token with the given public key to avoid leaking it on the way
147/// to the client.
148pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
149    let native_app_public_key =
150        rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
151    let encrypted_access_token = native_app_public_key
152        .encrypt_string(access_token)
153        .context("failed to encrypt access token with public key")?;
154    Ok(encrypted_access_token)
155}
156
157pub struct VerifyAccessTokenResult {
158    pub is_valid: bool,
159    pub impersonator_id: Option<UserId>,
160}
161
162/// Checks that the given access token is valid for the given user.
163pub async fn verify_access_token(
164    token: &str,
165    user_id: UserId,
166    db: &Arc<Database>,
167) -> Result<VerifyAccessTokenResult> {
168    static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
169    let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
170        register_histogram!(
171            "access_token_hashing_time",
172            "time spent hashing access tokens",
173            exponential_buckets(10.0, 2.0, 10).unwrap(),
174        )
175        .unwrap()
176    });
177
178    let token: AccessTokenJson = serde_json::from_str(&token)?;
179
180    let db_token = db.get_access_token(token.id).await?;
181    let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
182    if token_user_id != user_id {
183        return Err(anyhow!("no such access token"))?;
184    }
185    let t0 = Instant::now();
186
187    let is_valid = if db_token.hash.starts_with("$scrypt$") {
188        let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
189        Scrypt
190            .verify_password(token.token.as_bytes(), &db_hash)
191            .is_ok()
192    } else {
193        let token_hash = hash_access_token(&token.token);
194        db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
195    };
196
197    let duration = t0.elapsed();
198    log::info!("hashed access token in {:?}", duration);
199    metric_access_token_hashing_time.observe(duration.as_millis() as f64);
200
201    if is_valid && db_token.hash.starts_with("$scrypt$") {
202        let new_hash = hash_access_token(&token.token);
203        db.update_access_token_hash(db_token.id, &new_hash).await?;
204    }
205
206    Ok(VerifyAccessTokenResult {
207        is_valid,
208        impersonator_id: if db_token.impersonated_user_id.is_some() {
209            Some(db_token.user_id)
210        } else {
211            None
212        },
213    })
214}
215
216#[cfg(test)]
217mod test {
218    use rand::thread_rng;
219    use scrypt::password_hash::{PasswordHasher, SaltString};
220    use sea_orm::EntityTrait;
221
222    use super::*;
223    use crate::db::{access_token, NewUserParams};
224
225    #[gpui::test]
226    async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
227        let test_db = crate::db::TestDb::postgres(cx.executor().clone());
228        let db = test_db.db();
229
230        let user = db
231            .create_user(
232                "example@example.com",
233                false,
234                NewUserParams {
235                    github_login: "example".into(),
236                    github_user_id: 1,
237                },
238            )
239            .await
240            .unwrap();
241
242        let token = create_access_token(&db, user.user_id, None).await.unwrap();
243        assert!(matches!(
244            verify_access_token(&token, user.user_id, &db)
245                .await
246                .unwrap(),
247            VerifyAccessTokenResult {
248                is_valid: true,
249                impersonator_id: None,
250            }
251        ));
252
253        let old_token = create_previous_access_token(user.user_id, None, &db)
254            .await
255            .unwrap();
256
257        let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
258            .unwrap()
259            .id;
260
261        let hash = db
262            .transaction(|tx| async move {
263                Ok(access_token::Entity::find_by_id(old_token_id)
264                    .one(&*tx)
265                    .await?)
266            })
267            .await
268            .unwrap()
269            .unwrap()
270            .hash;
271        assert!(hash.starts_with("$scrypt$"));
272
273        assert!(matches!(
274            verify_access_token(&old_token, user.user_id, &db)
275                .await
276                .unwrap(),
277            VerifyAccessTokenResult {
278                is_valid: true,
279                impersonator_id: None,
280            }
281        ));
282
283        let hash = db
284            .transaction(|tx| async move {
285                Ok(access_token::Entity::find_by_id(old_token_id)
286                    .one(&*tx)
287                    .await?)
288            })
289            .await
290            .unwrap()
291            .unwrap()
292            .hash;
293        assert!(hash.starts_with("$sha256$"));
294
295        assert!(matches!(
296            verify_access_token(&old_token, user.user_id, &db)
297                .await
298                .unwrap(),
299            VerifyAccessTokenResult {
300                is_valid: true,
301                impersonator_id: None,
302            }
303        ));
304
305        assert!(matches!(
306            verify_access_token(&token, user.user_id, &db)
307                .await
308                .unwrap(),
309            VerifyAccessTokenResult {
310                is_valid: true,
311                impersonator_id: None,
312            }
313        ));
314    }
315
316    async fn create_previous_access_token(
317        user_id: UserId,
318        impersonated_user_id: Option<UserId>,
319        db: &Database,
320    ) -> Result<String> {
321        let access_token = rpc::auth::random_token();
322        let access_token_hash = previous_hash_access_token(&access_token)?;
323        let id = db
324            .create_access_token(
325                user_id,
326                impersonated_user_id,
327                &access_token_hash,
328                MAX_ACCESS_TOKENS_TO_STORE,
329            )
330            .await?;
331        Ok(serde_json::to_string(&AccessTokenJson {
332            version: 1,
333            id,
334            token: access_token,
335        })?)
336    }
337
338    fn previous_hash_access_token(token: &str) -> Result<String> {
339        // Avoid slow hashing in debug mode.
340        let params = if cfg!(debug_assertions) {
341            scrypt::Params::new(1, 1, 1).unwrap()
342        } else {
343            scrypt::Params::new(14, 8, 1).unwrap()
344        };
345
346        Ok(Scrypt
347            .hash_password(
348                token.as_bytes(),
349                None,
350                params,
351                &SaltString::generate(thread_rng()),
352            )
353            .map_err(anyhow::Error::new)?
354            .to_string())
355    }
356}