@@ -13,11 +13,11 @@ use collections::HashMap;
use dashmap::DashMap;
use futures::StreamExt;
use rpc::{proto, ConnectionId};
-use sea_orm::ActiveValue;
use sea_orm::{
entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
TransactionTrait,
};
+use sea_orm::{ActiveValue, IntoActiveModel};
use sea_query::OnConflict;
use serde::{Deserialize, Serialize};
use sqlx::migrate::{Migrate, Migration, MigrationSource};
@@ -31,7 +31,7 @@ use tokio::sync::{Mutex, OwnedMutexGuard};
pub use user::Model as User;
pub struct Database {
- url: String,
+ options: ConnectOptions,
pool: DatabaseConnection,
rooms: DashMap<RoomId, Arc<Mutex<()>>>,
#[cfg(test)]
@@ -41,11 +41,9 @@ pub struct Database {
}
impl Database {
- pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
- let mut options = ConnectOptions::new(url.into());
- options.max_connections(max_connections);
+ pub async fn new(options: ConnectOptions) -> Result<Self> {
Ok(Self {
- url: url.into(),
+ options: options.clone(),
pool: sea_orm::Database::connect(options).await?,
rooms: DashMap::with_capacity(16384),
#[cfg(test)]
@@ -59,12 +57,12 @@ impl Database {
&self,
migrations_path: &Path,
ignore_checksum_mismatch: bool,
- ) -> anyhow::Result<(sqlx::AnyConnection, Vec<(Migration, Duration)>)> {
+ ) -> anyhow::Result<Vec<(Migration, Duration)>> {
let migrations = MigrationSource::resolve(migrations_path)
.await
.map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
- let mut connection = sqlx::AnyConnection::connect(&self.url).await?;
+ let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
connection.ensure_migrations_table().await?;
let applied_migrations: HashMap<_, _> = connection
@@ -93,7 +91,7 @@ impl Database {
}
}
- Ok((connection, new_migrations))
+ Ok(new_migrations)
}
pub async fn create_user(
@@ -142,6 +140,43 @@ impl Database {
.await
}
+ pub async fn get_user_by_github_account(
+ &self,
+ github_login: &str,
+ github_user_id: Option<i32>,
+ ) -> Result<Option<User>> {
+ self.transact(|tx| async {
+ let tx = tx;
+ if let Some(github_user_id) = github_user_id {
+ if let Some(user_by_github_user_id) = user::Entity::find()
+ .filter(user::Column::GithubUserId.eq(github_user_id))
+ .one(&tx)
+ .await?
+ {
+ let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
+ user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
+ Ok(Some(user_by_github_user_id.update(&tx).await?))
+ } else if let Some(user_by_github_login) = user::Entity::find()
+ .filter(user::Column::GithubLogin.eq(github_login))
+ .one(&tx)
+ .await?
+ {
+ let mut user_by_github_login = user_by_github_login.into_active_model();
+ user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
+ Ok(Some(user_by_github_login.update(&tx).await?))
+ } else {
+ Ok(None)
+ }
+ } else {
+ Ok(user::Entity::find()
+ .filter(user::Column::GithubLogin.eq(github_login))
+ .one(&tx)
+ .await?)
+ }
+ })
+ .await
+ }
+
pub async fn share_project(
&self,
room_id: RoomId,
@@ -545,7 +580,9 @@ mod test {
.unwrap();
let mut db = runtime.block_on(async {
- let db = Database::new(&url, 5).await.unwrap();
+ let mut options = ConnectOptions::new(url);
+ options.max_connections(5);
+ let db = Database::new(options).await.unwrap();
let sql = include_str!(concat!(
env!("CARGO_MANIFEST_DIR"),
"/migrations.sqlite/20221109000000_test_schema.sql"
@@ -590,7 +627,11 @@ mod test {
sqlx::Postgres::create_database(&url)
.await
.expect("failed to create test db");
- let db = Database::new(&url, 5).await.unwrap();
+ let mut options = ConnectOptions::new(url);
+ options
+ .max_connections(5)
+ .idle_timeout(Duration::from_secs(0));
+ let db = Database::new(options).await.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
db.migrate(Path::new(migrations_path), false).await.unwrap();
db
@@ -610,11 +651,31 @@ mod test {
}
}
- // TODO: Implement drop
- // impl Drop for PostgresTestDb {
- // fn drop(&mut self) {
- // let db = self.db.take().unwrap();
- // db.teardown(&self.url);
- // }
- // }
+ impl Drop for TestDb {
+ fn drop(&mut self) {
+ let db = self.db.take().unwrap();
+ if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
+ db.runtime.as_ref().unwrap().block_on(async {
+ use util::ResultExt;
+ let query = "
+ SELECT pg_terminate_backend(pg_stat_activity.pid)
+ FROM pg_stat_activity
+ WHERE
+ pg_stat_activity.datname = current_database() AND
+ pid <> pg_backend_pid();
+ ";
+ db.pool
+ .execute(sea_orm::Statement::from_string(
+ db.pool.get_database_backend(),
+ query.into(),
+ ))
+ .await
+ .log_err();
+ sqlx::Postgres::drop_database(db.options.get_url())
+ .await
+ .log_err();
+ })
+ }
+ }
+ }
}
@@ -88,63 +88,63 @@ test_both_dbs!(
}
);
-// test_both_dbs!(
-// test_get_user_by_github_account_postgres,
-// test_get_user_by_github_account_sqlite,
-// db,
-// {
-// let user_id1 = db
-// .create_user(
-// "user1@example.com",
-// false,
-// NewUserParams {
-// github_login: "login1".into(),
-// github_user_id: 101,
-// invite_count: 0,
-// },
-// )
-// .await
-// .unwrap()
-// .user_id;
-// let user_id2 = db
-// .create_user(
-// "user2@example.com",
-// false,
-// NewUserParams {
-// github_login: "login2".into(),
-// github_user_id: 102,
-// invite_count: 0,
-// },
-// )
-// .await
-// .unwrap()
-// .user_id;
-
-// let user = db
-// .get_user_by_github_account("login1", None)
-// .await
-// .unwrap()
-// .unwrap();
-// assert_eq!(user.id, user_id1);
-// assert_eq!(&user.github_login, "login1");
-// assert_eq!(user.github_user_id, Some(101));
-
-// assert!(db
-// .get_user_by_github_account("non-existent-login", None)
-// .await
-// .unwrap()
-// .is_none());
-
-// let user = db
-// .get_user_by_github_account("the-new-login2", Some(102))
-// .await
-// .unwrap()
-// .unwrap();
-// assert_eq!(user.id, user_id2);
-// assert_eq!(&user.github_login, "the-new-login2");
-// assert_eq!(user.github_user_id, Some(102));
-// }
-// );
+test_both_dbs!(
+ test_get_user_by_github_account_postgres,
+ test_get_user_by_github_account_sqlite,
+ db,
+ {
+ let user_id1 = db
+ .create_user(
+ "user1@example.com",
+ false,
+ NewUserParams {
+ github_login: "login1".into(),
+ github_user_id: 101,
+ invite_count: 0,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+ let user_id2 = db
+ .create_user(
+ "user2@example.com",
+ false,
+ NewUserParams {
+ github_login: "login2".into(),
+ github_user_id: 102,
+ invite_count: 0,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+
+ let user = db
+ .get_user_by_github_account("login1", None)
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(user.id, user_id1);
+ assert_eq!(&user.github_login, "login1");
+ assert_eq!(user.github_user_id, Some(101));
+
+ assert!(db
+ .get_user_by_github_account("non-existent-login", None)
+ .await
+ .unwrap()
+ .is_none());
+
+ let user = db
+ .get_user_by_github_account("the-new-login2", Some(102))
+ .await
+ .unwrap()
+ .unwrap();
+ assert_eq!(user.id, user_id2);
+ assert_eq!(&user.github_login, "the-new-login2");
+ assert_eq!(user.github_user_id, Some(102));
+ }
+);
// test_both_dbs!(
// test_create_access_tokens_postgres,