Merge pull request #183 from zed-industries/speed-up-access-token-verification

Max Brunsfeld created

Speed up login by avoiding unnecessary access token verification

Change summary

server/src/auth.rs |  4 +
server/src/db.rs   | 75 +++++++++++++++++++++++++++++++++++++++++------
2 files changed, 68 insertions(+), 11 deletions(-)

Detailed changes

server/src/auth.rs 🔗

@@ -263,11 +263,13 @@ async fn post_sign_out(mut request: Request) -> tide::Result {
     Ok(tide::Redirect::new("/").into())
 }
 
+const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
+
 pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
     let access_token = zed_auth::random_token();
     let access_token_hash =
         hash_access_token(&access_token).context("failed to hash access token")?;
-    db.create_access_token_hash(user_id, access_token_hash)
+    db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE)
         .await?;
     Ok(access_token)
 }

server/src/db.rs 🔗

@@ -162,25 +162,48 @@ impl Db {
     pub async fn create_access_token_hash(
         &self,
         user_id: UserId,
-        access_token_hash: String,
+        access_token_hash: &str,
+        max_access_token_count: usize,
     ) -> Result<()> {
         test_support!(self, {
-            let query = "
-            INSERT INTO access_tokens (user_id, hash)
-            VALUES ($1, $2)
-        ";
-            sqlx::query(query)
+            let insert_query = "
+                INSERT INTO access_tokens (user_id, hash)
+                VALUES ($1, $2);
+            ";
+            let cleanup_query = "
+                DELETE FROM access_tokens
+                WHERE id IN (
+                    SELECT id from access_tokens
+                    WHERE user_id = $1
+                    ORDER BY id DESC
+                    OFFSET $3
+                )
+            ";
+
+            let mut tx = self.pool.begin().await?;
+            sqlx::query(insert_query)
                 .bind(user_id.0)
                 .bind(access_token_hash)
-                .execute(&self.pool)
-                .await
-                .map(drop)
+                .execute(&mut tx)
+                .await?;
+            sqlx::query(cleanup_query)
+                .bind(user_id.0)
+                .bind(access_token_hash)
+                .bind(max_access_token_count as u32)
+                .execute(&mut tx)
+                .await?;
+            tx.commit().await
         })
     }
 
     pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
         test_support!(self, {
-            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
+            let query = "
+                SELECT hash
+                FROM access_tokens
+                WHERE user_id = $1
+                ORDER BY id DESC
+            ";
             sqlx::query_scalar(query)
                 .bind(user_id.0)
                 .fetch_all(&self.pool)
@@ -636,4 +659,36 @@ pub mod tests {
         assert_eq!(msg1_id, msg3_id);
         assert_eq!(msg2_id, msg4_id);
     }
+
+    #[gpui::test]
+    async fn test_create_access_tokens() {
+        let test_db = TestDb::new();
+        let db = test_db.db();
+        let user = db.create_user("the-user", false).await.unwrap();
+
+        db.create_access_token_hash(user, "h1", 3).await.unwrap();
+        db.create_access_token_hash(user, "h2", 3).await.unwrap();
+        assert_eq!(
+            db.get_access_token_hashes(user).await.unwrap(),
+            &["h2".to_string(), "h1".to_string()]
+        );
+
+        db.create_access_token_hash(user, "h3", 3).await.unwrap();
+        assert_eq!(
+            db.get_access_token_hashes(user).await.unwrap(),
+            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
+        );
+
+        db.create_access_token_hash(user, "h4", 3).await.unwrap();
+        assert_eq!(
+            db.get_access_token_hashes(user).await.unwrap(),
+            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
+        );
+
+        db.create_access_token_hash(user, "h5", 3).await.unwrap();
+        assert_eq!(
+            db.get_access_token_hashes(user).await.unwrap(),
+            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
+        );
+    }
 }