@@ -1,5 +1,5 @@
use crate::{
- db::{self, UserId},
+ db::{self, AccessTokenId, Database, UserId},
AppState, Error, Result,
};
use anyhow::{anyhow, Context};
@@ -8,12 +8,24 @@ use axum::{
middleware::Next,
response::IntoResponse,
};
+use lazy_static::lazy_static;
+use prometheus::{exponential_buckets, register_histogram, Histogram};
use rand::thread_rng;
use scrypt::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Scrypt,
};
-use std::sync::Arc;
+use serde::{Deserialize, Serialize};
+use std::{sync::Arc, time::Instant};
+
+lazy_static! {
+ static ref METRIC_ACCESS_TOKEN_HASHING_TIME: Histogram = register_histogram!(
+ "access_token_hashing_time",
+ "time spent hashing access tokens",
+ exponential_buckets(10.0, 2.0, 10).unwrap(),
+ )
+ .unwrap();
+}
pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
let mut auth_header = req
@@ -42,20 +54,14 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
)
})?;
- let mut credentials_valid = false;
let state = req.extensions().get::<Arc<AppState>>().unwrap();
- if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
- if state.config.api_token == admin_token {
- credentials_valid = true;
- }
+ let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
+ state.config.api_token == admin_token
} else {
- for password_hash in state.db.get_access_token_hashes(user_id).await? {
- if verify_access_token(access_token, &password_hash)? {
- credentials_valid = true;
- break;
- }
- }
- }
+ verify_access_token(&access_token, user_id, &state.db)
+ .await
+ .unwrap_or(false)
+ };
if credentials_valid {
let user = state
@@ -75,13 +81,26 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
+#[derive(Serialize, Deserialize)]
+struct AccessTokenJson {
+ version: usize,
+ id: AccessTokenId,
+ token: String,
+}
+
pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
+ const VERSION: usize = 1;
let access_token = rpc::auth::random_token();
let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?;
- db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
+ let id = db
+ .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
.await?;
- Ok(access_token)
+ Ok(serde_json::to_string(&AccessTokenJson {
+ version: VERSION,
+ id,
+ token: access_token,
+ })?)
}
fn hash_access_token(token: &str) -> Result<String> {
@@ -89,7 +108,7 @@ fn hash_access_token(token: &str) -> Result<String> {
let params = if cfg!(debug_assertions) {
scrypt::Params::new(1, 1, 1).unwrap()
} else {
- scrypt::Params::recommended()
+ scrypt::Params::new(14, 8, 1).unwrap()
};
Ok(Scrypt
@@ -112,7 +131,21 @@ pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<St
Ok(encrypted_access_token)
}
-pub fn verify_access_token(token: &str, hash: &str) -> Result<bool> {
- let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?;
- Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
+pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database>) -> Result<bool> {
+ let token: AccessTokenJson = serde_json::from_str(&token)?;
+
+ let db_token = db.get_access_token(token.id).await?;
+ if db_token.user_id != user_id {
+ return Err(anyhow!("no such access token"))?;
+ }
+
+ let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
+ let t0 = Instant::now();
+ let is_valid = Scrypt
+ .verify_password(token.token.as_bytes(), &db_hash)
+ .is_ok();
+ let duration = t0.elapsed();
+ log::info!("hashed access token in {:?}", duration);
+ METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
+ Ok(is_valid)
}
@@ -2746,16 +2746,16 @@ impl Database {
// access tokens
- pub async fn create_access_token_hash(
+ pub async fn create_access_token(
&self,
user_id: UserId,
access_token_hash: &str,
max_access_token_count: usize,
- ) -> Result<()> {
+ ) -> Result<AccessTokenId> {
self.transaction(|tx| async {
let tx = tx;
- access_token::ActiveModel {
+ let token = access_token::ActiveModel {
user_id: ActiveValue::set(user_id),
hash: ActiveValue::set(access_token_hash.into()),
..Default::default()
@@ -2778,26 +2778,20 @@ impl Database {
)
.exec(&*tx)
.await?;
- Ok(())
+ Ok(token.id)
})
.await
}
- pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
- #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
- enum QueryAs {
- Hash,
- }
-
+ pub async fn get_access_token(
+ &self,
+ access_token_id: AccessTokenId,
+ ) -> Result<access_token::Model> {
self.transaction(|tx| async move {
- Ok(access_token::Entity::find()
- .select_only()
- .column(access_token::Column::Hash)
- .filter(access_token::Column::UserId.eq(user_id))
- .order_by_desc(access_token::Column::Id)
- .into_values::<_, QueryAs>()
- .all(&*tx)
- .await?)
+ Ok(access_token::Entity::find_by_id(access_token_id)
+ .one(&*tx)
+ .await?
+ .ok_or_else(|| anyhow!("no such access token"))?)
})
.await
}
@@ -177,30 +177,63 @@ test_both_dbs!(
.unwrap()
.user_id;
- db.create_access_token_hash(user, "h1", 3).await.unwrap();
- db.create_access_token_hash(user, "h2", 3).await.unwrap();
+ let token_1 = db.create_access_token(user, "h1", 2).await.unwrap();
+ let token_2 = db.create_access_token(user, "h2", 2).await.unwrap();
assert_eq!(
- db.get_access_token_hashes(user).await.unwrap(),
- &["h2".to_string(), "h1".to_string()]
+ db.get_access_token(token_1).await.unwrap(),
+ access_token::Model {
+ id: token_1,
+ user_id: user,
+ hash: "h1".into(),
+ }
);
-
- db.create_access_token_hash(user, "h3", 3).await.unwrap();
assert_eq!(
- db.get_access_token_hashes(user).await.unwrap(),
- &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
+ db.get_access_token(token_2).await.unwrap(),
+ access_token::Model {
+ id: token_2,
+ user_id: user,
+ hash: "h2".into()
+ }
);
- db.create_access_token_hash(user, "h4", 3).await.unwrap();
+ let token_3 = db.create_access_token(user, "h3", 2).await.unwrap();
assert_eq!(
- db.get_access_token_hashes(user).await.unwrap(),
- &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
+ db.get_access_token(token_3).await.unwrap(),
+ access_token::Model {
+ id: token_3,
+ user_id: user,
+ hash: "h3".into()
+ }
);
+ assert_eq!(
+ db.get_access_token(token_2).await.unwrap(),
+ access_token::Model {
+ id: token_2,
+ user_id: user,
+ hash: "h2".into()
+ }
+ );
+ assert!(db.get_access_token(token_1).await.is_err());
- db.create_access_token_hash(user, "h5", 3).await.unwrap();
+ let token_4 = db.create_access_token(user, "h4", 2).await.unwrap();
assert_eq!(
- db.get_access_token_hashes(user).await.unwrap(),
- &["h5".to_string(), "h4".to_string(), "h3".to_string()]
+ db.get_access_token(token_4).await.unwrap(),
+ access_token::Model {
+ id: token_4,
+ user_id: user,
+ hash: "h4".into()
+ }
+ );
+ assert_eq!(
+ db.get_access_token(token_3).await.unwrap(),
+ access_token::Model {
+ id: token_3,
+ user_id: user,
+ hash: "h3".into()
+ }
);
+ assert!(db.get_access_token(token_2).await.is_err());
+ assert!(db.get_access_token(token_1).await.is_err());
}
);