1use crate::{
2 AppState, Error, Result,
3 db::{self, AccessTokenId, Database, UserId},
4 rpc::Principal,
5};
6use anyhow::Context as _;
7use axum::{
8 http::{self, Request, StatusCode},
9 middleware::Next,
10 response::IntoResponse,
11};
12use base64::prelude::*;
13use prometheus::{Histogram, exponential_buckets, register_histogram};
14pub use rpc::auth::random_token;
15use scrypt::{
16 Scrypt,
17 password_hash::{PasswordHash, PasswordVerifier},
18};
19use serde::{Deserialize, Serialize};
20use sha2::Digest;
21use std::sync::OnceLock;
22use std::{sync::Arc, time::Instant};
23use subtle::ConstantTimeEq;
24
25/// Validates the authorization header and adds an Extension<Principal> to the request.
26/// Authorization: <user-id> <token>
27/// <token> can be an access_token attached to that user, or an access token of an admin
28/// or (in development) the string ADMIN:<config.api_token>.
29/// Authorization: "dev-server-token" <token>
30pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
31 let mut auth_header = req
32 .headers()
33 .get(http::header::AUTHORIZATION)
34 .and_then(|header| header.to_str().ok())
35 .ok_or_else(|| {
36 Error::http(
37 StatusCode::UNAUTHORIZED,
38 "missing authorization header".to_string(),
39 )
40 })?
41 .split_whitespace();
42
43 let state = req.extensions().get::<Arc<AppState>>().unwrap();
44
45 let first = auth_header.next().unwrap_or("");
46 if first == "dev-server-token" {
47 Err(Error::http(
48 StatusCode::UNAUTHORIZED,
49 "Dev servers were removed in Zed 0.157 please upgrade to SSH remoting".to_string(),
50 ))?;
51 }
52
53 let user_id = UserId(first.parse().map_err(|_| {
54 Error::http(
55 StatusCode::BAD_REQUEST,
56 "missing user id in authorization header".to_string(),
57 )
58 })?);
59
60 let access_token = auth_header.next().ok_or_else(|| {
61 Error::http(
62 StatusCode::BAD_REQUEST,
63 "missing access token in authorization header".to_string(),
64 )
65 })?;
66
67 // In development, allow impersonation using the admin API token.
68 // Don't allow this in production because we can't tell who is doing
69 // the impersonating.
70 let validate_result = if let (Some(admin_token), true) = (
71 access_token.strip_prefix("ADMIN_TOKEN:"),
72 state.config.is_development(),
73 ) {
74 Ok(VerifyAccessTokenResult {
75 is_valid: state.config.api_token == admin_token,
76 impersonator_id: None,
77 })
78 } else {
79 verify_access_token(access_token, user_id, &state.db).await
80 };
81
82 if let Ok(validate_result) = validate_result
83 && validate_result.is_valid
84 {
85 let user = state
86 .db
87 .get_user_by_id(user_id)
88 .await?
89 .with_context(|| format!("user {user_id} not found"))?;
90
91 if let Some(impersonator_id) = validate_result.impersonator_id {
92 let admin = state
93 .db
94 .get_user_by_id(impersonator_id)
95 .await?
96 .with_context(|| format!("user {impersonator_id} not found"))?;
97 req.extensions_mut()
98 .insert(Principal::Impersonated { user, admin });
99 } else {
100 req.extensions_mut().insert(Principal::User(user));
101 };
102 return Ok::<_, Error>(next.run(req).await);
103 }
104
105 Err(Error::http(
106 StatusCode::UNAUTHORIZED,
107 "invalid credentials".to_string(),
108 ))
109}
110
111const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
112
113#[derive(Serialize, Deserialize)]
114struct AccessTokenJson {
115 version: usize,
116 id: AccessTokenId,
117 token: String,
118}
119
120/// Creates a new access token to identify the given user. before returning it, you should
121/// encrypt it with the user's public key.
122pub async fn create_access_token(
123 db: &db::Database,
124 user_id: UserId,
125 impersonated_user_id: Option<UserId>,
126) -> Result<String> {
127 const VERSION: usize = 1;
128 let access_token = rpc::auth::random_token();
129 let access_token_hash = hash_access_token(&access_token);
130 let id = db
131 .create_access_token(
132 user_id,
133 impersonated_user_id,
134 &access_token_hash,
135 MAX_ACCESS_TOKENS_TO_STORE,
136 )
137 .await?;
138 Ok(serde_json::to_string(&AccessTokenJson {
139 version: VERSION,
140 id,
141 token: access_token,
142 })?)
143}
144
145/// Hashing prevents anyone with access to the database being able to login.
146/// As the token is randomly generated, we don't need to worry about scrypt-style
147/// protection.
148pub fn hash_access_token(token: &str) -> String {
149 let digest = sha2::Sha256::digest(token);
150 format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
151}
152
153/// Encrypts the given access token with the given public key to avoid leaking it on the way
154/// to the client.
155pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
156 use rpc::auth::EncryptionFormat;
157
158 /// The encryption format to use for the access token.
159 const ENCRYPTION_FORMAT: EncryptionFormat = EncryptionFormat::V1;
160
161 let native_app_public_key =
162 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
163 let encrypted_access_token = native_app_public_key
164 .encrypt_string(access_token, ENCRYPTION_FORMAT)
165 .context("failed to encrypt access token with public key")?;
166 Ok(encrypted_access_token)
167}
168
169pub struct VerifyAccessTokenResult {
170 pub is_valid: bool,
171 pub impersonator_id: Option<UserId>,
172}
173
174/// Checks that the given access token is valid for the given user.
175pub async fn verify_access_token(
176 token: &str,
177 user_id: UserId,
178 db: &Arc<Database>,
179) -> Result<VerifyAccessTokenResult> {
180 static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
181 let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
182 register_histogram!(
183 "access_token_hashing_time",
184 "time spent hashing access tokens",
185 exponential_buckets(10.0, 2.0, 10).unwrap(),
186 )
187 .unwrap()
188 });
189
190 let token: AccessTokenJson = serde_json::from_str(token)?;
191
192 let db_token = db.get_access_token(token.id).await?;
193 let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
194 if token_user_id != user_id {
195 return Err(anyhow::anyhow!("no such access token"))?;
196 }
197 let t0 = Instant::now();
198
199 let is_valid = if db_token.hash.starts_with("$scrypt$") {
200 let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
201 Scrypt
202 .verify_password(token.token.as_bytes(), &db_hash)
203 .is_ok()
204 } else {
205 let token_hash = hash_access_token(&token.token);
206 db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
207 };
208
209 let duration = t0.elapsed();
210 log::info!("hashed access token in {:?}", duration);
211 metric_access_token_hashing_time.observe(duration.as_millis() as f64);
212
213 if is_valid && db_token.hash.starts_with("$scrypt$") {
214 let new_hash = hash_access_token(&token.token);
215 db.update_access_token_hash(db_token.id, &new_hash).await?;
216 }
217
218 Ok(VerifyAccessTokenResult {
219 is_valid,
220 impersonator_id: if db_token.impersonated_user_id.is_some() {
221 Some(db_token.user_id)
222 } else {
223 None
224 },
225 })
226}
227
228#[cfg(test)]
229mod test {
230 use rand::prelude::*;
231 use scrypt::password_hash::{PasswordHasher, SaltString};
232 use sea_orm::EntityTrait;
233
234 use super::*;
235 use crate::db::{NewUserParams, access_token};
236
237 #[gpui::test]
238 async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
239 let test_db = crate::db::TestDb::sqlite(cx.executor());
240 let db = test_db.db();
241
242 let user = db
243 .create_user(
244 "example@example.com",
245 None,
246 false,
247 NewUserParams {
248 github_login: "example".into(),
249 github_user_id: 1,
250 },
251 )
252 .await
253 .unwrap();
254
255 let token = create_access_token(db, user.user_id, None).await.unwrap();
256 assert!(matches!(
257 verify_access_token(&token, user.user_id, db).await.unwrap(),
258 VerifyAccessTokenResult {
259 is_valid: true,
260 impersonator_id: None,
261 }
262 ));
263
264 let old_token = create_previous_access_token(user.user_id, None, db)
265 .await
266 .unwrap();
267
268 let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
269 .unwrap()
270 .id;
271
272 let hash = db
273 .transaction(|tx| async move {
274 Ok(access_token::Entity::find_by_id(old_token_id)
275 .one(&*tx)
276 .await?)
277 })
278 .await
279 .unwrap()
280 .unwrap()
281 .hash;
282 assert!(hash.starts_with("$scrypt$"));
283
284 assert!(matches!(
285 verify_access_token(&old_token, user.user_id, db)
286 .await
287 .unwrap(),
288 VerifyAccessTokenResult {
289 is_valid: true,
290 impersonator_id: None,
291 }
292 ));
293
294 let hash = db
295 .transaction(|tx| async move {
296 Ok(access_token::Entity::find_by_id(old_token_id)
297 .one(&*tx)
298 .await?)
299 })
300 .await
301 .unwrap()
302 .unwrap()
303 .hash;
304 assert!(hash.starts_with("$sha256$"));
305
306 assert!(matches!(
307 verify_access_token(&old_token, user.user_id, db)
308 .await
309 .unwrap(),
310 VerifyAccessTokenResult {
311 is_valid: true,
312 impersonator_id: None,
313 }
314 ));
315
316 assert!(matches!(
317 verify_access_token(&token, user.user_id, db).await.unwrap(),
318 VerifyAccessTokenResult {
319 is_valid: true,
320 impersonator_id: None,
321 }
322 ));
323 }
324
325 async fn create_previous_access_token(
326 user_id: UserId,
327 impersonated_user_id: Option<UserId>,
328 db: &Database,
329 ) -> Result<String> {
330 let access_token = rpc::auth::random_token();
331 let access_token_hash = previous_hash_access_token(&access_token)?;
332 let id = db
333 .create_access_token(
334 user_id,
335 impersonated_user_id,
336 &access_token_hash,
337 MAX_ACCESS_TOKENS_TO_STORE,
338 )
339 .await?;
340 Ok(serde_json::to_string(&AccessTokenJson {
341 version: 1,
342 id,
343 token: access_token,
344 })?)
345 }
346
347 fn previous_hash_access_token(token: &str) -> Result<String> {
348 // Avoid slow hashing in debug mode.
349 let params = if cfg!(debug_assertions) {
350 scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
351 } else {
352 scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
353 };
354
355 Ok(Scrypt
356 .hash_password_customized(
357 token.as_bytes(),
358 None,
359 None,
360 params,
361 &SaltString::generate(PasswordHashRngCompat::new()),
362 )
363 .map_err(anyhow::Error::new)?
364 .to_string())
365 }
366
367 // TODO: remove once we password_hash v0.6 is released.
368 struct PasswordHashRngCompat(rand::rngs::ThreadRng);
369
370 impl PasswordHashRngCompat {
371 fn new() -> Self {
372 Self(rand::rng())
373 }
374 }
375
376 impl scrypt::password_hash::rand_core::RngCore for PasswordHashRngCompat {
377 fn next_u32(&mut self) -> u32 {
378 self.0.next_u32()
379 }
380
381 fn next_u64(&mut self) -> u64 {
382 self.0.next_u64()
383 }
384
385 fn fill_bytes(&mut self, dest: &mut [u8]) {
386 self.0.fill_bytes(dest);
387 }
388
389 fn try_fill_bytes(
390 &mut self,
391 dest: &mut [u8],
392 ) -> Result<(), scrypt::password_hash::rand_core::Error> {
393 self.fill_bytes(dest);
394 Ok(())
395 }
396 }
397
398 impl scrypt::password_hash::rand_core::CryptoRng for PasswordHashRngCompat {}
399}