1use crate::{
2 db::{self, AccessTokenId, UserId},
3 AppState, Error, Result,
4};
5use anyhow::{anyhow, Context};
6use axum::{
7 http::{self, Request, StatusCode},
8 middleware::Next,
9 response::IntoResponse,
10};
11use rand::thread_rng;
12use scrypt::{
13 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
14 Scrypt,
15};
16use serde::{Deserialize, Serialize};
17use std::sync::Arc;
18
19pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
20 let mut auth_header = req
21 .headers()
22 .get(http::header::AUTHORIZATION)
23 .and_then(|header| header.to_str().ok())
24 .ok_or_else(|| {
25 Error::Http(
26 StatusCode::UNAUTHORIZED,
27 "missing authorization header".to_string(),
28 )
29 })?
30 .split_whitespace();
31
32 let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
33 Error::Http(
34 StatusCode::BAD_REQUEST,
35 "missing user id in authorization header".to_string(),
36 )
37 })?);
38
39 let access_token = auth_header.next().ok_or_else(|| {
40 Error::Http(
41 StatusCode::BAD_REQUEST,
42 "missing access token in authorization header".to_string(),
43 )
44 })?;
45
46 let state = req.extensions().get::<Arc<AppState>>().unwrap();
47 let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
48 state.config.api_token == admin_token
49 } else {
50 let access_token: AccessTokenJson = serde_json::from_str(&access_token)?;
51
52 let token = state.db.get_access_token(access_token.id).await?;
53 if token.user_id != user_id {
54 return Err(anyhow!("no such access token"))?;
55 }
56
57 verify_access_token(&access_token.token, &token.hash)?
58 };
59
60 if credentials_valid {
61 let user = state
62 .db
63 .get_user_by_id(user_id)
64 .await?
65 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
66 req.extensions_mut().insert(user);
67 Ok::<_, Error>(next.run(req).await)
68 } else {
69 Err(Error::Http(
70 StatusCode::UNAUTHORIZED,
71 "invalid credentials".to_string(),
72 ))
73 }
74}
75
76const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
77
78#[derive(Serialize, Deserialize)]
79struct AccessTokenJson {
80 version: usize,
81 id: AccessTokenId,
82 token: String,
83}
84
85pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result<String> {
86 const VERSION: usize = 1;
87 let access_token = rpc::auth::random_token();
88 let access_token_hash =
89 hash_access_token(&access_token).context("failed to hash access token")?;
90 let id = db
91 .create_access_token(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
92 .await?;
93 Ok(serde_json::to_string(&AccessTokenJson {
94 version: VERSION,
95 id,
96 token: access_token,
97 })?)
98}
99
100fn hash_access_token(token: &str) -> Result<String> {
101 // Avoid slow hashing in debug mode.
102 let params = if cfg!(debug_assertions) {
103 scrypt::Params::new(1, 1, 1).unwrap()
104 } else {
105 scrypt::Params::recommended()
106 };
107
108 Ok(Scrypt
109 .hash_password(
110 token.as_bytes(),
111 None,
112 params,
113 &SaltString::generate(thread_rng()),
114 )
115 .map_err(anyhow::Error::new)?
116 .to_string())
117}
118
119pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
120 let native_app_public_key =
121 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
122 let encrypted_access_token = native_app_public_key
123 .encrypt_string(access_token)
124 .context("failed to encrypt access token with public key")?;
125 Ok(encrypted_access_token)
126}
127
128pub fn verify_access_token(token: &str, hash: &str) -> Result<bool> {
129 let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?;
130 Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
131}