auth.rs

  1use crate::{
  2    db::{self, AccessTokenId, Database, UserId},
  3    rpc::Principal,
  4    AppState, Error, Result,
  5};
  6use anyhow::{anyhow, Context as _};
  7use axum::{
  8    http::{self, Request, StatusCode},
  9    middleware::Next,
 10    response::IntoResponse,
 11};
 12use base64::prelude::*;
 13use prometheus::{exponential_buckets, register_histogram, Histogram};
 14pub use rpc::auth::random_token;
 15use scrypt::{
 16    password_hash::{PasswordHash, PasswordVerifier},
 17    Scrypt,
 18};
 19use serde::{Deserialize, Serialize};
 20use sha2::Digest;
 21use std::sync::OnceLock;
 22use std::{sync::Arc, time::Instant};
 23use subtle::ConstantTimeEq;
 24
 25/// Validates the authorization header and adds an Extension<Principal> to the request.
 26/// Authorization: <user-id> <token>
 27///   <token> can be an access_token attached to that user, or an access token of an admin
 28///   or (in development) the string ADMIN:<config.api_token>.
 29/// Authorization: "dev-server-token" <token>
 30pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 31    let mut auth_header = req
 32        .headers()
 33        .get(http::header::AUTHORIZATION)
 34        .and_then(|header| header.to_str().ok())
 35        .ok_or_else(|| {
 36            Error::http(
 37                StatusCode::UNAUTHORIZED,
 38                "missing authorization header".to_string(),
 39            )
 40        })?
 41        .split_whitespace();
 42
 43    let state = req.extensions().get::<Arc<AppState>>().unwrap();
 44
 45    let first = auth_header.next().unwrap_or("");
 46    if first == "dev-server-token" {
 47        Err(Error::http(
 48            StatusCode::UNAUTHORIZED,
 49            "Dev servers were removed in Zed 0.157 please upgrade to SSH remoting".to_string(),
 50        ))?;
 51    }
 52
 53    let user_id = UserId(first.parse().map_err(|_| {
 54        Error::http(
 55            StatusCode::BAD_REQUEST,
 56            "missing user id in authorization header".to_string(),
 57        )
 58    })?);
 59
 60    let access_token = auth_header.next().ok_or_else(|| {
 61        Error::http(
 62            StatusCode::BAD_REQUEST,
 63            "missing access token in authorization header".to_string(),
 64        )
 65    })?;
 66
 67    // In development, allow impersonation using the admin API token.
 68    // Don't allow this in production because we can't tell who is doing
 69    // the impersonating.
 70    let validate_result = if let (Some(admin_token), true) = (
 71        access_token.strip_prefix("ADMIN_TOKEN:"),
 72        state.config.is_development(),
 73    ) {
 74        Ok(VerifyAccessTokenResult {
 75            is_valid: state.config.api_token == admin_token,
 76            impersonator_id: None,
 77        })
 78    } else {
 79        verify_access_token(access_token, user_id, &state.db).await
 80    };
 81
 82    if let Ok(validate_result) = validate_result {
 83        if validate_result.is_valid {
 84            let user = state
 85                .db
 86                .get_user_by_id(user_id)
 87                .await?
 88                .ok_or_else(|| anyhow!("user {} not found", user_id))?;
 89
 90            if let Some(impersonator_id) = validate_result.impersonator_id {
 91                let admin = state
 92                    .db
 93                    .get_user_by_id(impersonator_id)
 94                    .await?
 95                    .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
 96                req.extensions_mut()
 97                    .insert(Principal::Impersonated { user, admin });
 98            } else {
 99                req.extensions_mut().insert(Principal::User(user));
100            };
101            return Ok::<_, Error>(next.run(req).await);
102        }
103    }
104
105    Err(Error::http(
106        StatusCode::UNAUTHORIZED,
107        "invalid credentials".to_string(),
108    ))
109}
110
111const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
112
113#[derive(Serialize, Deserialize)]
114struct AccessTokenJson {
115    version: usize,
116    id: AccessTokenId,
117    token: String,
118}
119
120/// Creates a new access token to identify the given user. before returning it, you should
121/// encrypt it with the user's public key.
122pub async fn create_access_token(
123    db: &db::Database,
124    user_id: UserId,
125    impersonated_user_id: Option<UserId>,
126) -> Result<String> {
127    const VERSION: usize = 1;
128    let access_token = rpc::auth::random_token();
129    let access_token_hash = hash_access_token(&access_token);
130    let id = db
131        .create_access_token(
132            user_id,
133            impersonated_user_id,
134            &access_token_hash,
135            MAX_ACCESS_TOKENS_TO_STORE,
136        )
137        .await?;
138    Ok(serde_json::to_string(&AccessTokenJson {
139        version: VERSION,
140        id,
141        token: access_token,
142    })?)
143}
144
145/// Hashing prevents anyone with access to the database being able to login.
146/// As the token is randomly generated, we don't need to worry about scrypt-style
147/// protection.
148pub fn hash_access_token(token: &str) -> String {
149    let digest = sha2::Sha256::digest(token);
150    format!("$sha256${}", BASE64_URL_SAFE.encode(digest))
151}
152
153/// Encrypts the given access token with the given public key to avoid leaking it on the way
154/// to the client.
155pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
156    use rpc::auth::EncryptionFormat;
157
158    /// The encryption format to use for the access token.
159    ///
160    /// Currently we're using the original encryption format to avoid
161    /// breaking compatibility with older clients.
162    ///
163    /// Once enough clients are capable of decrypting the newer encryption
164    /// format we can start encrypting with `EncryptionFormat::V1`.
165    const ENCRYPTION_FORMAT: EncryptionFormat = EncryptionFormat::V0;
166
167    let native_app_public_key =
168        rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
169    let encrypted_access_token = native_app_public_key
170        .encrypt_string(access_token, ENCRYPTION_FORMAT)
171        .context("failed to encrypt access token with public key")?;
172    Ok(encrypted_access_token)
173}
174
175pub struct VerifyAccessTokenResult {
176    pub is_valid: bool,
177    pub impersonator_id: Option<UserId>,
178}
179
180/// Checks that the given access token is valid for the given user.
181pub async fn verify_access_token(
182    token: &str,
183    user_id: UserId,
184    db: &Arc<Database>,
185) -> Result<VerifyAccessTokenResult> {
186    static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
187    let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
188        register_histogram!(
189            "access_token_hashing_time",
190            "time spent hashing access tokens",
191            exponential_buckets(10.0, 2.0, 10).unwrap(),
192        )
193        .unwrap()
194    });
195
196    let token: AccessTokenJson = serde_json::from_str(token)?;
197
198    let db_token = db.get_access_token(token.id).await?;
199    let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
200    if token_user_id != user_id {
201        return Err(anyhow!("no such access token"))?;
202    }
203    let t0 = Instant::now();
204
205    let is_valid = if db_token.hash.starts_with("$scrypt$") {
206        let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
207        Scrypt
208            .verify_password(token.token.as_bytes(), &db_hash)
209            .is_ok()
210    } else {
211        let token_hash = hash_access_token(&token.token);
212        db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
213    };
214
215    let duration = t0.elapsed();
216    log::info!("hashed access token in {:?}", duration);
217    metric_access_token_hashing_time.observe(duration.as_millis() as f64);
218
219    if is_valid && db_token.hash.starts_with("$scrypt$") {
220        let new_hash = hash_access_token(&token.token);
221        db.update_access_token_hash(db_token.id, &new_hash).await?;
222    }
223
224    Ok(VerifyAccessTokenResult {
225        is_valid,
226        impersonator_id: if db_token.impersonated_user_id.is_some() {
227            Some(db_token.user_id)
228        } else {
229            None
230        },
231    })
232}
233
234#[cfg(test)]
235mod test {
236    use rand::thread_rng;
237    use scrypt::password_hash::{PasswordHasher, SaltString};
238    use sea_orm::EntityTrait;
239
240    use super::*;
241    use crate::db::{access_token, NewUserParams};
242
243    #[gpui::test]
244    async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
245        let test_db = crate::db::TestDb::sqlite(cx.executor().clone());
246        let db = test_db.db();
247
248        let user = db
249            .create_user(
250                "example@example.com",
251                None,
252                false,
253                NewUserParams {
254                    github_login: "example".into(),
255                    github_user_id: 1,
256                },
257            )
258            .await
259            .unwrap();
260
261        let token = create_access_token(db, user.user_id, None).await.unwrap();
262        assert!(matches!(
263            verify_access_token(&token, user.user_id, db).await.unwrap(),
264            VerifyAccessTokenResult {
265                is_valid: true,
266                impersonator_id: None,
267            }
268        ));
269
270        let old_token = create_previous_access_token(user.user_id, None, db)
271            .await
272            .unwrap();
273
274        let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
275            .unwrap()
276            .id;
277
278        let hash = db
279            .transaction(|tx| async move {
280                Ok(access_token::Entity::find_by_id(old_token_id)
281                    .one(&*tx)
282                    .await?)
283            })
284            .await
285            .unwrap()
286            .unwrap()
287            .hash;
288        assert!(hash.starts_with("$scrypt$"));
289
290        assert!(matches!(
291            verify_access_token(&old_token, user.user_id, db)
292                .await
293                .unwrap(),
294            VerifyAccessTokenResult {
295                is_valid: true,
296                impersonator_id: None,
297            }
298        ));
299
300        let hash = db
301            .transaction(|tx| async move {
302                Ok(access_token::Entity::find_by_id(old_token_id)
303                    .one(&*tx)
304                    .await?)
305            })
306            .await
307            .unwrap()
308            .unwrap()
309            .hash;
310        assert!(hash.starts_with("$sha256$"));
311
312        assert!(matches!(
313            verify_access_token(&old_token, user.user_id, db)
314                .await
315                .unwrap(),
316            VerifyAccessTokenResult {
317                is_valid: true,
318                impersonator_id: None,
319            }
320        ));
321
322        assert!(matches!(
323            verify_access_token(&token, user.user_id, db).await.unwrap(),
324            VerifyAccessTokenResult {
325                is_valid: true,
326                impersonator_id: None,
327            }
328        ));
329    }
330
331    async fn create_previous_access_token(
332        user_id: UserId,
333        impersonated_user_id: Option<UserId>,
334        db: &Database,
335    ) -> Result<String> {
336        let access_token = rpc::auth::random_token();
337        let access_token_hash = previous_hash_access_token(&access_token)?;
338        let id = db
339            .create_access_token(
340                user_id,
341                impersonated_user_id,
342                &access_token_hash,
343                MAX_ACCESS_TOKENS_TO_STORE,
344            )
345            .await?;
346        Ok(serde_json::to_string(&AccessTokenJson {
347            version: 1,
348            id,
349            token: access_token,
350        })?)
351    }
352
353    fn previous_hash_access_token(token: &str) -> Result<String> {
354        // Avoid slow hashing in debug mode.
355        let params = if cfg!(debug_assertions) {
356            scrypt::Params::new(1, 1, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
357        } else {
358            scrypt::Params::new(14, 8, 1, scrypt::Params::RECOMMENDED_LEN).unwrap()
359        };
360
361        Ok(Scrypt
362            .hash_password_customized(
363                token.as_bytes(),
364                None,
365                None,
366                params,
367                &SaltString::generate(thread_rng()),
368            )
369            .map_err(anyhow::Error::new)?
370            .to_string())
371    }
372}