Implement `db2::Database::get_user_by_github_account`

Antonio Scandurra created

Change summary

crates/collab/src/db2.rs       |  97 ++++++++++++++++++++++++-----
crates/collab/src/db2/tests.rs | 114 ++++++++++++++++++------------------
2 files changed, 136 insertions(+), 75 deletions(-)

Detailed changes

crates/collab/src/db2.rs 🔗

@@ -13,11 +13,11 @@ use collections::HashMap;
 use dashmap::DashMap;
 use futures::StreamExt;
 use rpc::{proto, ConnectionId};
-use sea_orm::ActiveValue;
 use sea_orm::{
     entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
     TransactionTrait,
 };
+use sea_orm::{ActiveValue, IntoActiveModel};
 use sea_query::OnConflict;
 use serde::{Deserialize, Serialize};
 use sqlx::migrate::{Migrate, Migration, MigrationSource};
@@ -31,7 +31,7 @@ use tokio::sync::{Mutex, OwnedMutexGuard};
 pub use user::Model as User;
 
 pub struct Database {
-    url: String,
+    options: ConnectOptions,
     pool: DatabaseConnection,
     rooms: DashMap<RoomId, Arc<Mutex<()>>>,
     #[cfg(test)]
@@ -41,11 +41,9 @@ pub struct Database {
 }
 
 impl Database {
-    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
-        let mut options = ConnectOptions::new(url.into());
-        options.max_connections(max_connections);
+    pub async fn new(options: ConnectOptions) -> Result<Self> {
         Ok(Self {
-            url: url.into(),
+            options: options.clone(),
             pool: sea_orm::Database::connect(options).await?,
             rooms: DashMap::with_capacity(16384),
             #[cfg(test)]
@@ -59,12 +57,12 @@ impl Database {
         &self,
         migrations_path: &Path,
         ignore_checksum_mismatch: bool,
-    ) -> anyhow::Result<(sqlx::AnyConnection, Vec<(Migration, Duration)>)> {
+    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
         let migrations = MigrationSource::resolve(migrations_path)
             .await
             .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 
-        let mut connection = sqlx::AnyConnection::connect(&self.url).await?;
+        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
 
         connection.ensure_migrations_table().await?;
         let applied_migrations: HashMap<_, _> = connection
@@ -93,7 +91,7 @@ impl Database {
             }
         }
 
-        Ok((connection, new_migrations))
+        Ok(new_migrations)
     }
 
     pub async fn create_user(
@@ -142,6 +140,43 @@ impl Database {
         .await
     }
 
+    pub async fn get_user_by_github_account(
+        &self,
+        github_login: &str,
+        github_user_id: Option<i32>,
+    ) -> Result<Option<User>> {
+        self.transact(|tx| async {
+            let tx = tx;
+            if let Some(github_user_id) = github_user_id {
+                if let Some(user_by_github_user_id) = user::Entity::find()
+                    .filter(user::Column::GithubUserId.eq(github_user_id))
+                    .one(&tx)
+                    .await?
+                {
+                    let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
+                    user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
+                    Ok(Some(user_by_github_user_id.update(&tx).await?))
+                } else if let Some(user_by_github_login) = user::Entity::find()
+                    .filter(user::Column::GithubLogin.eq(github_login))
+                    .one(&tx)
+                    .await?
+                {
+                    let mut user_by_github_login = user_by_github_login.into_active_model();
+                    user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
+                    Ok(Some(user_by_github_login.update(&tx).await?))
+                } else {
+                    Ok(None)
+                }
+            } else {
+                Ok(user::Entity::find()
+                    .filter(user::Column::GithubLogin.eq(github_login))
+                    .one(&tx)
+                    .await?)
+            }
+        })
+        .await
+    }
+
     pub async fn share_project(
         &self,
         room_id: RoomId,
@@ -545,7 +580,9 @@ mod test {
                 .unwrap();
 
             let mut db = runtime.block_on(async {
-                let db = Database::new(&url, 5).await.unwrap();
+                let mut options = ConnectOptions::new(url);
+                options.max_connections(5);
+                let db = Database::new(options).await.unwrap();
                 let sql = include_str!(concat!(
                     env!("CARGO_MANIFEST_DIR"),
                     "/migrations.sqlite/20221109000000_test_schema.sql"
@@ -590,7 +627,11 @@ mod test {
                 sqlx::Postgres::create_database(&url)
                     .await
                     .expect("failed to create test db");
-                let db = Database::new(&url, 5).await.unwrap();
+                let mut options = ConnectOptions::new(url);
+                options
+                    .max_connections(5)
+                    .idle_timeout(Duration::from_secs(0));
+                let db = Database::new(options).await.unwrap();
                 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
                 db.migrate(Path::new(migrations_path), false).await.unwrap();
                 db
@@ -610,11 +651,31 @@ mod test {
         }
     }
 
-    // TODO: Implement drop
-    // impl Drop for PostgresTestDb {
-    //     fn drop(&mut self) {
-    //         let db = self.db.take().unwrap();
-    //         db.teardown(&self.url);
-    //     }
-    // }
+    impl Drop for TestDb {
+        fn drop(&mut self) {
+            let db = self.db.take().unwrap();
+            if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
+                db.runtime.as_ref().unwrap().block_on(async {
+                    use util::ResultExt;
+                    let query = "
+                        SELECT pg_terminate_backend(pg_stat_activity.pid)
+                        FROM pg_stat_activity
+                        WHERE
+                            pg_stat_activity.datname = current_database() AND
+                            pid <> pg_backend_pid();
+                    ";
+                    db.pool
+                        .execute(sea_orm::Statement::from_string(
+                            db.pool.get_database_backend(),
+                            query.into(),
+                        ))
+                        .await
+                        .log_err();
+                    sqlx::Postgres::drop_database(db.options.get_url())
+                        .await
+                        .log_err();
+                })
+            }
+        }
+    }
 }

crates/collab/src/db2/tests.rs 🔗

@@ -88,63 +88,63 @@ test_both_dbs!(
     }
 );
 
-// test_both_dbs!(
-//     test_get_user_by_github_account_postgres,
-//     test_get_user_by_github_account_sqlite,
-//     db,
-//     {
-//         let user_id1 = db
-//             .create_user(
-//                 "user1@example.com",
-//                 false,
-//                 NewUserParams {
-//                     github_login: "login1".into(),
-//                     github_user_id: 101,
-//                     invite_count: 0,
-//                 },
-//             )
-//             .await
-//             .unwrap()
-//             .user_id;
-//         let user_id2 = db
-//             .create_user(
-//                 "user2@example.com",
-//                 false,
-//                 NewUserParams {
-//                     github_login: "login2".into(),
-//                     github_user_id: 102,
-//                     invite_count: 0,
-//                 },
-//             )
-//             .await
-//             .unwrap()
-//             .user_id;
-
-//         let user = db
-//             .get_user_by_github_account("login1", None)
-//             .await
-//             .unwrap()
-//             .unwrap();
-//         assert_eq!(user.id, user_id1);
-//         assert_eq!(&user.github_login, "login1");
-//         assert_eq!(user.github_user_id, Some(101));
-
-//         assert!(db
-//             .get_user_by_github_account("non-existent-login", None)
-//             .await
-//             .unwrap()
-//             .is_none());
-
-//         let user = db
-//             .get_user_by_github_account("the-new-login2", Some(102))
-//             .await
-//             .unwrap()
-//             .unwrap();
-//         assert_eq!(user.id, user_id2);
-//         assert_eq!(&user.github_login, "the-new-login2");
-//         assert_eq!(user.github_user_id, Some(102));
-//     }
-// );
+test_both_dbs!(
+    test_get_user_by_github_account_postgres,
+    test_get_user_by_github_account_sqlite,
+    db,
+    {
+        let user_id1 = db
+            .create_user(
+                "user1@example.com",
+                false,
+                NewUserParams {
+                    github_login: "login1".into(),
+                    github_user_id: 101,
+                    invite_count: 0,
+                },
+            )
+            .await
+            .unwrap()
+            .user_id;
+        let user_id2 = db
+            .create_user(
+                "user2@example.com",
+                false,
+                NewUserParams {
+                    github_login: "login2".into(),
+                    github_user_id: 102,
+                    invite_count: 0,
+                },
+            )
+            .await
+            .unwrap()
+            .user_id;
+
+        let user = db
+            .get_user_by_github_account("login1", None)
+            .await
+            .unwrap()
+            .unwrap();
+        assert_eq!(user.id, user_id1);
+        assert_eq!(&user.github_login, "login1");
+        assert_eq!(user.github_user_id, Some(101));
+
+        assert!(db
+            .get_user_by_github_account("non-existent-login", None)
+            .await
+            .unwrap()
+            .is_none());
+
+        let user = db
+            .get_user_by_github_account("the-new-login2", Some(102))
+            .await
+            .unwrap()
+            .unwrap();
+        assert_eq!(user.id, user_id2);
+        assert_eq!(&user.github_login, "the-new-login2");
+        assert_eq!(user.github_user_id, Some(102));
+    }
+);
 
 // test_both_dbs!(
 //     test_create_access_tokens_postgres,