1use crate::{
2 db::{self, AccessTokenId, Database, UserId},
3 AppState, Error, Result,
4};
5use anyhow::{anyhow, Context};
6use axum::{
7 http::{self, Request, StatusCode},
8 middleware::Next,
9 response::IntoResponse,
10};
11use prometheus::{exponential_buckets, register_histogram, Histogram};
12use scrypt::{
13 password_hash::{PasswordHash, PasswordVerifier},
14 Scrypt,
15};
16use serde::{Deserialize, Serialize};
17use sha2::Digest;
18use std::sync::OnceLock;
19use std::{sync::Arc, time::Instant};
20use subtle::ConstantTimeEq;
21
22#[derive(Clone, Debug, Default, PartialEq, Eq)]
23pub struct Impersonator(pub Option<db::User>);
24
25/// Validates the authorization header. This has two mechanisms, one for the ADMIN_TOKEN
26/// and one for the access tokens that we issue.
27pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
28 let mut auth_header = req
29 .headers()
30 .get(http::header::AUTHORIZATION)
31 .and_then(|header| header.to_str().ok())
32 .ok_or_else(|| {
33 Error::Http(
34 StatusCode::UNAUTHORIZED,
35 "missing authorization header".to_string(),
36 )
37 })?
38 .split_whitespace();
39
40 let user_id = UserId(auth_header.next().unwrap_or("").parse().map_err(|_| {
41 Error::Http(
42 StatusCode::BAD_REQUEST,
43 "missing user id in authorization header".to_string(),
44 )
45 })?);
46
47 let access_token = auth_header.next().ok_or_else(|| {
48 Error::Http(
49 StatusCode::BAD_REQUEST,
50 "missing access token in authorization header".to_string(),
51 )
52 })?;
53
54 let state = req.extensions().get::<Arc<AppState>>().unwrap();
55
56 // In development, allow impersonation using the admin API token.
57 // Don't allow this in production because we can't tell who is doing
58 // the impersonating.
59 let validate_result = if let (Some(admin_token), true) = (
60 access_token.strip_prefix("ADMIN_TOKEN:"),
61 state.config.is_development(),
62 ) {
63 Ok(VerifyAccessTokenResult {
64 is_valid: state.config.api_token == admin_token,
65 impersonator_id: None,
66 })
67 } else {
68 verify_access_token(&access_token, user_id, &state.db).await
69 };
70
71 if let Ok(validate_result) = validate_result {
72 if validate_result.is_valid {
73 let user = state
74 .db
75 .get_user_by_id(user_id)
76 .await?
77 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
78
79 let impersonator = if let Some(impersonator_id) = validate_result.impersonator_id {
80 let impersonator = state
81 .db
82 .get_user_by_id(impersonator_id)
83 .await?
84 .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
85 Some(impersonator)
86 } else {
87 None
88 };
89 req.extensions_mut().insert(user);
90 req.extensions_mut().insert(Impersonator(impersonator));
91 return Ok::<_, Error>(next.run(req).await);
92 }
93 }
94
95 Err(Error::Http(
96 StatusCode::UNAUTHORIZED,
97 "invalid credentials".to_string(),
98 ))
99}
100
101const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
102
103#[derive(Serialize, Deserialize)]
104struct AccessTokenJson {
105 version: usize,
106 id: AccessTokenId,
107 token: String,
108}
109
110/// Creates a new access token to identify the given user. before returning it, you should
111/// encrypt it with the user's public key.
112pub async fn create_access_token(
113 db: &db::Database,
114 user_id: UserId,
115 impersonated_user_id: Option<UserId>,
116) -> Result<String> {
117 const VERSION: usize = 1;
118 let access_token = rpc::auth::random_token();
119 let access_token_hash = hash_access_token(&access_token);
120 let id = db
121 .create_access_token(
122 user_id,
123 impersonated_user_id,
124 &access_token_hash,
125 MAX_ACCESS_TOKENS_TO_STORE,
126 )
127 .await?;
128 Ok(serde_json::to_string(&AccessTokenJson {
129 version: VERSION,
130 id,
131 token: access_token,
132 })?)
133}
134
135/// Hashing prevents anyone with access to the database being able to login.
136/// As the token is randomly generated, we don't need to worry about scrypt-style
137/// protection.
138fn hash_access_token(token: &str) -> String {
139 let digest = sha2::Sha256::digest(token);
140 format!(
141 "$sha256${}",
142 base64::encode_config(digest, base64::URL_SAFE)
143 )
144}
145
146/// Encrypts the given access token with the given public key to avoid leaking it on the way
147/// to the client.
148pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
149 let native_app_public_key =
150 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
151 let encrypted_access_token = native_app_public_key
152 .encrypt_string(access_token)
153 .context("failed to encrypt access token with public key")?;
154 Ok(encrypted_access_token)
155}
156
157pub struct VerifyAccessTokenResult {
158 pub is_valid: bool,
159 pub impersonator_id: Option<UserId>,
160}
161
162/// Checks that the given access token is valid for the given user.
163pub async fn verify_access_token(
164 token: &str,
165 user_id: UserId,
166 db: &Arc<Database>,
167) -> Result<VerifyAccessTokenResult> {
168 static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
169 let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
170 register_histogram!(
171 "access_token_hashing_time",
172 "time spent hashing access tokens",
173 exponential_buckets(10.0, 2.0, 10).unwrap(),
174 )
175 .unwrap()
176 });
177
178 let token: AccessTokenJson = serde_json::from_str(&token)?;
179
180 let db_token = db.get_access_token(token.id).await?;
181 let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
182 if token_user_id != user_id {
183 return Err(anyhow!("no such access token"))?;
184 }
185 let t0 = Instant::now();
186
187 let is_valid = if db_token.hash.starts_with("$scrypt$") {
188 let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
189 Scrypt
190 .verify_password(token.token.as_bytes(), &db_hash)
191 .is_ok()
192 } else {
193 let token_hash = hash_access_token(&token.token);
194 db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
195 };
196
197 let duration = t0.elapsed();
198 log::info!("hashed access token in {:?}", duration);
199 metric_access_token_hashing_time.observe(duration.as_millis() as f64);
200
201 if is_valid && db_token.hash.starts_with("$scrypt$") {
202 let new_hash = hash_access_token(&token.token);
203 db.update_access_token_hash(db_token.id, &new_hash).await?;
204 }
205
206 Ok(VerifyAccessTokenResult {
207 is_valid,
208 impersonator_id: if db_token.impersonated_user_id.is_some() {
209 Some(db_token.user_id)
210 } else {
211 None
212 },
213 })
214}
215
216#[cfg(test)]
217mod test {
218 use rand::thread_rng;
219 use scrypt::password_hash::{PasswordHasher, SaltString};
220 use sea_orm::EntityTrait;
221
222 use super::*;
223 use crate::db::{access_token, NewUserParams};
224
225 #[gpui::test]
226 async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
227 let test_db = crate::db::TestDb::postgres(cx.executor().clone());
228 let db = test_db.db();
229
230 let user = db
231 .create_user(
232 "example@example.com",
233 false,
234 NewUserParams {
235 github_login: "example".into(),
236 github_user_id: 1,
237 },
238 )
239 .await
240 .unwrap();
241
242 let token = create_access_token(&db, user.user_id, None).await.unwrap();
243 assert!(matches!(
244 verify_access_token(&token, user.user_id, &db)
245 .await
246 .unwrap(),
247 VerifyAccessTokenResult {
248 is_valid: true,
249 impersonator_id: None,
250 }
251 ));
252
253 let old_token = create_previous_access_token(user.user_id, None, &db)
254 .await
255 .unwrap();
256
257 let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
258 .unwrap()
259 .id;
260
261 let hash = db
262 .transaction(|tx| async move {
263 Ok(access_token::Entity::find_by_id(old_token_id)
264 .one(&*tx)
265 .await?)
266 })
267 .await
268 .unwrap()
269 .unwrap()
270 .hash;
271 assert!(hash.starts_with("$scrypt$"));
272
273 assert!(matches!(
274 verify_access_token(&old_token, user.user_id, &db)
275 .await
276 .unwrap(),
277 VerifyAccessTokenResult {
278 is_valid: true,
279 impersonator_id: None,
280 }
281 ));
282
283 let hash = db
284 .transaction(|tx| async move {
285 Ok(access_token::Entity::find_by_id(old_token_id)
286 .one(&*tx)
287 .await?)
288 })
289 .await
290 .unwrap()
291 .unwrap()
292 .hash;
293 assert!(hash.starts_with("$sha256$"));
294
295 assert!(matches!(
296 verify_access_token(&old_token, user.user_id, &db)
297 .await
298 .unwrap(),
299 VerifyAccessTokenResult {
300 is_valid: true,
301 impersonator_id: None,
302 }
303 ));
304
305 assert!(matches!(
306 verify_access_token(&token, user.user_id, &db)
307 .await
308 .unwrap(),
309 VerifyAccessTokenResult {
310 is_valid: true,
311 impersonator_id: None,
312 }
313 ));
314 }
315
316 async fn create_previous_access_token(
317 user_id: UserId,
318 impersonated_user_id: Option<UserId>,
319 db: &Database,
320 ) -> Result<String> {
321 let access_token = rpc::auth::random_token();
322 let access_token_hash = previous_hash_access_token(&access_token)?;
323 let id = db
324 .create_access_token(
325 user_id,
326 impersonated_user_id,
327 &access_token_hash,
328 MAX_ACCESS_TOKENS_TO_STORE,
329 )
330 .await?;
331 Ok(serde_json::to_string(&AccessTokenJson {
332 version: 1,
333 id,
334 token: access_token,
335 })?)
336 }
337
338 fn previous_hash_access_token(token: &str) -> Result<String> {
339 // Avoid slow hashing in debug mode.
340 let params = if cfg!(debug_assertions) {
341 scrypt::Params::new(1, 1, 1).unwrap()
342 } else {
343 scrypt::Params::new(14, 8, 1).unwrap()
344 };
345
346 Ok(Scrypt
347 .hash_password(
348 token.as_bytes(),
349 None,
350 params,
351 &SaltString::generate(thread_rng()),
352 )
353 .map_err(anyhow::Error::new)?
354 .to_string())
355 }
356}