1use crate::{
2 db::{self, dev_server, AccessTokenId, Database, DevServerId, UserId},
3 rpc::Principal,
4 AppState, Error, Result,
5};
6use anyhow::{anyhow, Context};
7use axum::{
8 http::{self, Request, StatusCode},
9 middleware::Next,
10 response::IntoResponse,
11};
12use prometheus::{exponential_buckets, register_histogram, Histogram};
13use scrypt::{
14 password_hash::{PasswordHash, PasswordVerifier},
15 Scrypt,
16};
17use serde::{Deserialize, Serialize};
18use sha2::Digest;
19use std::sync::OnceLock;
20use std::{sync::Arc, time::Instant};
21use subtle::ConstantTimeEq;
22
23/// Validates the authorization header and adds an Extension<Principal> to the request.
24/// Authorization: <user-id> <token>
25/// <token> can be an access_token attached to that user, or an access token of an admin
26/// or (in development) the string ADMIN:<config.api_token>.
27/// Authorization: "dev-server-token" <token>
28pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
29 let mut auth_header = req
30 .headers()
31 .get(http::header::AUTHORIZATION)
32 .and_then(|header| header.to_str().ok())
33 .ok_or_else(|| {
34 Error::Http(
35 StatusCode::UNAUTHORIZED,
36 "missing authorization header".to_string(),
37 )
38 })?
39 .split_whitespace();
40
41 let state = req.extensions().get::<Arc<AppState>>().unwrap();
42
43 let first = auth_header.next().unwrap_or("");
44 if first == "dev-server-token" {
45 let dev_server_token = auth_header.next().ok_or_else(|| {
46 Error::Http(
47 StatusCode::BAD_REQUEST,
48 "missing dev-server-token token in authorization header".to_string(),
49 )
50 })?;
51 let dev_server = verify_dev_server_token(dev_server_token, &state.db)
52 .await
53 .map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
54
55 req.extensions_mut()
56 .insert(Principal::DevServer(dev_server));
57 return Ok::<_, Error>(next.run(req).await);
58 }
59
60 let user_id = UserId(first.parse().map_err(|_| {
61 Error::Http(
62 StatusCode::BAD_REQUEST,
63 "missing user id in authorization header".to_string(),
64 )
65 })?);
66
67 let access_token = auth_header.next().ok_or_else(|| {
68 Error::Http(
69 StatusCode::BAD_REQUEST,
70 "missing access token in authorization header".to_string(),
71 )
72 })?;
73
74 // In development, allow impersonation using the admin API token.
75 // Don't allow this in production because we can't tell who is doing
76 // the impersonating.
77 let validate_result = if let (Some(admin_token), true) = (
78 access_token.strip_prefix("ADMIN_TOKEN:"),
79 state.config.is_development(),
80 ) {
81 Ok(VerifyAccessTokenResult {
82 is_valid: state.config.api_token == admin_token,
83 impersonator_id: None,
84 })
85 } else {
86 verify_access_token(&access_token, user_id, &state.db).await
87 };
88
89 if let Ok(validate_result) = validate_result {
90 if validate_result.is_valid {
91 let user = state
92 .db
93 .get_user_by_id(user_id)
94 .await?
95 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
96
97 if let Some(impersonator_id) = validate_result.impersonator_id {
98 let admin = state
99 .db
100 .get_user_by_id(impersonator_id)
101 .await?
102 .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
103 req.extensions_mut()
104 .insert(Principal::Impersonated { user, admin });
105 } else {
106 req.extensions_mut().insert(Principal::User(user));
107 };
108 return Ok::<_, Error>(next.run(req).await);
109 }
110 }
111
112 Err(Error::Http(
113 StatusCode::UNAUTHORIZED,
114 "invalid credentials".to_string(),
115 ))
116}
117
118const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
119
120#[derive(Serialize, Deserialize)]
121struct AccessTokenJson {
122 version: usize,
123 id: AccessTokenId,
124 token: String,
125}
126
127/// Creates a new access token to identify the given user. before returning it, you should
128/// encrypt it with the user's public key.
129pub async fn create_access_token(
130 db: &db::Database,
131 user_id: UserId,
132 impersonated_user_id: Option<UserId>,
133) -> Result<String> {
134 const VERSION: usize = 1;
135 let access_token = rpc::auth::random_token();
136 let access_token_hash = hash_access_token(&access_token);
137 let id = db
138 .create_access_token(
139 user_id,
140 impersonated_user_id,
141 &access_token_hash,
142 MAX_ACCESS_TOKENS_TO_STORE,
143 )
144 .await?;
145 Ok(serde_json::to_string(&AccessTokenJson {
146 version: VERSION,
147 id,
148 token: access_token,
149 })?)
150}
151
152/// Hashing prevents anyone with access to the database being able to login.
153/// As the token is randomly generated, we don't need to worry about scrypt-style
154/// protection.
155fn hash_access_token(token: &str) -> String {
156 let digest = sha2::Sha256::digest(token);
157 format!(
158 "$sha256${}",
159 base64::encode_config(digest, base64::URL_SAFE)
160 )
161}
162
163/// Encrypts the given access token with the given public key to avoid leaking it on the way
164/// to the client.
165pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
166 let native_app_public_key =
167 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
168 let encrypted_access_token = native_app_public_key
169 .encrypt_string(access_token)
170 .context("failed to encrypt access token with public key")?;
171 Ok(encrypted_access_token)
172}
173
174pub struct VerifyAccessTokenResult {
175 pub is_valid: bool,
176 pub impersonator_id: Option<UserId>,
177}
178
179/// Checks that the given access token is valid for the given user.
180pub async fn verify_access_token(
181 token: &str,
182 user_id: UserId,
183 db: &Arc<Database>,
184) -> Result<VerifyAccessTokenResult> {
185 static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
186 let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
187 register_histogram!(
188 "access_token_hashing_time",
189 "time spent hashing access tokens",
190 exponential_buckets(10.0, 2.0, 10).unwrap(),
191 )
192 .unwrap()
193 });
194
195 let token: AccessTokenJson = serde_json::from_str(&token)?;
196
197 let db_token = db.get_access_token(token.id).await?;
198 let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
199 if token_user_id != user_id {
200 return Err(anyhow!("no such access token"))?;
201 }
202 let t0 = Instant::now();
203
204 let is_valid = if db_token.hash.starts_with("$scrypt$") {
205 let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
206 Scrypt
207 .verify_password(token.token.as_bytes(), &db_hash)
208 .is_ok()
209 } else {
210 let token_hash = hash_access_token(&token.token);
211 db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
212 };
213
214 let duration = t0.elapsed();
215 log::info!("hashed access token in {:?}", duration);
216 metric_access_token_hashing_time.observe(duration.as_millis() as f64);
217
218 if is_valid && db_token.hash.starts_with("$scrypt$") {
219 let new_hash = hash_access_token(&token.token);
220 db.update_access_token_hash(db_token.id, &new_hash).await?;
221 }
222
223 Ok(VerifyAccessTokenResult {
224 is_valid,
225 impersonator_id: if db_token.impersonated_user_id.is_some() {
226 Some(db_token.user_id)
227 } else {
228 None
229 },
230 })
231}
232
233// a dev_server_token has the format <id>.<base64>. This is to make them
234// relatively easy to copy/paste around.
235pub async fn verify_dev_server_token(
236 dev_server_token: &str,
237 db: &Arc<Database>,
238) -> anyhow::Result<dev_server::Model> {
239 let mut parts = dev_server_token.splitn(2, '.');
240 let id = DevServerId(parts.next().unwrap_or_default().parse()?);
241 let token = parts
242 .next()
243 .ok_or_else(|| anyhow!("invalid dev server token format"))?;
244
245 let token_hash = hash_access_token(&token);
246 let server = db.get_dev_server(id).await?;
247
248 if server
249 .hashed_token
250 .as_bytes()
251 .ct_eq(token_hash.as_ref())
252 .into()
253 {
254 Ok(server)
255 } else {
256 Err(anyhow!("wrong token for dev server"))
257 }
258}
259
260#[cfg(test)]
261mod test {
262 use rand::thread_rng;
263 use scrypt::password_hash::{PasswordHasher, SaltString};
264 use sea_orm::EntityTrait;
265
266 use super::*;
267 use crate::db::{access_token, NewUserParams};
268
269 #[gpui::test]
270 async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
271 let test_db = crate::db::TestDb::postgres(cx.executor().clone());
272 let db = test_db.db();
273
274 let user = db
275 .create_user(
276 "example@example.com",
277 false,
278 NewUserParams {
279 github_login: "example".into(),
280 github_user_id: 1,
281 },
282 )
283 .await
284 .unwrap();
285
286 let token = create_access_token(&db, user.user_id, None).await.unwrap();
287 assert!(matches!(
288 verify_access_token(&token, user.user_id, &db)
289 .await
290 .unwrap(),
291 VerifyAccessTokenResult {
292 is_valid: true,
293 impersonator_id: None,
294 }
295 ));
296
297 let old_token = create_previous_access_token(user.user_id, None, &db)
298 .await
299 .unwrap();
300
301 let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
302 .unwrap()
303 .id;
304
305 let hash = db
306 .transaction(|tx| async move {
307 Ok(access_token::Entity::find_by_id(old_token_id)
308 .one(&*tx)
309 .await?)
310 })
311 .await
312 .unwrap()
313 .unwrap()
314 .hash;
315 assert!(hash.starts_with("$scrypt$"));
316
317 assert!(matches!(
318 verify_access_token(&old_token, user.user_id, &db)
319 .await
320 .unwrap(),
321 VerifyAccessTokenResult {
322 is_valid: true,
323 impersonator_id: None,
324 }
325 ));
326
327 let hash = db
328 .transaction(|tx| async move {
329 Ok(access_token::Entity::find_by_id(old_token_id)
330 .one(&*tx)
331 .await?)
332 })
333 .await
334 .unwrap()
335 .unwrap()
336 .hash;
337 assert!(hash.starts_with("$sha256$"));
338
339 assert!(matches!(
340 verify_access_token(&old_token, user.user_id, &db)
341 .await
342 .unwrap(),
343 VerifyAccessTokenResult {
344 is_valid: true,
345 impersonator_id: None,
346 }
347 ));
348
349 assert!(matches!(
350 verify_access_token(&token, user.user_id, &db)
351 .await
352 .unwrap(),
353 VerifyAccessTokenResult {
354 is_valid: true,
355 impersonator_id: None,
356 }
357 ));
358 }
359
360 async fn create_previous_access_token(
361 user_id: UserId,
362 impersonated_user_id: Option<UserId>,
363 db: &Database,
364 ) -> Result<String> {
365 let access_token = rpc::auth::random_token();
366 let access_token_hash = previous_hash_access_token(&access_token)?;
367 let id = db
368 .create_access_token(
369 user_id,
370 impersonated_user_id,
371 &access_token_hash,
372 MAX_ACCESS_TOKENS_TO_STORE,
373 )
374 .await?;
375 Ok(serde_json::to_string(&AccessTokenJson {
376 version: 1,
377 id,
378 token: access_token,
379 })?)
380 }
381
382 fn previous_hash_access_token(token: &str) -> Result<String> {
383 // Avoid slow hashing in debug mode.
384 let params = if cfg!(debug_assertions) {
385 scrypt::Params::new(1, 1, 1).unwrap()
386 } else {
387 scrypt::Params::new(14, 8, 1).unwrap()
388 };
389
390 Ok(Scrypt
391 .hash_password(
392 token.as_bytes(),
393 None,
394 params,
395 &SaltString::generate(thread_rng()),
396 )
397 .map_err(anyhow::Error::new)?
398 .to_string())
399 }
400}