1use crate::{
2 db::{self, UserId},
3 AppState, Error, Result,
4};
5use anyhow::{anyhow, Context};
6use axum::{
7 http::{self, Request, StatusCode},
8 middleware::Next,
9 response::IntoResponse,
10};
11use rand::thread_rng;
12use scrypt::{
13 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
14 Scrypt,
15};
16use std::sync::Arc;
17
18pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
19 let mut auth_header = req
20 .headers()
21 .get(http::header::AUTHORIZATION)
22 .and_then(|header| header.to_str().ok())
23 .ok_or_else(|| {
24 Error::Http(
25 StatusCode::UNAUTHORIZED,
26 "missing authorization header".to_string(),
27 )
28 })?
29 .split_whitespace();
30
31 let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
32 Error::Http(
33 StatusCode::BAD_REQUEST,
34 "missing user id in authorization header".to_string(),
35 )
36 })?);
37
38 let access_token = auth_header.next().ok_or_else(|| {
39 Error::Http(
40 StatusCode::BAD_REQUEST,
41 "missing access token in authorization header".to_string(),
42 )
43 })?;
44
45 let mut credentials_valid = false;
46 let state = req.extensions().get::<Arc<AppState>>().unwrap();
47 if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
48 if state.config.api_token == admin_token {
49 credentials_valid = true;
50 }
51 } else {
52 for password_hash in state.db.get_access_token_hashes(user_id).await? {
53 if verify_access_token(access_token, &password_hash)? {
54 credentials_valid = true;
55 break;
56 }
57 }
58 }
59
60 if credentials_valid {
61 let user = state
62 .db
63 .get_user_by_id(user_id)
64 .await?
65 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
66 req.extensions_mut().insert(user);
67 Ok::<_, Error>(next.run(req).await)
68 } else {
69 Err(Error::Http(
70 StatusCode::UNAUTHORIZED,
71 "invalid credentials".to_string(),
72 ))
73 }
74}
75
76const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
77
78pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
79 let access_token = rpc::auth::random_token();
80 let access_token_hash =
81 hash_access_token(&access_token).context("failed to hash access token")?;
82 db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
83 .await?;
84 Ok(access_token)
85}
86
87fn hash_access_token(token: &str) -> Result<String> {
88 // Avoid slow hashing in debug mode.
89 let params = if cfg!(debug_assertions) {
90 scrypt::Params::new(1, 1, 1).unwrap()
91 } else {
92 scrypt::Params::recommended()
93 };
94
95 Ok(Scrypt
96 .hash_password(
97 token.as_bytes(),
98 None,
99 params,
100 &SaltString::generate(thread_rng()),
101 )
102 .map_err(anyhow::Error::new)?
103 .to_string())
104}
105
106pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
107 let native_app_public_key =
108 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
109 let encrypted_access_token = native_app_public_key
110 .encrypt_string(access_token)
111 .context("failed to encrypt access token with public key")?;
112 Ok(encrypted_access_token)
113}
114
115pub fn verify_access_token(token: &str, hash: &str) -> Result<bool> {
116 let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?;
117 Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
118}