Implement access tokens using sea-orm

Antonio Scandurra created

Change summary

crates/collab/src/db2.rs              | 73 +++++++++++++++++++++++++
crates/collab/src/db2/access_token.rs | 29 ++++++++++
crates/collab/src/db2/tests.rs        | 82 ++++++++++++++--------------
crates/collab/src/db2/user.rs         | 11 +++
4 files changed, 151 insertions(+), 44 deletions(-)

Detailed changes

crates/collab/src/db2.rs 🔗

@@ -1,3 +1,4 @@
+mod access_token;
 mod project;
 mod project_collaborator;
 mod room;
@@ -17,8 +18,8 @@ use sea_orm::{
     entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
     TransactionTrait,
 };
-use sea_orm::{ActiveValue, IntoActiveModel};
-use sea_query::OnConflict;
+use sea_orm::{ActiveValue, ConnectionTrait, IntoActiveModel, QueryOrder, QuerySelect};
+use sea_query::{OnConflict, Query};
 use serde::{Deserialize, Serialize};
 use sqlx::migrate::{Migrate, Migration, MigrationSource};
 use sqlx::Connection;
@@ -336,6 +337,63 @@ impl Database {
         })
     }
 
+    pub async fn create_access_token_hash(
+        &self,
+        user_id: UserId,
+        access_token_hash: &str,
+        max_access_token_count: usize,
+    ) -> Result<()> {
+        self.transact(|tx| async {
+            let tx = tx;
+
+            access_token::ActiveModel {
+                user_id: ActiveValue::set(user_id),
+                hash: ActiveValue::set(access_token_hash.into()),
+                ..Default::default()
+            }
+            .insert(&tx)
+            .await?;
+
+            access_token::Entity::delete_many()
+                .filter(
+                    access_token::Column::Id.in_subquery(
+                        Query::select()
+                            .column(access_token::Column::Id)
+                            .from(access_token::Entity)
+                            .and_where(access_token::Column::UserId.eq(user_id))
+                            .order_by(access_token::Column::Id, sea_orm::Order::Desc)
+                            .limit(10000)
+                            .offset(max_access_token_count as u64)
+                            .to_owned(),
+                    ),
+                )
+                .exec(&tx)
+                .await?;
+            tx.commit().await?;
+            Ok(())
+        })
+        .await
+    }
+
+    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
+        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+        enum QueryAs {
+            Hash,
+        }
+
+        self.transact(|tx| async move {
+            Ok(access_token::Entity::find()
+                .select_only()
+                .column(access_token::Column::Hash)
+                .filter(access_token::Column::UserId.eq(user_id))
+                .order_by_desc(access_token::Column::Id)
+                .into_values::<_, QueryAs>()
+                .all(&tx)
+                .await?)
+        })
+        .await
+    }
+
     async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
     where
         F: Send + Fn(DatabaseTransaction) -> Fut,
@@ -344,6 +402,16 @@ impl Database {
         let body = async {
             loop {
                 let tx = self.pool.begin().await?;
+
+                // In Postgres, serializable transactions are opt-in
+                if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() {
+                    tx.execute(sea_orm::Statement::from_string(
+                        sea_orm::DatabaseBackend::Postgres,
+                        "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
+                    ))
+                    .await?;
+                }
+
                 match f(tx).await {
                     Ok(result) => return Ok(result),
                     Err(error) => match error {
@@ -544,6 +612,7 @@ macro_rules! id_type {
     };
 }
 
+id_type!(AccessTokenId);
 id_type!(UserId);
 id_type!(RoomId);
 id_type!(RoomParticipantId);

crates/collab/src/db2/access_token.rs 🔗

@@ -0,0 +1,29 @@
+use super::{AccessTokenId, UserId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "access_tokens")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: AccessTokenId,
+    pub user_id: UserId,
+    pub hash: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::user::Entity",
+        from = "Column::UserId",
+        to = "super::user::Column::Id"
+    )]
+    User,
+}
+
+impl Related<super::user::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::User.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/db2/tests.rs 🔗

@@ -146,51 +146,51 @@ test_both_dbs!(
     }
 );
 
-// test_both_dbs!(
-//     test_create_access_tokens_postgres,
-//     test_create_access_tokens_sqlite,
-//     db,
-//     {
-//         let user = db
-//             .create_user(
-//                 "u1@example.com",
-//                 false,
-//                 NewUserParams {
-//                     github_login: "u1".into(),
-//                     github_user_id: 1,
-//                     invite_count: 0,
-//                 },
-//             )
-//             .await
-//             .unwrap()
-//             .user_id;
+test_both_dbs!(
+    test_create_access_tokens_postgres,
+    test_create_access_tokens_sqlite,
+    db,
+    {
+        let user = db
+            .create_user(
+                "u1@example.com",
+                false,
+                NewUserParams {
+                    github_login: "u1".into(),
+                    github_user_id: 1,
+                    invite_count: 0,
+                },
+            )
+            .await
+            .unwrap()
+            .user_id;
 
-//         db.create_access_token_hash(user, "h1", 3).await.unwrap();
-//         db.create_access_token_hash(user, "h2", 3).await.unwrap();
-//         assert_eq!(
-//             db.get_access_token_hashes(user).await.unwrap(),
-//             &["h2".to_string(), "h1".to_string()]
-//         );
+        db.create_access_token_hash(user, "h1", 3).await.unwrap();
+        db.create_access_token_hash(user, "h2", 3).await.unwrap();
+        assert_eq!(
+            db.get_access_token_hashes(user).await.unwrap(),
+            &["h2".to_string(), "h1".to_string()]
+        );
 
-//         db.create_access_token_hash(user, "h3", 3).await.unwrap();
-//         assert_eq!(
-//             db.get_access_token_hashes(user).await.unwrap(),
-//             &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
-//         );
+        db.create_access_token_hash(user, "h3", 3).await.unwrap();
+        assert_eq!(
+            db.get_access_token_hashes(user).await.unwrap(),
+            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
+        );
 
-//         db.create_access_token_hash(user, "h4", 3).await.unwrap();
-//         assert_eq!(
-//             db.get_access_token_hashes(user).await.unwrap(),
-//             &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
-//         );
+        db.create_access_token_hash(user, "h4", 3).await.unwrap();
+        assert_eq!(
+            db.get_access_token_hashes(user).await.unwrap(),
+            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
+        );
 
-//         db.create_access_token_hash(user, "h5", 3).await.unwrap();
-//         assert_eq!(
-//             db.get_access_token_hashes(user).await.unwrap(),
-//             &["h5".to_string(), "h4".to_string(), "h3".to_string()]
-//         );
-//     }
-// );
+        db.create_access_token_hash(user, "h5", 3).await.unwrap();
+        assert_eq!(
+            db.get_access_token_hashes(user).await.unwrap(),
+            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
+        );
+    }
+);
 
 // test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, {
 //     let mut user_ids = Vec::new();

crates/collab/src/db2/user.rs 🔗

@@ -17,6 +17,15 @@ pub struct Model {
 }
 
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
-pub enum Relation {}
+pub enum Relation {
+    #[sea_orm(has_many = "super::access_token::Entity")]
+    AccessToken,
+}
+
+impl Related<super::access_token::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::AccessToken.def()
+    }
+}
 
 impl ActiveModelBehavior for ActiveModel {}