1use crate::{
2 db::{self, AccessTokenId, Database, UserId},
3 rpc::Principal,
4 AppState, Error, Result,
5};
6use anyhow::{anyhow, Context};
7use axum::{
8 http::{self, Request, StatusCode},
9 middleware::Next,
10 response::IntoResponse,
11};
12use base64::prelude::*;
13use prometheus::{exponential_buckets, register_histogram, Histogram};
14pub use rpc::auth::random_token;
15use scrypt::{
16 password_hash::{PasswordHash, PasswordVerifier},
17 Scrypt,
18};
19use serde::{Deserialize, Serialize};
20use sha2::Digest;
21use std::sync::OnceLock;
22use std::{sync::Arc, time::Instant};
23use subtle::ConstantTimeEq;
24
25/// Validates the authorization header and adds an Extension<Principal> to the request.
26/// Authorization: <user-id> <token>
27/// <token> can be an access_token attached to that user, or an access token of an admin
28/// or (in development) the string ADMIN:<config.api_token>.
29/// Authorization: "dev-server-token" <token>
30pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
31 let mut auth_header = req
32 .headers()
33 .get(http::header::AUTHORIZATION)
34 .and_then(|header| header.to_str().ok())
35 .ok_or_else(|| {
36 Error::http(
37 StatusCode::UNAUTHORIZED,
38 "missing authorization header".to_string(),
39 )
40 })?
41 .split_whitespace();
42
43 let state = req.extensions().get::<Arc<AppState>>().unwrap();
44
45 let first = auth_header.next().unwrap_or("");
46 if first == "dev-server-token" {
47 Err(Error::http(
48 StatusCode::UNAUTHORIZED,
49 "Dev servers were removed in Zed 0.157 please upgrade to SSH remoting".to_string(),
50 ))?;
51 }
52
53 let user_id = UserId(first.parse().map_err(|_| {
54 Error::http(
55 StatusCode::BAD_REQUEST,
56 "missing user id in authorization header".to_string(),
57 )
58 })?);
59
60 let access_token = auth_header.next().ok_or_else(|| {
61 Error::http(
62 StatusCode::BAD_REQUEST,
63 "missing access token in authorization header".to_string(),
64 )
65 })?;
66
67 // In development, allow impersonation using the admin API token.
68 // Don't allow this in production because we can't tell who is doing
69 // the impersonating.
70 let validate_result = if let (Some(admin_token), true) = (
71 access_token.strip_prefix("ADMIN_TOKEN:"),
72 state.config.is_development(),
73 ) {
74 Ok(VerifyAccessTokenResult {
75 is_valid: state.config.api_token == admin_token,
76 impersonator_id: None,
77 })
78 } else {
79 verify_access_token(access_token, user_id, &state.db).await
80 };
81
82 if let Ok(validate_result) = validate_result {
83 if validate_result.is_valid {
84 let user = state
85 .db
86 .get_user_by_id(user_id)
87 .await?
88 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
89
90 if let Some(impersonator_id) = validate_result.impersonator_id {
91 let admin = state
92 .db
93 .get_user_by_id(impersonator_id)
94 .await?
95 .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
96 req.extensions_mut()
97 .insert(Principal::Impersonated { user, admin });
98 } else {
99 req.extensions_mut().insert(Principal::User(user));
100 };
101 return Ok::<_, Error>(next.run(req).await);
102 }
103 }
104
105 Err(Error::http(
106 StatusCode::UNAUTHORIZED,
107 "invalid credentials".to_string(),
108 ))
109}
110
111const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
112
113#[derive(Serialize, Deserialize)]
114struct AccessTokenJson {
115 version: usize,
116 id: AccessTokenId,
117 token: String,
118}
119
120/// Creates a new access token to identify the given user. before returning it, you should
121/// encrypt it with the user's public key.
122pub async fn create_access_token(
123 db: &db::Database,
124 user_id: UserId,
125 impersonated_user_id: Option<UserId>,
126) -> Result<String> {
127 const VERSION: usize = 1;
128 let access_token = rpc::auth::random_token();
129 let access_token_hash = hash_access_token(&access_token);
130 let id = db
131 .create_access_token(
132 user_id,
133 impersonated_user_id,
134 &access_token_hash,
135 MAX_ACCESS_TOKENS_TO_STORE,
136 )
137 .await?;
138 Ok(serde_json::to_string(&AccessTokenJson {
139 version: VERSION,
140 id,
141 token: access_token,
142 })?)
143}
144
145/// Hashing prevents anyone with access to the database being able to login.
146/// As the token is randomly generated, we don't need to worry about scrypt-style
147/// protection.
148pub fn hash_access_token(token: &str) -> String {
149 let digest = sha2::Sha256::digest(token);
150 format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
151}
152
153/// Encrypts the given access token with the given public key to avoid leaking it on the way
154/// to the client.
155pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
156 use rpc::auth::EncryptionFormat;
157
158 /// The encryption format to use for the access token.
159 ///
160 /// Currently we're using the original encryption format to avoid
161 /// breaking compatibility with older clients.
162 ///
163 /// Once enough clients are capable of decrypting the newer encryption
164 /// format we can start encrypting with `EncryptionFormat::V1`.
165 const ENCRYPTION_FORMAT: EncryptionFormat = EncryptionFormat::V0;
166
167 let native_app_public_key =
168 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
169 let encrypted_access_token = native_app_public_key
170 .encrypt_string(access_token, ENCRYPTION_FORMAT)
171 .context("failed to encrypt access token with public key")?;
172 Ok(encrypted_access_token)
173}
174
175pub struct VerifyAccessTokenResult {
176 pub is_valid: bool,
177 pub impersonator_id: Option<UserId>,
178}
179
180/// Checks that the given access token is valid for the given user.
181pub async fn verify_access_token(
182 token: &str,
183 user_id: UserId,
184 db: &Arc<Database>,
185) -> Result<VerifyAccessTokenResult> {
186 static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
187 let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
188 register_histogram!(
189 "access_token_hashing_time",
190 "time spent hashing access tokens",
191 exponential_buckets(10.0, 2.0, 10).unwrap(),
192 )
193 .unwrap()
194 });
195
196 let token: AccessTokenJson = serde_json::from_str(token)?;
197
198 let db_token = db.get_access_token(token.id).await?;
199 let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
200 if token_user_id != user_id {
201 return Err(anyhow!("no such access token"))?;
202 }
203 let t0 = Instant::now();
204
205 let is_valid = if db_token.hash.starts_with("$scrypt$") {
206 let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
207 Scrypt
208 .verify_password(token.token.as_bytes(), &db_hash)
209 .is_ok()
210 } else {
211 let token_hash = hash_access_token(&token.token);
212 db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
213 };
214
215 let duration = t0.elapsed();
216 log::info!("hashed access token in {:?}", duration);
217 metric_access_token_hashing_time.observe(duration.as_millis() as f64);
218
219 if is_valid && db_token.hash.starts_with("$scrypt$") {
220 let new_hash = hash_access_token(&token.token);
221 db.update_access_token_hash(db_token.id, &new_hash).await?;
222 }
223
224 Ok(VerifyAccessTokenResult {
225 is_valid,
226 impersonator_id: if db_token.impersonated_user_id.is_some() {
227 Some(db_token.user_id)
228 } else {
229 None
230 },
231 })
232}
233
234#[cfg(test)]
235mod test {
236 use rand::thread_rng;
237 use scrypt::password_hash::{PasswordHasher, SaltString};
238 use sea_orm::EntityTrait;
239
240 use super::*;
241 use crate::db::{access_token, NewUserParams};
242
243 #[gpui::test]
244 async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
245 let test_db = crate::db::TestDb::sqlite(cx.executor().clone());
246 let db = test_db.db();
247
248 let user = db
249 .create_user(
250 "example@example.com",
251 false,
252 NewUserParams {
253 github_login: "example".into(),
254 github_user_id: 1,
255 },
256 )
257 .await
258 .unwrap();
259
260 let token = create_access_token(db, user.user_id, None).await.unwrap();
261 assert!(matches!(
262 verify_access_token(&token, user.user_id, db).await.unwrap(),
263 VerifyAccessTokenResult {
264 is_valid: true,
265 impersonator_id: None,
266 }
267 ));
268
269 let old_token = create_previous_access_token(user.user_id, None, db)
270 .await
271 .unwrap();
272
273 let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
274 .unwrap()
275 .id;
276
277 let hash = db
278 .transaction(|tx| async move {
279 Ok(access_token::Entity::find_by_id(old_token_id)
280 .one(&*tx)
281 .await?)
282 })
283 .await
284 .unwrap()
285 .unwrap()
286 .hash;
287 assert!(hash.starts_with("$scrypt$"));
288
289 assert!(matches!(
290 verify_access_token(&old_token, user.user_id, db)
291 .await
292 .unwrap(),
293 VerifyAccessTokenResult {
294 is_valid: true,
295 impersonator_id: None,
296 }
297 ));
298
299 let hash = db
300 .transaction(|tx| async move {
301 Ok(access_token::Entity::find_by_id(old_token_id)
302 .one(&*tx)
303 .await?)
304 })
305 .await
306 .unwrap()
307 .unwrap()
308 .hash;
309 assert!(hash.starts_with("$sha256$"));
310
311 assert!(matches!(
312 verify_access_token(&old_token, user.user_id, db)
313 .await
314 .unwrap(),
315 VerifyAccessTokenResult {
316 is_valid: true,
317 impersonator_id: None,
318 }
319 ));
320
321 assert!(matches!(
322 verify_access_token(&token, user.user_id, db).await.unwrap(),
323 VerifyAccessTokenResult {
324 is_valid: true,
325 impersonator_id: None,
326 }
327 ));
328 }
329
330 async fn create_previous_access_token(
331 user_id: UserId,
332 impersonated_user_id: Option<UserId>,
333 db: &Database,
334 ) -> Result<String> {
335 let access_token = rpc::auth::random_token();
336 let access_token_hash = previous_hash_access_token(&access_token)?;
337 let id = db
338 .create_access_token(
339 user_id,
340 impersonated_user_id,
341 &access_token_hash,
342 MAX_ACCESS_TOKENS_TO_STORE,
343 )
344 .await?;
345 Ok(serde_json::to_string(&AccessTokenJson {
346 version: 1,
347 id,
348 token: access_token,
349 })?)
350 }
351
352 fn previous_hash_access_token(token: &str) -> Result<String> {
353 // Avoid slow hashing in debug mode.
354 let params = if cfg!(debug_assertions) {
355 scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
356 } else {
357 scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
358 };
359
360 Ok(Scrypt
361 .hash_password_customized(
362 token.as_bytes(),
363 None,
364 None,
365 params,
366 &SaltString::generate(thread_rng()),
367 )
368 .map_err(anyhow::Error::new)?
369 .to_string())
370 }
371}