diff --git a/server/src/auth.rs b/server/src/auth.rs index e60802285ec602a058fb0c74c2b4038603d6312a..0578f2c9fafffe496c5fd8c68a502cc4efe85743 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -263,11 +263,13 @@ async fn post_sign_out(mut request: Request) -> tide::Result { Ok(tide::Redirect::new("/").into()) } +const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; + pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result { let access_token = zed_auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; - db.create_access_token_hash(user_id, access_token_hash) + db.create_access_token_hash(user_id, &access_token_hash, MAX_ACCESS_TOKENS_TO_STORE) .await?; Ok(access_token) } diff --git a/server/src/db.rs b/server/src/db.rs index 002b82741c6de924cf9319753b3be31730613c5c..bd9fca85e744503ba66fc79ff8c3a350323b0e61 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -162,25 +162,48 @@ impl Db { pub async fn create_access_token_hash( &self, user_id: UserId, - access_token_hash: String, + access_token_hash: &str, + max_access_token_count: usize, ) -> Result<()> { test_support!(self, { - let query = " - INSERT INTO access_tokens (user_id, hash) - VALUES ($1, $2) - "; - sqlx::query(query) + let insert_query = " + INSERT INTO access_tokens (user_id, hash) + VALUES ($1, $2); + "; + let cleanup_query = " + DELETE FROM access_tokens + WHERE id IN ( + SELECT id from access_tokens + WHERE user_id = $1 + ORDER BY id DESC + OFFSET $3 + ) + "; + + let mut tx = self.pool.begin().await?; + sqlx::query(insert_query) .bind(user_id.0) .bind(access_token_hash) - .execute(&self.pool) - .await - .map(drop) + .execute(&mut tx) + .await?; + sqlx::query(cleanup_query) + .bind(user_id.0) + .bind(access_token_hash) + .bind(max_access_token_count as u32) + .execute(&mut tx) + .await?; + tx.commit().await }) } pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { test_support!(self, { - let query = "SELECT hash FROM access_tokens WHERE user_id = $1"; + let query = " + SELECT hash + FROM access_tokens + WHERE user_id = $1 + ORDER BY id DESC + "; sqlx::query_scalar(query) .bind(user_id.0) .fetch_all(&self.pool) @@ -636,4 +659,36 @@ pub mod tests { assert_eq!(msg1_id, msg3_id); assert_eq!(msg2_id, msg4_id); } + + #[gpui::test] + async fn test_create_access_tokens() { + let test_db = TestDb::new(); + let db = test_db.db(); + let user = db.create_user("the-user", false).await.unwrap(); + + db.create_access_token_hash(user, "h1", 3).await.unwrap(); + db.create_access_token_hash(user, "h2", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h2".to_string(), "h1".to_string()] + ); + + db.create_access_token_hash(user, "h3", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h3".to_string(), "h2".to_string(), "h1".to_string(),] + ); + + db.create_access_token_hash(user, "h4", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h4".to_string(), "h3".to_string(), "h2".to_string(),] + ); + + db.create_access_token_hash(user, "h5", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h5".to_string(), "h4".to_string(), "h3".to_string()] + ); + } }