1use crate::{
2 db::{self, AccessTokenId, Database, UserId},
3 rpc::Principal,
4 AppState, Error, Result,
5};
6use anyhow::{anyhow, Context as _};
7use axum::{
8 http::{self, Request, StatusCode},
9 middleware::Next,
10 response::IntoResponse,
11};
12use base64::prelude::*;
13use prometheus::{exponential_buckets, register_histogram, Histogram};
14pub use rpc::auth::random_token;
15use scrypt::{
16 password_hash::{PasswordHash, PasswordVerifier},
17 Scrypt,
18};
19use serde::{Deserialize, Serialize};
20use sha2::Digest;
21use std::sync::OnceLock;
22use std::{sync::Arc, time::Instant};
23use subtle::ConstantTimeEq;
24
25/// Validates the authorization header and adds an Extension<Principal> to the request.
26/// Authorization: <user-id> <token>
27/// <token> can be an access_token attached to that user, or an access token of an admin
28/// or (in development) the string ADMIN:<config.api_token>.
29/// Authorization: "dev-server-token" <token>
30pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
31 let mut auth_header = req
32 .headers()
33 .get(http::header::AUTHORIZATION)
34 .and_then(|header| header.to_str().ok())
35 .ok_or_else(|| {
36 Error::http(
37 StatusCode::UNAUTHORIZED,
38 "missing authorization header".to_string(),
39 )
40 })?
41 .split_whitespace();
42
43 let state = req.extensions().get::<Arc<AppState>>().unwrap();
44
45 let first = auth_header.next().unwrap_or("");
46 if first == "dev-server-token" {
47 Err(Error::http(
48 StatusCode::UNAUTHORIZED,
49 "Dev servers were removed in Zed 0.157 please upgrade to SSH remoting".to_string(),
50 ))?;
51 }
52
53 let user_id = UserId(first.parse().map_err(|_| {
54 Error::http(
55 StatusCode::BAD_REQUEST,
56 "missing user id in authorization header".to_string(),
57 )
58 })?);
59
60 let access_token = auth_header.next().ok_or_else(|| {
61 Error::http(
62 StatusCode::BAD_REQUEST,
63 "missing access token in authorization header".to_string(),
64 )
65 })?;
66
67 // In development, allow impersonation using the admin API token.
68 // Don't allow this in production because we can't tell who is doing
69 // the impersonating.
70 let validate_result = if let (Some(admin_token), true) = (
71 access_token.strip_prefix("ADMIN_TOKEN:"),
72 state.config.is_development(),
73 ) {
74 Ok(VerifyAccessTokenResult {
75 is_valid: state.config.api_token == admin_token,
76 impersonator_id: None,
77 })
78 } else {
79 verify_access_token(access_token, user_id, &state.db).await
80 };
81
82 if let Ok(validate_result) = validate_result {
83 if validate_result.is_valid {
84 let user = state
85 .db
86 .get_user_by_id(user_id)
87 .await?
88 .ok_or_else(|| anyhow!("user {} not found", user_id))?;
89
90 if let Some(impersonator_id) = validate_result.impersonator_id {
91 let admin = state
92 .db
93 .get_user_by_id(impersonator_id)
94 .await?
95 .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
96 req.extensions_mut()
97 .insert(Principal::Impersonated { user, admin });
98 } else {
99 req.extensions_mut().insert(Principal::User(user));
100 };
101 return Ok::<_, Error>(next.run(req).await);
102 }
103 }
104
105 Err(Error::http(
106 StatusCode::UNAUTHORIZED,
107 "invalid credentials".to_string(),
108 ))
109}
110
111const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
112
113#[derive(Serialize, Deserialize)]
114struct AccessTokenJson {
115 version: usize,
116 id: AccessTokenId,
117 token: String,
118}
119
120/// Creates a new access token to identify the given user. before returning it, you should
121/// encrypt it with the user's public key.
122pub async fn create_access_token(
123 db: &db::Database,
124 user_id: UserId,
125 impersonated_user_id: Option<UserId>,
126) -> Result<String> {
127 const VERSION: usize = 1;
128 let access_token = rpc::auth::random_token();
129 let access_token_hash = hash_access_token(&access_token);
130 let id = db
131 .create_access_token(
132 user_id,
133 impersonated_user_id,
134 &access_token_hash,
135 MAX_ACCESS_TOKENS_TO_STORE,
136 )
137 .await?;
138 Ok(serde_json::to_string(&AccessTokenJson {
139 version: VERSION,
140 id,
141 token: access_token,
142 })?)
143}
144
145/// Hashing prevents anyone with access to the database being able to login.
146/// As the token is randomly generated, we don't need to worry about scrypt-style
147/// protection.
148pub fn hash_access_token(token: &str) -> String {
149 let digest = sha2::Sha256::digest(token);
150 format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
151}
152
153/// Encrypts the given access token with the given public key to avoid leaking it on the way
154/// to the client.
155pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
156 use rpc::auth::EncryptionFormat;
157
158 /// The encryption format to use for the access token.
159 ///
160 /// Currently we're using the original encryption format to avoid
161 /// breaking compatibility with older clients.
162 ///
163 /// Once enough clients are capable of decrypting the newer encryption
164 /// format we can start encrypting with `EncryptionFormat::V1`.
165 const ENCRYPTION_FORMAT: EncryptionFormat = EncryptionFormat::V0;
166
167 let native_app_public_key =
168 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
169 let encrypted_access_token = native_app_public_key
170 .encrypt_string(access_token, ENCRYPTION_FORMAT)
171 .context("failed to encrypt access token with public key")?;
172 Ok(encrypted_access_token)
173}
174
175pub struct VerifyAccessTokenResult {
176 pub is_valid: bool,
177 pub impersonator_id: Option<UserId>,
178}
179
180/// Checks that the given access token is valid for the given user.
181pub async fn verify_access_token(
182 token: &str,
183 user_id: UserId,
184 db: &Arc<Database>,
185) -> Result<VerifyAccessTokenResult> {
186 static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
187 let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
188 register_histogram!(
189 "access_token_hashing_time",
190 "time spent hashing access tokens",
191 exponential_buckets(10.0, 2.0, 10).unwrap(),
192 )
193 .unwrap()
194 });
195
196 let token: AccessTokenJson = serde_json::from_str(token)?;
197
198 let db_token = db.get_access_token(token.id).await?;
199 let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
200 if token_user_id != user_id {
201 return Err(anyhow!("no such access token"))?;
202 }
203 let t0 = Instant::now();
204
205 let is_valid = if db_token.hash.starts_with("$scrypt$") {
206 let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
207 Scrypt
208 .verify_password(token.token.as_bytes(), &db_hash)
209 .is_ok()
210 } else {
211 let token_hash = hash_access_token(&token.token);
212 db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
213 };
214
215 let duration = t0.elapsed();
216 log::info!("hashed access token in {:?}", duration);
217 metric_access_token_hashing_time.observe(duration.as_millis() as f64);
218
219 if is_valid && db_token.hash.starts_with("$scrypt$") {
220 let new_hash = hash_access_token(&token.token);
221 db.update_access_token_hash(db_token.id, &new_hash).await?;
222 }
223
224 Ok(VerifyAccessTokenResult {
225 is_valid,
226 impersonator_id: if db_token.impersonated_user_id.is_some() {
227 Some(db_token.user_id)
228 } else {
229 None
230 },
231 })
232}
233
234#[cfg(test)]
235mod test {
236 use rand::thread_rng;
237 use scrypt::password_hash::{PasswordHasher, SaltString};
238 use sea_orm::EntityTrait;
239
240 use super::*;
241 use crate::db::{access_token, NewUserParams};
242
243 #[gpui::test]
244 async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
245 let test_db = crate::db::TestDb::sqlite(cx.executor().clone());
246 let db = test_db.db();
247
248 let user = db
249 .create_user(
250 "example@example.com",
251 None,
252 false,
253 NewUserParams {
254 github_login: "example".into(),
255 github_user_id: 1,
256 },
257 )
258 .await
259 .unwrap();
260
261 let token = create_access_token(db, user.user_id, None).await.unwrap();
262 assert!(matches!(
263 verify_access_token(&token, user.user_id, db).await.unwrap(),
264 VerifyAccessTokenResult {
265 is_valid: true,
266 impersonator_id: None,
267 }
268 ));
269
270 let old_token = create_previous_access_token(user.user_id, None, db)
271 .await
272 .unwrap();
273
274 let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
275 .unwrap()
276 .id;
277
278 let hash = db
279 .transaction(|tx| async move {
280 Ok(access_token::Entity::find_by_id(old_token_id)
281 .one(&*tx)
282 .await?)
283 })
284 .await
285 .unwrap()
286 .unwrap()
287 .hash;
288 assert!(hash.starts_with("$scrypt$"));
289
290 assert!(matches!(
291 verify_access_token(&old_token, user.user_id, db)
292 .await
293 .unwrap(),
294 VerifyAccessTokenResult {
295 is_valid: true,
296 impersonator_id: None,
297 }
298 ));
299
300 let hash = db
301 .transaction(|tx| async move {
302 Ok(access_token::Entity::find_by_id(old_token_id)
303 .one(&*tx)
304 .await?)
305 })
306 .await
307 .unwrap()
308 .unwrap()
309 .hash;
310 assert!(hash.starts_with("$sha256$"));
311
312 assert!(matches!(
313 verify_access_token(&old_token, user.user_id, db)
314 .await
315 .unwrap(),
316 VerifyAccessTokenResult {
317 is_valid: true,
318 impersonator_id: None,
319 }
320 ));
321
322 assert!(matches!(
323 verify_access_token(&token, user.user_id, db).await.unwrap(),
324 VerifyAccessTokenResult {
325 is_valid: true,
326 impersonator_id: None,
327 }
328 ));
329 }
330
331 async fn create_previous_access_token(
332 user_id: UserId,
333 impersonated_user_id: Option<UserId>,
334 db: &Database,
335 ) -> Result<String> {
336 let access_token = rpc::auth::random_token();
337 let access_token_hash = previous_hash_access_token(&access_token)?;
338 let id = db
339 .create_access_token(
340 user_id,
341 impersonated_user_id,
342 &access_token_hash,
343 MAX_ACCESS_TOKENS_TO_STORE,
344 )
345 .await?;
346 Ok(serde_json::to_string(&AccessTokenJson {
347 version: 1,
348 id,
349 token: access_token,
350 })?)
351 }
352
353 fn previous_hash_access_token(token: &str) -> Result<String> {
354 // Avoid slow hashing in debug mode.
355 let params = if cfg!(debug_assertions) {
356 scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
357 } else {
358 scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
359 };
360
361 Ok(Scrypt
362 .hash_password_customized(
363 token.as_bytes(),
364 None,
365 None,
366 params,
367 &SaltString::generate(thread_rng()),
368 )
369 .map_err(anyhow::Error::new)?
370 .to_string())
371 }
372}