Return a 400, not a 500 when token validation fails

Max Brunsfeld and Antonio Scandurra created

Co-authored-by: Antonio Scandurra <antonio@zed.dev>

Change summary

crates/collab/src/auth.rs | 28 ++++++++++++++++------------
1 file changed, 16 insertions(+), 12 deletions(-)

Detailed changes

crates/collab/src/auth.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    db::{self, AccessTokenId, UserId},
+    db::{self, AccessTokenId, Database, UserId},
     AppState, Error, Result,
 };
 use anyhow::{anyhow, Context};
@@ -47,14 +47,9 @@ pub async fn validate_header<B>(mut req: Request<B>, next: Next<B>) -> impl Into
     let credentials_valid = if let Some(admin_token) = access_token.strip_prefix("ADMIN_TOKEN:") {
         state.config.api_token == admin_token
     } else {
-        let access_token: AccessTokenJson = serde_json::from_str(&access_token)?;
-
-        let token = state.db.get_access_token(access_token.id).await?;
-        if token.user_id != user_id {
-            return Err(anyhow!("no such access token"))?;
-        }
-
-        verify_access_token(&access_token.token, &token.hash)?
+        verify_access_token(&access_token, user_id, &state.db)
+            .await
+            .unwrap_or(false)
     };
 
     if credentials_valid {
@@ -125,7 +120,16 @@ pub fn encrypt_access_token(access_token: &str, public_key: String) -> Result<St
     Ok(encrypted_access_token)
 }
 
-pub fn verify_access_token(token: &str, hash: &str) -> Result<bool> {
-    let hash = PasswordHash::new(hash).map_err(anyhow::Error::new)?;
-    Ok(Scrypt.verify_password(token.as_bytes(), &hash).is_ok())
+pub async fn verify_access_token(token: &str, user_id: UserId, db: &Arc<Database>) -> Result<bool> {
+    let token: AccessTokenJson = serde_json::from_str(&token)?;
+
+    let db_token = db.get_access_token(token.id).await?;
+    if db_token.user_id != user_id {
+        return Err(anyhow!("no such access token"))?;
+    }
+
+    let db_hash = PasswordHash::new(&db_token.hash).map_err(anyhow::Error::new)?;
+    Ok(Scrypt
+        .verify_password(token.token.as_bytes(), &db_hash)
+        .is_ok())
 }