1use crate::{
2 db::{self, dev_server, AccessTokenId, Database, DevServerId, UserId},
3 rpc::Principal,
4 AppState, Error, Result,
5};
6use anyhow::{anyhow, Context};
7use axum::{
8 http::{self, Request, StatusCode},
9 middleware::Next,
10 response::IntoResponse,
11};
12use base64::prelude::*;
13use prometheus::{exponential_buckets, register_histogram, Histogram};
14pub use rpc::auth::random_token;
15use scrypt::{
16 password_hash::{PasswordHash, PasswordVerifier},
17 Scrypt,
18};
19use serde::{Deserialize, Serialize};
20use sha2::Digest;
21use std::sync::OnceLock;
22use std::{sync::Arc, time::Instant};
23use subtle::ConstantTimeEq;
24
25/// Validates the authorization header and adds an Extension<Principal> to the request.
26/// Authorization: <user-id> <token>
27/// <token> can be an access_token attached to that user, or an access token of an admin
28/// or (in development) the string ADMIN:<config.api_token>.
29/// Authorization: "dev-server-token" <token>
30pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
31 let mut auth_header = req
32 .headers()
33 .get(http::header::AUTHORIZATION)
34 .and_then(|header| header.to_str().ok())
35 .ok_or_else(|| {
36 Error::Http(
37 StatusCode::UNAUTHORIZED,
38 "missing authorization header".to_string(),
39 )
40 })?
41 .split_whitespace();
42
43 let state = req.extensions().get::<Arc<AppState>>().unwrap();
44
45 let first = auth_header.next().unwrap_or("");
46 if first == "dev-server-token" {
47 let dev_server_token = auth_header.next().ok_or_else(|| {
48 Error::Http(
49 StatusCode::BAD_REQUEST,
50 "missing dev-server-token token in authorization header".to_string(),
51 )
52 })?;
53 let dev_server = verify_dev_server_token(dev_server_token, &state.db)
54 .await
55 .map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
56
57 req.extensions_mut()
58 .insert(Principal::DevServer(dev_server));
59 return Ok::<_, Error>(next.run(req).await);
60 }
61
62 let user_id = UserId(first.parse().map_err(|_| {
63 Error::Http(
64 StatusCode::BAD_REQUEST,
65 "missing user id in authorization header".to_string(),
66 )
67 })?);
68
69 let access_token = auth_header.next().ok_or_else(|| {
70 Error::Http(
71 StatusCode::BAD_REQUEST,
72 "missing access token in authorization header".to_string(),
73 )
74 })?;
75
76 // In development, allow impersonation using the admin API token.
77 // Don't allow this in production because we can't tell who is doing
78 // the impersonating.
79 let validate_result = if let (Some(admin_token), true) = (
80 access_token.strip_prefix("ADMIN_TOKEN:"),
81 state.config.is_development(),
82 ) {
83 Ok(VerifyAccessTokenResult {
84 is_valid: state.config.api_token == admin_token,
85 impersonator_id: None,
86 })
87 } else {
88 verify_access_token(&access_token, user_id, &state.db).await
89 };
90
91 if let Ok(validate_result) = validate_result {
92 if validate_result.is_valid {
93 let user = state
94 .db
95 .get_user_by_id(user_id)
96 .await?
97 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
98
99 if let Some(impersonator_id) = validate_result.impersonator_id {
100 let admin = state
101 .db
102 .get_user_by_id(impersonator_id)
103 .await?
104 .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
105 req.extensions_mut()
106 .insert(Principal::Impersonated { user, admin });
107 } else {
108 req.extensions_mut().insert(Principal::User(user));
109 };
110 return Ok::<_, Error>(next.run(req).await);
111 }
112 }
113
114 Err(Error::Http(
115 StatusCode::UNAUTHORIZED,
116 "invalid credentials".to_string(),
117 ))
118}
119
120const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
121
122#[derive(Serialize, Deserialize)]
123struct AccessTokenJson {
124 version: usize,
125 id: AccessTokenId,
126 token: String,
127}
128
129/// Creates a new access token to identify the given user. before returning it, you should
130/// encrypt it with the user's public key.
131pub async fn create_access_token(
132 db: &db::Database,
133 user_id: UserId,
134 impersonated_user_id: Option<UserId>,
135) -> Result<String> {
136 const VERSION: usize = 1;
137 let access_token = rpc::auth::random_token();
138 let access_token_hash = hash_access_token(&access_token);
139 let id = db
140 .create_access_token(
141 user_id,
142 impersonated_user_id,
143 &access_token_hash,
144 MAX_ACCESS_TOKENS_TO_STORE,
145 )
146 .await?;
147 Ok(serde_json::to_string(&AccessTokenJson {
148 version: VERSION,
149 id,
150 token: access_token,
151 })?)
152}
153
154/// Hashing prevents anyone with access to the database being able to login.
155/// As the token is randomly generated, we don't need to worry about scrypt-style
156/// protection.
157pub fn hash_access_token(token: &str) -> String {
158 let digest = sha2::Sha256::digest(token);
159 format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
160}
161
162/// Encrypts the given access token with the given public key to avoid leaking it on the way
163/// to the client.
164pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
165 use rpc::auth::EncryptionFormat;
166
167 /// The encryption format to use for the access token.
168 ///
169 /// Currently we're using the original encryption format to avoid
170 /// breaking compatibility with older clients.
171 ///
172 /// Once enough clients are capable of decrypting the newer encryption
173 /// format we can start encrypting with `EncryptionFormat::V1`.
174 const ENCRYPTION_FORMAT: EncryptionFormat = EncryptionFormat::V0;
175
176 let native_app_public_key =
177 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
178 let encrypted_access_token = native_app_public_key
179 .encrypt_string(access_token, ENCRYPTION_FORMAT)
180 .context("failed to encrypt access token with public key")?;
181 Ok(encrypted_access_token)
182}
183
184pub struct VerifyAccessTokenResult {
185 pub is_valid: bool,
186 pub impersonator_id: Option<UserId>,
187}
188
189/// Checks that the given access token is valid for the given user.
190pub async fn verify_access_token(
191 token: &str,
192 user_id: UserId,
193 db: &Arc<Database>,
194) -> Result<VerifyAccessTokenResult> {
195 static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
196 let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
197 register_histogram!(
198 "access_token_hashing_time",
199 "time spent hashing access tokens",
200 exponential_buckets(10.0, 2.0, 10).unwrap(),
201 )
202 .unwrap()
203 });
204
205 let token: AccessTokenJson = serde_json::from_str(&token)?;
206
207 let db_token = db.get_access_token(token.id).await?;
208 let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
209 if token_user_id != user_id {
210 return Err(anyhow!("no such access token"))?;
211 }
212 let t0 = Instant::now();
213
214 let is_valid = if db_token.hash.starts_with("$scrypt$") {
215 let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
216 Scrypt
217 .verify_password(token.token.as_bytes(), &db_hash)
218 .is_ok()
219 } else {
220 let token_hash = hash_access_token(&token.token);
221 db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
222 };
223
224 let duration = t0.elapsed();
225 log::info!("hashed access token in {:?}", duration);
226 metric_access_token_hashing_time.observe(duration.as_millis() as f64);
227
228 if is_valid && db_token.hash.starts_with("$scrypt$") {
229 let new_hash = hash_access_token(&token.token);
230 db.update_access_token_hash(db_token.id, &new_hash).await?;
231 }
232
233 Ok(VerifyAccessTokenResult {
234 is_valid,
235 impersonator_id: if db_token.impersonated_user_id.is_some() {
236 Some(db_token.user_id)
237 } else {
238 None
239 },
240 })
241}
242
243pub fn generate_dev_server_token(id: usize, access_token: String) -> String {
244 format!("{}.{}", id, access_token)
245}
246
247pub async fn verify_dev_server_token(
248 dev_server_token: &str,
249 db: &Arc<Database>,
250) -> anyhow::Result<dev_server::Model> {
251 let (id, token) = split_dev_server_token(dev_server_token)?;
252 let token_hash = hash_access_token(&token);
253 let server = db.get_dev_server(id).await?;
254
255 if server
256 .hashed_token
257 .as_bytes()
258 .ct_eq(token_hash.as_ref())
259 .into()
260 {
261 Ok(server)
262 } else {
263 Err(anyhow!("wrong token for dev server"))
264 }
265}
266
267// a dev_server_token has the format <id>.<base64>. This is to make them
268// relatively easy to copy/paste around.
269pub fn split_dev_server_token(dev_server_token: &str) -> anyhow::Result<(DevServerId, &str)> {
270 let mut parts = dev_server_token.splitn(2, '.');
271 let id = DevServerId(parts.next().unwrap_or_default().parse()?);
272 let token = parts
273 .next()
274 .ok_or_else(|| anyhow!("invalid dev server token format"))?;
275 Ok((id, token))
276}
277
278#[cfg(test)]
279mod test {
280 use rand::thread_rng;
281 use scrypt::password_hash::{PasswordHasher, SaltString};
282 use sea_orm::EntityTrait;
283
284 use super::*;
285 use crate::db::{access_token, NewUserParams};
286
287 #[gpui::test]
288 async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
289 let test_db = crate::db::TestDb::sqlite(cx.executor().clone());
290 let db = test_db.db();
291
292 let user = db
293 .create_user(
294 "example@example.com",
295 false,
296 NewUserParams {
297 github_login: "example".into(),
298 github_user_id: 1,
299 },
300 )
301 .await
302 .unwrap();
303
304 let token = create_access_token(&db, user.user_id, None).await.unwrap();
305 assert!(matches!(
306 verify_access_token(&token, user.user_id, &db)
307 .await
308 .unwrap(),
309 VerifyAccessTokenResult {
310 is_valid: true,
311 impersonator_id: None,
312 }
313 ));
314
315 let old_token = create_previous_access_token(user.user_id, None, &db)
316 .await
317 .unwrap();
318
319 let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
320 .unwrap()
321 .id;
322
323 let hash = db
324 .transaction(|tx| async move {
325 Ok(access_token::Entity::find_by_id(old_token_id)
326 .one(&*tx)
327 .await?)
328 })
329 .await
330 .unwrap()
331 .unwrap()
332 .hash;
333 assert!(hash.starts_with("$scrypt$"));
334
335 assert!(matches!(
336 verify_access_token(&old_token, user.user_id, &db)
337 .await
338 .unwrap(),
339 VerifyAccessTokenResult {
340 is_valid: true,
341 impersonator_id: None,
342 }
343 ));
344
345 let hash = db
346 .transaction(|tx| async move {
347 Ok(access_token::Entity::find_by_id(old_token_id)
348 .one(&*tx)
349 .await?)
350 })
351 .await
352 .unwrap()
353 .unwrap()
354 .hash;
355 assert!(hash.starts_with("$sha256$"));
356
357 assert!(matches!(
358 verify_access_token(&old_token, user.user_id, &db)
359 .await
360 .unwrap(),
361 VerifyAccessTokenResult {
362 is_valid: true,
363 impersonator_id: None,
364 }
365 ));
366
367 assert!(matches!(
368 verify_access_token(&token, user.user_id, &db)
369 .await
370 .unwrap(),
371 VerifyAccessTokenResult {
372 is_valid: true,
373 impersonator_id: None,
374 }
375 ));
376 }
377
378 async fn create_previous_access_token(
379 user_id: UserId,
380 impersonated_user_id: Option<UserId>,
381 db: &Database,
382 ) -> Result<String> {
383 let access_token = rpc::auth::random_token();
384 let access_token_hash = previous_hash_access_token(&access_token)?;
385 let id = db
386 .create_access_token(
387 user_id,
388 impersonated_user_id,
389 &access_token_hash,
390 MAX_ACCESS_TOKENS_TO_STORE,
391 )
392 .await?;
393 Ok(serde_json::to_string(&AccessTokenJson {
394 version: 1,
395 id,
396 token: access_token,
397 })?)
398 }
399
400 fn previous_hash_access_token(token: &str) -> Result<String> {
401 // Avoid slow hashing in debug mode.
402 let params = if cfg!(debug_assertions) {
403 scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
404 } else {
405 scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
406 };
407
408 Ok(Scrypt
409 .hash_password_customized(
410 token.as_bytes(),
411 None,
412 None,
413 params,
414 &SaltString::generate(thread_rng()),
415 )
416 .map_err(anyhow::Error::new)?
417 .to_string())
418 }
419}