1use crate::{
2 db::{self, AccessTokenId, Database, UserId},
3 AppState, Error, Result,
4};
5use anyhow::{anyhow, Context};
6use axum::{
7 http::{self, Request, StatusCode},
8 middleware::Next,
9 response::IntoResponse,
10};
11use lazy_static::lazy_static;
12use prometheus::{exponential_buckets, register_histogram, Histogram};
13use rand::thread_rng;
14use scrypt::{
15 password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
16 Scrypt,
17};
18use serde::{Deserialize, Serialize};
19use std::{sync::Arc, time::Instant};
20
21lazy_static! {
22 static ref METRIC_ACCESS_TOKEN_HASHING_TIME: Histogram = register_histogram!(
23 "access_token_hashing_time",
24 "time spent hashing access tokens",
25 exponential_buckets(10.0, 2.0, 10).unwrap(),
26 )
27 .unwrap();
28}
29
30#[derive(Clone, Debug, Default, PartialEq, Eq)]
31pub struct Impersonator(pub Option<db::User>);
32
33/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
34/// and one for the access tokens that we issue.
35pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
36 let mut auth_header = req
37 .headers()
38 .get(http::header::AUTHORIZATION)
39 .and_then(|header| header.to_str().ok())
40 .ok_or_else(|| {
41 Error::Http(
42 StatusCode::UNAUTHORIZED,
43 "missing authorization header".to_string(),
44 )
45 })?
46 .split_whitespace();
47
48 let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
49 Error::Http(
50 StatusCode::BAD_REQUEST,
51 "missing user id in authorization header".to_string(),
52 )
53 })?);
54
55 let access_token = auth_header.next().ok_or_else(|| {
56 Error::Http(
57 StatusCode::BAD_REQUEST,
58 "missing access token in authorization header".to_string(),
59 )
60 })?;
61
62 let state = req.extensions().get::<Arc<AppState>>().unwrap();
63
64 // In development, allow impersonation using the admin API token.
65 // Don't allow this in production because we can't tell who is doing
66 // the impersonating.
67 let validate_result = if let (Some(admin_token), true) = (
68 access_token.strip_prefix("ADMIN_TOKEN:"),
69 state.config.is_development(),
70 ) {
71 Ok(VerifyAccessTokenResult {
72 is_valid: state.config.api_token == admin_token,
73 impersonator_id: None,
74 })
75 } else {
76 verify_access_token(&access_token, user_id, &state.db).await
77 };
78
79 if let Ok(validate_result) = validate_result {
80 if validate_result.is_valid {
81 let user = state
82 .db
83 .get_user_by_id(user_id)
84 .await?
85 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
86
87 let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
88 let impersonator = state
89 .db
90 .get_user_by_id(impersonator_id)
91 .await?
92 .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
93 Some(impersonator)
94 } else {
95 None
96 };
97 req.extensions_mut().insert(user);
98 req.extensions_mut().insert(Impersonator(impersonator));
99 return Ok::<_, Error>(next.run(req).await);
100 }
101 }
102
103 Err(Error::Http(
104 StatusCode::UNAUTHORIZED,
105 "invalid credentials".to_string(),
106 ))
107}
108
109const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
110
111#[derive(Serialize, Deserialize)]
112struct AccessTokenJson {
113 version: usize,
114 id: AccessTokenId,
115 token: String,
116}
117
118/// Creates a new access token to identify the given user. before returning it, you should
119/// encrypt it with the user's public key.
120pub async fn create_access_token(
121 db: &db::Database,
122 user_id: UserId,
123 impersonated_user_id: Option<UserId>,
124) -> Result<String> {
125 const VERSION: usize = 1;
126 let access_token = rpc::auth::random_token();
127 let access_token_hash =
128 hash_access_token(&access_token).context("failed to hash access token")?;
129 let id = db
130 .create_access_token(
131 user_id,
132 impersonated_user_id,
133 &access_token_hash,
134 MAX_ACCESS_TOKENS_TO_STORE,
135 )
136 .await?;
137 Ok(serde_json::to_string(&AccessTokenJson {
138 version: VERSION,
139 id,
140 token: access_token,
141 })?)
142}
143
144fn hash_access_token(token: &str) -> Result<String> {
145 // Avoid slow hashing in debug mode.
146 let params = if cfg!(debug_assertions) {
147 scrypt::Params::new(1, 1, 1).unwrap()
148 } else {
149 scrypt::Params::new(14, 8, 1).unwrap()
150 };
151
152 Ok(Scrypt
153 .hash_password(
154 token.as_bytes(),
155 None,
156 params,
157 &SaltString::generate(thread_rng()),
158 )
159 .map_err(anyhow::Error::new)?
160 .to_string())
161}
162
163/// Encrypts the given access token with the given public key to avoid leaking it on the way
164/// to the client.
165pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
166 let native_app_public_key =
167 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
168 let encrypted_access_token = native_app_public_key
169 .encrypt_string(access_token)
170 .context("failed to encrypt access token with public key")?;
171 Ok(encrypted_access_token)
172}
173
174pub struct VerifyAccessTokenResult {
175 pub is_valid: bool,
176 pub impersonator_id: Option<UserId>,
177}
178
179/// Checks that the given access token is valid for the given user.
180pub async fn verify_access_token(
181 token: &str,
182 user_id: UserId,
183 db: &Arc<Database>,
184) -> Result<VerifyAccessTokenResult> {
185 let token: AccessTokenJson = serde_json::from_str(&token)?;
186
187 let db_token = db.get_access_token(token.id).await?;
188 let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
189 if token_user_id != user_id {
190 return Err(anyhow!("no such access token"))?;
191 }
192
193 let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
194 let t0 = Instant::now();
195 let is_valid = Scrypt
196 .verify_password(token.token.as_bytes(), &db_hash)
197 .is_ok();
198 let duration = t0.elapsed();
199 log::info!("hashed access token in {:?}", duration);
200 METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
201 Ok(VerifyAccessTokenResult {
202 is_valid,
203 impersonator_id: if db_token.impersonated_user_id.is_some() {
204 Some(db_token.user_id)
205 } else {
206 None
207 },
208 })
209}