1use crate::{
2 AppState, Error, Result,
3 db::{self, AccessTokenId, Database, UserId},
4 rpc::Principal,
5};
6use anyhow::Context as _;
7use axum::{
8 http::{self, Request, StatusCode},
9 middleware::Next,
10 response::IntoResponse,
11};
12use base64::prelude::*;
13use prometheus::{Histogram, exponential_buckets, register_histogram};
14pub use rpc::auth::random_token;
15use scrypt::{
16 Scrypt,
17 password_hash::{PasswordHash, PasswordVerifier},
18};
19use serde::{Deserialize, Serialize};
20use sha2::Digest;
21use std::sync::OnceLock;
22use std::{sync::Arc, time::Instant};
23use subtle::ConstantTimeEq;
24
25/// Validates the authorization header and adds an Extension<Principal> to the request.
26/// Authorization: <user-id> <token>
27/// <token> can be an access_token attached to that user, or an access token of an admin
28/// or (in development) the string ADMIN:<config.api_token>.
29/// Authorization: "dev-server-token" <token>
30pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
31 let mut auth_header = req
32 .headers()
33 .get(http::header::AUTHORIZATION)
34 .and_then(|header| header.to_str().ok())
35 .ok_or_else(|| {
36 Error::http(
37 StatusCode::UNAUTHORIZED,
38 "missing authorization header".to_string(),
39 )
40 })?
41 .split_whitespace();
42
43 let state = req.extensions().get::<Arc<AppState>>().unwrap();
44
45 let first = auth_header.next().unwrap_or("");
46 if first == "dev-server-token" {
47 Err(Error::http(
48 StatusCode::UNAUTHORIZED,
49 "Dev servers were removed in Zed 0.157 please upgrade to SSH remoting".to_string(),
50 ))?;
51 }
52
53 let user_id = UserId(first.parse().map_err(|_| {
54 Error::http(
55 StatusCode::BAD_REQUEST,
56 "missing user id in authorization header".to_string(),
57 )
58 })?);
59
60 let access_token = auth_header.next().ok_or_else(|| {
61 Error::http(
62 StatusCode::BAD_REQUEST,
63 "missing access token in authorization header".to_string(),
64 )
65 })?;
66
67 // In development, allow impersonation using the admin API token.
68 // Don't allow this in production because we can't tell who is doing
69 // the impersonating.
70 let validate_result = if let (Some(admin_token), true) = (
71 access_token.strip_prefix("ADMIN_TOKEN:"),
72 state.config.is_development(),
73 ) {
74 Ok(VerifyAccessTokenResult {
75 is_valid: state.config.api_token == admin_token,
76 impersonator_id: None,
77 })
78 } else {
79 verify_access_token(access_token, user_id, &state.db).await
80 };
81
82 if let Ok(validate_result) = validate_result
83 && validate_result.is_valid {
84 let user = state
85 .db
86 .get_user_by_id(user_id)
87 .await?
88 .with_context(|| format!("user {user_id} not found"))?;
89
90 if let Some(impersonator_id) = validate_result.impersonator_id {
91 let admin = state
92 .db
93 .get_user_by_id(impersonator_id)
94 .await?
95 .with_context(|| format!("user {impersonator_id} not found"))?;
96 req.extensions_mut()
97 .insert(Principal::Impersonated { user, admin });
98 } else {
99 req.extensions_mut().insert(Principal::User(user));
100 };
101 return Ok::<_, Error>(next.run(req).await);
102 }
103
104 Err(Error::http(
105 StatusCode::UNAUTHORIZED,
106 "invalid credentials".to_string(),
107 ))
108}
109
110const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
111
112#[derive(Serialize, Deserialize)]
113struct AccessTokenJson {
114 version: usize,
115 id: AccessTokenId,
116 token: String,
117}
118
119/// Creates a new access token to identify the given user. before returning it, you should
120/// encrypt it with the user's public key.
121pub async fn create_access_token(
122 db: &db::Database,
123 user_id: UserId,
124 impersonated_user_id: Option<UserId>,
125) -> Result<String> {
126 const VERSION: usize = 1;
127 let access_token = rpc::auth::random_token();
128 let access_token_hash = hash_access_token(&access_token);
129 let id = db
130 .create_access_token(
131 user_id,
132 impersonated_user_id,
133 &access_token_hash,
134 MAX_ACCESS_TOKENS_TO_STORE,
135 )
136 .await?;
137 Ok(serde_json::to_string(&AccessTokenJson {
138 version: VERSION,
139 id,
140 token: access_token,
141 })?)
142}
143
144/// Hashing prevents anyone with access to the database being able to login.
145/// As the token is randomly generated, we don't need to worry about scrypt-style
146/// protection.
147pub fn hash_access_token(token: &str) -> String {
148 let digest = sha2::Sha256::digest(token);
149 format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
150}
151
152/// Encrypts the given access token with the given public key to avoid leaking it on the way
153/// to the client.
154pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
155 use rpc::auth::EncryptionFormat;
156
157 /// The encryption format to use for the access token.
158 const ENCRYPTION_FORMAT: EncryptionFormat = EncryptionFormat::V1;
159
160 let native_app_public_key =
161 rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
162 let encrypted_access_token = native_app_public_key
163 .encrypt_string(access_token, ENCRYPTION_FORMAT)
164 .context("failed to encrypt access token with public key")?;
165 Ok(encrypted_access_token)
166}
167
168pub struct VerifyAccessTokenResult {
169 pub is_valid: bool,
170 pub impersonator_id: Option<UserId>,
171}
172
173/// Checks that the given access token is valid for the given user.
174pub async fn verify_access_token(
175 token: &str,
176 user_id: UserId,
177 db: &Arc<Database>,
178) -> Result<VerifyAccessTokenResult> {
179 static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
180 let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
181 register_histogram!(
182 "access_token_hashing_time",
183 "time spent hashing access tokens",
184 exponential_buckets(10.0, 2.0, 10).unwrap(),
185 )
186 .unwrap()
187 });
188
189 let token: AccessTokenJson = serde_json::from_str(token)?;
190
191 let db_token = db.get_access_token(token.id).await?;
192 let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
193 if token_user_id != user_id {
194 return Err(anyhow::anyhow!("no such access token"))?;
195 }
196 let t0 = Instant::now();
197
198 let is_valid = if db_token.hash.starts_with("$scrypt$") {
199 let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
200 Scrypt
201 .verify_password(token.token.as_bytes(), &db_hash)
202 .is_ok()
203 } else {
204 let token_hash = hash_access_token(&token.token);
205 db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
206 };
207
208 let duration = t0.elapsed();
209 log::info!("hashed access token in {:?}", duration);
210 metric_access_token_hashing_time.observe(duration.as_millis() as f64);
211
212 if is_valid && db_token.hash.starts_with("$scrypt$") {
213 let new_hash = hash_access_token(&token.token);
214 db.update_access_token_hash(db_token.id, &new_hash).await?;
215 }
216
217 Ok(VerifyAccessTokenResult {
218 is_valid,
219 impersonator_id: if db_token.impersonated_user_id.is_some() {
220 Some(db_token.user_id)
221 } else {
222 None
223 },
224 })
225}
226
227#[cfg(test)]
228mod test {
229 use rand::thread_rng;
230 use scrypt::password_hash::{PasswordHasher, SaltString};
231 use sea_orm::EntityTrait;
232
233 use super::*;
234 use crate::db::{NewUserParams, access_token};
235
236 #[gpui::test]
237 async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
238 let test_db = crate::db::TestDb::sqlite(cx.executor().clone());
239 let db = test_db.db();
240
241 let user = db
242 .create_user(
243 "example@example.com",
244 None,
245 false,
246 NewUserParams {
247 github_login: "example".into(),
248 github_user_id: 1,
249 },
250 )
251 .await
252 .unwrap();
253
254 let token = create_access_token(db, user.user_id, None).await.unwrap();
255 assert!(matches!(
256 verify_access_token(&token, user.user_id, db).await.unwrap(),
257 VerifyAccessTokenResult {
258 is_valid: true,
259 impersonator_id: None,
260 }
261 ));
262
263 let old_token = create_previous_access_token(user.user_id, None, db)
264 .await
265 .unwrap();
266
267 let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
268 .unwrap()
269 .id;
270
271 let hash = db
272 .transaction(|tx| async move {
273 Ok(access_token::Entity::find_by_id(old_token_id)
274 .one(&*tx)
275 .await?)
276 })
277 .await
278 .unwrap()
279 .unwrap()
280 .hash;
281 assert!(hash.starts_with("$scrypt$"));
282
283 assert!(matches!(
284 verify_access_token(&old_token, user.user_id, db)
285 .await
286 .unwrap(),
287 VerifyAccessTokenResult {
288 is_valid: true,
289 impersonator_id: None,
290 }
291 ));
292
293 let hash = db
294 .transaction(|tx| async move {
295 Ok(access_token::Entity::find_by_id(old_token_id)
296 .one(&*tx)
297 .await?)
298 })
299 .await
300 .unwrap()
301 .unwrap()
302 .hash;
303 assert!(hash.starts_with("$sha256$"));
304
305 assert!(matches!(
306 verify_access_token(&old_token, user.user_id, db)
307 .await
308 .unwrap(),
309 VerifyAccessTokenResult {
310 is_valid: true,
311 impersonator_id: None,
312 }
313 ));
314
315 assert!(matches!(
316 verify_access_token(&token, user.user_id, db).await.unwrap(),
317 VerifyAccessTokenResult {
318 is_valid: true,
319 impersonator_id: None,
320 }
321 ));
322 }
323
324 async fn create_previous_access_token(
325 user_id: UserId,
326 impersonated_user_id: Option<UserId>,
327 db: &Database,
328 ) -> Result<String> {
329 let access_token = rpc::auth::random_token();
330 let access_token_hash = previous_hash_access_token(&access_token)?;
331 let id = db
332 .create_access_token(
333 user_id,
334 impersonated_user_id,
335 &access_token_hash,
336 MAX_ACCESS_TOKENS_TO_STORE,
337 )
338 .await?;
339 Ok(serde_json::to_string(&AccessTokenJson {
340 version: 1,
341 id,
342 token: access_token,
343 })?)
344 }
345
346 fn previous_hash_access_token(token: &str) -> Result<String> {
347 // Avoid slow hashing in debug mode.
348 let params = if cfg!(debug_assertions) {
349 scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
350 } else {
351 scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
352 };
353
354 Ok(Scrypt
355 .hash_password_customized(
356 token.as_bytes(),
357 None,
358 None,
359 params,
360 &SaltString::generate(thread_rng()),
361 )
362 .map_err(anyhow::Error::new)?
363 .to_string())
364 }
365}