auth.rs

  1use crate::{
  2    db::{self, dev_server, AccessTokenId, Database, DevServerId, UserId},
  3    rpc::Principal,
  4    AppState, Error, Result,
  5};
  6use anyhow::{anyhow, Context};
  7use axum::{
  8    http::{self, Request, StatusCode},
  9    middleware::Next,
 10    response::IntoResponse,
 11};
 12use prometheus::{exponential_buckets, register_histogram, Histogram};
 13use scrypt::{
 14    password_hash::{PasswordHash, PasswordVerifier},
 15    Scrypt,
 16};
 17use serde::{Deserialize, Serialize};
 18use sha2::Digest;
 19use std::sync::OnceLock;
 20use std::{sync::Arc, time::Instant};
 21use subtle::ConstantTimeEq;
 22
 23/// Validates the authorization header and adds an Extension<Principal> to the request.
 24/// Authorization: <user-id> <token>
 25///   <token> can be an access_token attached to that user, or an access token of an admin
 26///   or (in development) the string ADMIN:<config.api_token>.
 27/// Authorization: "dev-server-token" <token>
 28pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl IntoResponse {
 29    let mut auth_header = req
 30        .headers()
 31        .get(http::header::AUTHORIZATION)
 32        .and_then(|header| header.to_str().ok())
 33        .ok_or_else(|| {
 34            Error::Http(
 35                StatusCode::UNAUTHORIZED,
 36                "missing authorization header".to_string(),
 37            )
 38        })?
 39        .split_whitespace();
 40
 41    let state = req.extensions().get::<Arc<AppState>>().unwrap();
 42
 43    let first = auth_header.next().unwrap_or("");
 44    if first == "dev-server-token" {
 45        let dev_server_token = auth_header.next().ok_or_else(|| {
 46            Error::Http(
 47                StatusCode::BAD_REQUEST,
 48                "missing dev-server-token token in authorization header".to_string(),
 49            )
 50        })?;
 51        let dev_server = verify_dev_server_token(dev_server_token, &state.db)
 52            .await
 53            .map_err(|e| Error::Http(StatusCode::UNAUTHORIZED, format!("{}", e)))?;
 54
 55        req.extensions_mut()
 56            .insert(Principal::DevServer(dev_server));
 57        return Ok::<_, Error>(next.run(req).await);
 58    }
 59
 60    let user_id = UserId(first.parse().map_err(|_| {
 61        Error::Http(
 62            StatusCode::BAD_REQUEST,
 63            "missing user id in authorization header".to_string(),
 64        )
 65    })?);
 66
 67    let access_token = auth_header.next().ok_or_else(|| {
 68        Error::Http(
 69            StatusCode::BAD_REQUEST,
 70            "missing access token in authorization header".to_string(),
 71        )
 72    })?;
 73
 74    // In development, allow impersonation using the admin API token.
 75    // Don't allow this in production because we can't tell who is doing
 76    // the impersonating.
 77    let validate_result = if let (Some(admin_token), true) = (
 78        access_token.strip_prefix("ADMIN_TOKEN:"),
 79        state.config.is_development(),
 80    ) {
 81        Ok(VerifyAccessTokenResult {
 82            is_valid: state.config.api_token == admin_token,
 83            impersonator_id: None,
 84        })
 85    } else {
 86        verify_access_token(&access_token, user_id, &state.db).await
 87    };
 88
 89    if let Ok(validate_result) = validate_result {
 90        if validate_result.is_valid {
 91            let user = state
 92                .db
 93                .get_user_by_id(user_id)
 94                .await?
 95                .ok_or_else(|| anyhow!("user {} not found", user_id))?;
 96
 97            if let Some(impersonator_id) = validate_result.impersonator_id {
 98                let admin = state
 99                    .db
100                    .get_user_by_id(impersonator_id)
101                    .await?
102                    .ok_or_else(|| anyhow!("user {} not found", impersonator_id))?;
103                req.extensions_mut()
104                    .insert(Principal::Impersonated { user, admin });
105            } else {
106                req.extensions_mut().insert(Principal::User(user));
107            };
108            return Ok::<_, Error>(next.run(req).await);
109        }
110    }
111
112    Err(Error::Http(
113        StatusCode::UNAUTHORIZED,
114        "invalid credentials".to_string(),
115    ))
116}
117
118const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
119
120#[derive(Serialize, Deserialize)]
121struct AccessTokenJson {
122    version: usize,
123    id: AccessTokenId,
124    token: String,
125}
126
127/// Creates a new access token to identify the given user. before returning it, you should
128/// encrypt it with the user's public key.
129pub async fn create_access_token(
130    db: &db::Database,
131    user_id: UserId,
132    impersonated_user_id: Option<UserId>,
133) -> Result<String> {
134    const VERSION: usize = 1;
135    let access_token = rpc::auth::random_token();
136    let access_token_hash = hash_access_token(&access_token);
137    let id = db
138        .create_access_token(
139            user_id,
140            impersonated_user_id,
141            &access_token_hash,
142            MAX_ACCESS_TOKENS_TO_STORE,
143        )
144        .await?;
145    Ok(serde_json::to_string(&AccessTokenJson {
146        version: VERSION,
147        id,
148        token: access_token,
149    })?)
150}
151
152/// Hashing prevents anyone with access to the database being able to login.
153/// As the token is randomly generated, we don't need to worry about scrypt-style
154/// protection.
155fn hash_access_token(token: &str) -> String {
156    let digest = sha2::Sha256::digest(token);
157    format!(
158        "$sha256${}",
159        base64::encode_config(digest, base64::URL_SAFE)
160    )
161}
162
163/// Encrypts the given access token with the given public key to avoid leaking it on the way
164/// to the client.
165pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<String> {
166    let native_app_public_key =
167        rpc::auth::PublicKey::try_from(public_key).context("failed to parse app public key")?;
168    let encrypted_access_token = native_app_public_key
169        .encrypt_string(access_token)
170        .context("failed to encrypt access token with public key")?;
171    Ok(encrypted_access_token)
172}
173
174pub struct VerifyAccessTokenResult {
175    pub is_valid: bool,
176    pub impersonator_id: Option<UserId>,
177}
178
179/// Checks that the given access token is valid for the given user.
180pub async fn verify_access_token(
181    token: &str,
182    user_id: UserId,
183    db: &Arc<Database>,
184) -> Result<VerifyAccessTokenResult> {
185    static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
186    let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
187        register_histogram!(
188            "access_token_hashing_time",
189            "time spent hashing access tokens",
190            exponential_buckets(10.0, 2.0, 10).unwrap(),
191        )
192        .unwrap()
193    });
194
195    let token: AccessTokenJson = serde_json::from_str(&token)?;
196
197    let db_token = db.get_access_token(token.id).await?;
198    let token_user_id = db_token.impersonated_user_id.unwrap_or(db_token.user_id);
199    if token_user_id != user_id {
200        return Err(anyhow!("no such access token"))?;
201    }
202    let t0 = Instant::now();
203
204    let is_valid = if db_token.hash.starts_with("$scrypt$") {
205        let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
206        Scrypt
207            .verify_password(token.token.as_bytes(), &db_hash)
208            .is_ok()
209    } else {
210        let token_hash = hash_access_token(&token.token);
211        db_token.hash.as_bytes().ct_eq(token_hash.as_ref()).into()
212    };
213
214    let duration = t0.elapsed();
215    log::info!("hashed access token in {:?}", duration);
216    metric_access_token_hashing_time.observe(duration.as_millis() as f64);
217
218    if is_valid && db_token.hash.starts_with("$scrypt$") {
219        let new_hash = hash_access_token(&token.token);
220        db.update_access_token_hash(db_token.id, &new_hash).await?;
221    }
222
223    Ok(VerifyAccessTokenResult {
224        is_valid,
225        impersonator_id: if db_token.impersonated_user_id.is_some() {
226            Some(db_token.user_id)
227        } else {
228            None
229        },
230    })
231}
232
233// a dev_server_token has the format <id>.<base64>. This is to make them
234// relatively easy to copy/paste around.
235pub async fn verify_dev_server_token(
236    dev_server_token: &str,
237    db: &Arc<Database>,
238) -> anyhow::Result<dev_server::Model> {
239    let mut parts = dev_server_token.splitn(2, '.');
240    let id = DevServerId(parts.next().unwrap_or_default().parse()?);
241    let token = parts
242        .next()
243        .ok_or_else(|| anyhow!("invalid dev server token format"))?;
244
245    let token_hash = hash_access_token(&token);
246    let server = db.get_dev_server(id).await?;
247
248    if server
249        .hashed_token
250        .as_bytes()
251        .ct_eq(token_hash.as_ref())
252        .into()
253    {
254        Ok(server)
255    } else {
256        Err(anyhow!("wrong token for dev server"))
257    }
258}
259
260#[cfg(test)]
261mod test {
262    use rand::thread_rng;
263    use scrypt::password_hash::{PasswordHasher, SaltString};
264    use sea_orm::EntityTrait;
265
266    use super::*;
267    use crate::db::{access_token, NewUserParams};
268
269    #[gpui::test]
270    async fn test_verify_access_token(cx: &mut gpui::TestAppContext) {
271        let test_db = crate::db::TestDb::postgres(cx.executor().clone());
272        let db = test_db.db();
273
274        let user = db
275            .create_user(
276                "example@example.com",
277                false,
278                NewUserParams {
279                    github_login: "example".into(),
280                    github_user_id: 1,
281                },
282            )
283            .await
284            .unwrap();
285
286        let token = create_access_token(&db, user.user_id, None).await.unwrap();
287        assert!(matches!(
288            verify_access_token(&token, user.user_id, &db)
289                .await
290                .unwrap(),
291            VerifyAccessTokenResult {
292                is_valid: true,
293                impersonator_id: None,
294            }
295        ));
296
297        let old_token = create_previous_access_token(user.user_id, None, &db)
298            .await
299            .unwrap();
300
301        let old_token_id = serde_json::from_str::<AccessTokenJson>(&old_token)
302            .unwrap()
303            .id;
304
305        let hash = db
306            .transaction(|tx| async move {
307                Ok(access_token::Entity::find_by_id(old_token_id)
308                    .one(&*tx)
309                    .await?)
310            })
311            .await
312            .unwrap()
313            .unwrap()
314            .hash;
315        assert!(hash.starts_with("$scrypt$"));
316
317        assert!(matches!(
318            verify_access_token(&old_token, user.user_id, &db)
319                .await
320                .unwrap(),
321            VerifyAccessTokenResult {
322                is_valid: true,
323                impersonator_id: None,
324            }
325        ));
326
327        let hash = db
328            .transaction(|tx| async move {
329                Ok(access_token::Entity::find_by_id(old_token_id)
330                    .one(&*tx)
331                    .await?)
332            })
333            .await
334            .unwrap()
335            .unwrap()
336            .hash;
337        assert!(hash.starts_with("$sha256$"));
338
339        assert!(matches!(
340            verify_access_token(&old_token, user.user_id, &db)
341                .await
342                .unwrap(),
343            VerifyAccessTokenResult {
344                is_valid: true,
345                impersonator_id: None,
346            }
347        ));
348
349        assert!(matches!(
350            verify_access_token(&token, user.user_id, &db)
351                .await
352                .unwrap(),
353            VerifyAccessTokenResult {
354                is_valid: true,
355                impersonator_id: None,
356            }
357        ));
358    }
359
360    async fn create_previous_access_token(
361        user_id: UserId,
362        impersonated_user_id: Option<UserId>,
363        db: &Database,
364    ) -> Result<String> {
365        let access_token = rpc::auth::random_token();
366        let access_token_hash = previous_hash_access_token(&access_token)?;
367        let id = db
368            .create_access_token(
369                user_id,
370                impersonated_user_id,
371                &access_token_hash,
372                MAX_ACCESS_TOKENS_TO_STORE,
373            )
374            .await?;
375        Ok(serde_json::to_string(&AccessTokenJson {
376            version: 1,
377            id,
378            token: access_token,
379        })?)
380    }
381
382    fn previous_hash_access_token(token: &str) -> Result<String> {
383        // Avoid slow hashing in debug mode.
384        let params = if cfg!(debug_assertions) {
385            scrypt::Params::new(1, 1, 1).unwrap()
386        } else {
387            scrypt::Params::new(14, 8, 1).unwrap()
388        };
389
390        Ok(Scrypt
391            .hash_password(
392                token.as_bytes(),
393                None,
394                params,
395                &SaltString::generate(thread_rng()),
396            )
397            .map_err(anyhow::Error::new)?
398            .to_string())
399    }
400}