From c865f8ad1a26e1e8c25ddb723e49a37f6486f537 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 26 Aug 2021 14:14:22 +0200 Subject: [PATCH] Fix retrieving recent channel messages --- server/src/db.rs | 126 ++++++++++++++++++++++++++++------------------ server/src/rpc.rs | 19 +++---- 2 files changed, 84 insertions(+), 61 deletions(-) diff --git a/server/src/db.rs b/server/src/db.rs index 7c063e140f3f7b5d5553a73249639562ebaab209..2f1cbc5fba181457d823aa1a5a98c7bc0cbafc67 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -21,8 +21,9 @@ macro_rules! test_support { }}; } +#[derive(Clone)] pub struct Db { - db: sqlx::PgPool, + pool: sqlx::PgPool, test_mode: bool, } @@ -57,13 +58,13 @@ pub struct ChannelMessage { impl Db { pub async fn new(url: &str, max_connections: u32) -> tide::Result { - let db = DbOptions::new() + let pool = DbOptions::new() .max_connections(max_connections) .connect(url) .await .context("failed to connect to postgres database")?; Ok(Self { - db, + pool, test_mode: false, }) } @@ -86,7 +87,7 @@ impl Db { .bind(github_login) .bind(email_address) .bind(about) - .fetch_one(&self.db) + .fetch_one(&self.pool) .await .map(SignupId) }) @@ -95,7 +96,7 @@ impl Db { pub async fn get_all_signups(&self) -> Result> { test_support!(self, { let query = "SELECT * FROM users ORDER BY github_login ASC"; - sqlx::query_as(query).fetch_all(&self.db).await + sqlx::query_as(query).fetch_all(&self.pool).await }) } @@ -104,7 +105,7 @@ impl Db { let query = "DELETE FROM signups WHERE id = $1"; sqlx::query(query) .bind(id.0) - .execute(&self.db) + .execute(&self.pool) .await .map(drop) }) @@ -122,7 +123,7 @@ impl Db { sqlx::query_scalar(query) .bind(github_login) .bind(admin) - .fetch_one(&self.db) + .fetch_one(&self.pool) .await .map(UserId) }) @@ -131,7 +132,7 @@ impl Db { pub async fn get_all_users(&self) -> Result> { test_support!(self, { let query = "SELECT * FROM users ORDER BY github_login ASC"; - sqlx::query_as(query).fetch_all(&self.db).await + sqlx::query_as(query).fetch_all(&self.pool).await }) } @@ -159,7 +160,7 @@ impl Db { sqlx::query_as(query) .bind(&ids.map(|id| id.0).collect::>()) .bind(requester_id) - .fetch_all(&self.db) + .fetch_all(&self.pool) .await }) } @@ -169,7 +170,7 @@ impl Db { let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1"; sqlx::query_as(query) .bind(github_login) - .fetch_optional(&self.db) + .fetch_optional(&self.pool) .await }) } @@ -180,7 +181,7 @@ impl Db { sqlx::query(query) .bind(is_admin) .bind(id.0) - .execute(&self.db) + .execute(&self.pool) .await .map(drop) }) @@ -191,7 +192,7 @@ impl Db { let query = "DELETE FROM users WHERE id = $1;"; sqlx::query(query) .bind(id.0) - .execute(&self.db) + .execute(&self.pool) .await .map(drop) }) @@ -212,7 +213,7 @@ impl Db { sqlx::query(query) .bind(user_id.0) .bind(access_token_hash) - .execute(&self.db) + .execute(&self.pool) .await .map(drop) }) @@ -223,7 +224,7 @@ impl Db { let query = "SELECT hash FROM access_tokens WHERE user_id = $1"; sqlx::query_scalar(query) .bind(user_id.0) - .fetch_all(&self.db) + .fetch_all(&self.pool) .await }) } @@ -241,7 +242,7 @@ impl Db { sqlx::query_scalar(query) .bind(name) .bind(slug) - .fetch_one(&self.db) + .fetch_one(&self.pool) .await .map(OrgId) }) @@ -263,7 +264,7 @@ impl Db { .bind(org_id.0) .bind(user_id.0) .bind(is_admin) - .execute(&self.db) + .execute(&self.pool) .await .map(drop) }) @@ -282,7 +283,7 @@ impl Db { sqlx::query_scalar(query) .bind(org_id.0) .bind(name) - .fetch_one(&self.db) + .fetch_one(&self.pool) .await .map(ChannelId) }) @@ -301,7 +302,7 @@ impl Db { "; sqlx::query_as(query) .bind(user_id.0) - .fetch_all(&self.db) + .fetch_all(&self.pool) .await }) } @@ -321,7 +322,7 @@ impl Db { sqlx::query_scalar::<_, i32>(query) .bind(user_id.0) .bind(channel_id.0) - .fetch_optional(&self.db) + .fetch_optional(&self.pool) .await .map(|e| e.is_some()) }) @@ -343,7 +344,7 @@ impl Db { .bind(channel_id.0) .bind(user_id.0) .bind(is_admin) - .execute(&self.db) + .execute(&self.pool) .await .map(drop) }) @@ -369,7 +370,7 @@ impl Db { .bind(sender_id.0) .bind(body) .bind(timestamp) - .fetch_one(&self.db) + .fetch_one(&self.pool) .await .map(MessageId) }) @@ -382,36 +383,23 @@ impl Db { ) -> Result> { test_support!(self, { let query = r#" - SELECT - id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at - FROM - channel_messages - WHERE - channel_id = $1 - LIMIT $2 + SELECT * FROM ( + SELECT + id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at + FROM + channel_messages + WHERE + channel_id = $1 + ORDER BY id DESC + LIMIT $2 + ) as recent_messages + ORDER BY id ASC "#; sqlx::query_as(query) .bind(channel_id.0) .bind(count as i64) - .fetch_all(&self.db) - .await - }) - } - - #[cfg(test)] - pub async fn close(&self, db_name: &str) { - test_support!(self, { - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid(); - "; - sqlx::query(query) - .bind(db_name) - .execute(&self.db) + .fetch_all(&self.pool) .await - .unwrap(); - self.db.close().await; }) } } @@ -454,12 +442,13 @@ pub mod tests { use std::path::Path; pub struct TestDb { + pub db: Db, pub name: String, pub url: String, } impl TestDb { - pub fn new() -> (Self, Db) { + pub fn new() -> Self { // Enable tests to run in parallel by serializing the creation of each test database. lazy_static::lazy_static! { static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(()); @@ -479,17 +468,54 @@ pub mod tests { let mut db = Db::new(&url, 5).await.unwrap(); db.test_mode = true; let migrator = Migrator::new(migrations_path).await.unwrap(); - migrator.run(&db.db).await.unwrap(); + migrator.run(&db.pool).await.unwrap(); db }); - (Self { name, url }, db) + Self { db, name, url } + } + + pub fn db(&self) -> &Db { + &self.db } } impl Drop for TestDb { fn drop(&mut self) { - block_on(Postgres::drop_database(&self.url)).unwrap(); + block_on(async { + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid(); + "; + sqlx::query(query) + .bind(&self.name) + .execute(&self.db.pool) + .await + .unwrap(); + self.db.pool.close().await; + Postgres::drop_database(&self.url).await.unwrap(); + }); } } + + #[gpui::test] + async fn test_recent_channel_messages() { + let test_db = TestDb::new(); + let db = test_db.db(); + let user = db.create_user("user", false).await.unwrap(); + let org = db.create_org("org", "org").await.unwrap(); + let channel = db.create_org_channel(org, "channel").await.unwrap(); + for i in 0..10 { + db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc()) + .await + .unwrap(); + } + + let messages = db.get_recent_channel_messages(channel, 5).await.unwrap(); + assert_eq!( + messages.iter().map(|m| &m.body).collect::>(), + ["5", "6", "7", "8", "9"] + ); + } } diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 0822ddbef1c42430b7d8f84937c566512e33d95c..1f329f5219071fd8daab38ee67b313f93d557557 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -919,7 +919,7 @@ mod tests { use super::*; use crate::{ auth, - db::{tests::TestDb, Db, UserId}, + db::{tests::TestDb, UserId}, github, AppState, Config, }; use async_std::{sync::RwLockReadGuard, task}; @@ -1529,14 +1529,14 @@ mod tests { peer: Arc, app_state: Arc, server: Arc, - test_db: TestDb, notifications: mpsc::Receiver<()>, + _test_db: TestDb, } impl TestServer { async fn start() -> Self { - let (test_db, db) = TestDb::new(); - let app_state = Self::build_app_state(&test_db, db).await; + let test_db = TestDb::new(); + let app_state = Self::build_app_state(&test_db).await; let peer = Peer::new(); let notifications = mpsc::channel(128); let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0)); @@ -1544,8 +1544,8 @@ mod tests { peer, app_state, server, - test_db, notifications: notifications.1, + _test_db: test_db, } } @@ -1570,13 +1570,13 @@ mod tests { (user_id, client) } - async fn build_app_state(test_db: &TestDb, db: Db) -> Arc { + async fn build_app_state(test_db: &TestDb) -> Arc { let mut config = Config::default(); config.session_secret = "a".repeat(32); config.database_url = test_db.url.clone(); let github_client = github::AppClient::test(); Arc::new(AppState { - db, + db: test_db.db().clone(), handlebars: Default::default(), auth_client: auth::build_client("", ""), repo_client: github::RepoClient::test(&github_client), @@ -1605,10 +1605,7 @@ mod tests { impl Drop for TestServer { fn drop(&mut self) { - task::block_on(async { - self.peer.reset().await; - self.app_state.db.close(&self.test_db.name).await; - }); + task::block_on(self.peer.reset()); } }