@@ -21,8 +21,9 @@ macro_rules! test_support {
}};
}
+#[derive(Clone)]
pub struct Db {
- db: sqlx::PgPool,
+ pool: sqlx::PgPool,
test_mode: bool,
}
@@ -57,13 +58,13 @@ pub struct ChannelMessage {
impl Db {
pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
- let db = DbOptions::new()
+ let pool = DbOptions::new()
.max_connections(max_connections)
.connect(url)
.await
.context("failed to connect to postgres database")?;
Ok(Self {
- db,
+ pool,
test_mode: false,
})
}
@@ -86,7 +87,7 @@ impl Db {
.bind(github_login)
.bind(email_address)
.bind(about)
- .fetch_one(&self.db)
+ .fetch_one(&self.pool)
.await
.map(SignupId)
})
@@ -95,7 +96,7 @@ impl Db {
pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
test_support!(self, {
let query = "SELECT * FROM users ORDER BY github_login ASC";
- sqlx::query_as(query).fetch_all(&self.db).await
+ sqlx::query_as(query).fetch_all(&self.pool).await
})
}
@@ -104,7 +105,7 @@ impl Db {
let query = "DELETE FROM signups WHERE id = $1";
sqlx::query(query)
.bind(id.0)
- .execute(&self.db)
+ .execute(&self.pool)
.await
.map(drop)
})
@@ -122,7 +123,7 @@ impl Db {
sqlx::query_scalar(query)
.bind(github_login)
.bind(admin)
- .fetch_one(&self.db)
+ .fetch_one(&self.pool)
.await
.map(UserId)
})
@@ -131,7 +132,7 @@ impl Db {
pub async fn get_all_users(&self) -> Result<Vec<User>> {
test_support!(self, {
let query = "SELECT * FROM users ORDER BY github_login ASC";
- sqlx::query_as(query).fetch_all(&self.db).await
+ sqlx::query_as(query).fetch_all(&self.pool).await
})
}
@@ -159,7 +160,7 @@ impl Db {
sqlx::query_as(query)
.bind(&ids.map(|id| id.0).collect::<Vec<_>>())
.bind(requester_id)
- .fetch_all(&self.db)
+ .fetch_all(&self.pool)
.await
})
}
@@ -169,7 +170,7 @@ impl Db {
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
sqlx::query_as(query)
.bind(github_login)
- .fetch_optional(&self.db)
+ .fetch_optional(&self.pool)
.await
})
}
@@ -180,7 +181,7 @@ impl Db {
sqlx::query(query)
.bind(is_admin)
.bind(id.0)
- .execute(&self.db)
+ .execute(&self.pool)
.await
.map(drop)
})
@@ -191,7 +192,7 @@ impl Db {
let query = "DELETE FROM users WHERE id = $1;";
sqlx::query(query)
.bind(id.0)
- .execute(&self.db)
+ .execute(&self.pool)
.await
.map(drop)
})
@@ -212,7 +213,7 @@ impl Db {
sqlx::query(query)
.bind(user_id.0)
.bind(access_token_hash)
- .execute(&self.db)
+ .execute(&self.pool)
.await
.map(drop)
})
@@ -223,7 +224,7 @@ impl Db {
let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
sqlx::query_scalar(query)
.bind(user_id.0)
- .fetch_all(&self.db)
+ .fetch_all(&self.pool)
.await
})
}
@@ -241,7 +242,7 @@ impl Db {
sqlx::query_scalar(query)
.bind(name)
.bind(slug)
- .fetch_one(&self.db)
+ .fetch_one(&self.pool)
.await
.map(OrgId)
})
@@ -263,7 +264,7 @@ impl Db {
.bind(org_id.0)
.bind(user_id.0)
.bind(is_admin)
- .execute(&self.db)
+ .execute(&self.pool)
.await
.map(drop)
})
@@ -282,7 +283,7 @@ impl Db {
sqlx::query_scalar(query)
.bind(org_id.0)
.bind(name)
- .fetch_one(&self.db)
+ .fetch_one(&self.pool)
.await
.map(ChannelId)
})
@@ -301,7 +302,7 @@ impl Db {
";
sqlx::query_as(query)
.bind(user_id.0)
- .fetch_all(&self.db)
+ .fetch_all(&self.pool)
.await
})
}
@@ -321,7 +322,7 @@ impl Db {
sqlx::query_scalar::<_, i32>(query)
.bind(user_id.0)
.bind(channel_id.0)
- .fetch_optional(&self.db)
+ .fetch_optional(&self.pool)
.await
.map(|e| e.is_some())
})
@@ -343,7 +344,7 @@ impl Db {
.bind(channel_id.0)
.bind(user_id.0)
.bind(is_admin)
- .execute(&self.db)
+ .execute(&self.pool)
.await
.map(drop)
})
@@ -369,7 +370,7 @@ impl Db {
.bind(sender_id.0)
.bind(body)
.bind(timestamp)
- .fetch_one(&self.db)
+ .fetch_one(&self.pool)
.await
.map(MessageId)
})
@@ -382,36 +383,23 @@ impl Db {
) -> Result<Vec<ChannelMessage>> {
test_support!(self, {
let query = r#"
- SELECT
- id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
- FROM
- channel_messages
- WHERE
- channel_id = $1
- LIMIT $2
+ SELECT * FROM (
+ SELECT
+ id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
+ FROM
+ channel_messages
+ WHERE
+ channel_id = $1
+ ORDER BY id DESC
+ LIMIT $2
+ ) as recent_messages
+ ORDER BY id ASC
"#;
sqlx::query_as(query)
.bind(channel_id.0)
.bind(count as i64)
- .fetch_all(&self.db)
- .await
- })
- }
-
- #[cfg(test)]
- pub async fn close(&self, db_name: &str) {
- test_support!(self, {
- let query = "
- SELECT pg_terminate_backend(pg_stat_activity.pid)
- FROM pg_stat_activity
- WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
- ";
- sqlx::query(query)
- .bind(db_name)
- .execute(&self.db)
+ .fetch_all(&self.pool)
.await
- .unwrap();
- self.db.close().await;
})
}
}
@@ -454,12 +442,13 @@ pub mod tests {
use std::path::Path;
pub struct TestDb {
+ pub db: Db,
pub name: String,
pub url: String,
}
impl TestDb {
- pub fn new() -> (Self, Db) {
+ pub fn new() -> Self {
// Enable tests to run in parallel by serializing the creation of each test database.
lazy_static::lazy_static! {
static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
@@ -479,17 +468,54 @@ pub mod tests {
let mut db = Db::new(&url, 5).await.unwrap();
db.test_mode = true;
let migrator = Migrator::new(migrations_path).await.unwrap();
- migrator.run(&db.db).await.unwrap();
+ migrator.run(&db.pool).await.unwrap();
db
});
- (Self { name, url }, db)
+ Self { db, name, url }
+ }
+
+ pub fn db(&self) -> &Db {
+ &self.db
}
}
impl Drop for TestDb {
fn drop(&mut self) {
- block_on(Postgres::drop_database(&self.url)).unwrap();
+ block_on(async {
+ let query = "
+ SELECT pg_terminate_backend(pg_stat_activity.pid)
+ FROM pg_stat_activity
+ WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
+ ";
+ sqlx::query(query)
+ .bind(&self.name)
+ .execute(&self.db.pool)
+ .await
+ .unwrap();
+ self.db.pool.close().await;
+ Postgres::drop_database(&self.url).await.unwrap();
+ });
}
}
+
+ #[gpui::test]
+ async fn test_recent_channel_messages() {
+ let test_db = TestDb::new();
+ let db = test_db.db();
+ let user = db.create_user("user", false).await.unwrap();
+ let org = db.create_org("org", "org").await.unwrap();
+ let channel = db.create_org_channel(org, "channel").await.unwrap();
+ for i in 0..10 {
+ db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc())
+ .await
+ .unwrap();
+ }
+
+ let messages = db.get_recent_channel_messages(channel, 5).await.unwrap();
+ assert_eq!(
+ messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+ ["5", "6", "7", "8", "9"]
+ );
+ }
}
@@ -919,7 +919,7 @@ mod tests {
use super::*;
use crate::{
auth,
- db::{tests::TestDb, Db, UserId},
+ db::{tests::TestDb, UserId},
github, AppState, Config,
};
use async_std::{sync::RwLockReadGuard, task};
@@ -1529,14 +1529,14 @@ mod tests {
peer: Arc<Peer>,
app_state: Arc<AppState>,
server: Arc<Server>,
- test_db: TestDb,
notifications: mpsc::Receiver<()>,
+ _test_db: TestDb,
}
impl TestServer {
async fn start() -> Self {
- let (test_db, db) = TestDb::new();
- let app_state = Self::build_app_state(&test_db, db).await;
+ let test_db = TestDb::new();
+ let app_state = Self::build_app_state(&test_db).await;
let peer = Peer::new();
let notifications = mpsc::channel(128);
let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
@@ -1544,8 +1544,8 @@ mod tests {
peer,
app_state,
server,
- test_db,
notifications: notifications.1,
+ _test_db: test_db,
}
}
@@ -1570,13 +1570,13 @@ mod tests {
(user_id, client)
}
- async fn build_app_state(test_db: &TestDb, db: Db) -> Arc<AppState> {
+ async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
let mut config = Config::default();
config.session_secret = "a".repeat(32);
config.database_url = test_db.url.clone();
let github_client = github::AppClient::test();
Arc::new(AppState {
- db,
+ db: test_db.db().clone(),
handlebars: Default::default(),
auth_client: auth::build_client("", ""),
repo_client: github::RepoClient::test(&github_client),
@@ -1605,10 +1605,7 @@ mod tests {
impl Drop for TestServer {
fn drop(&mut self) {
- task::block_on(async {
- self.peer.reset().await;
- self.app_state.db.close(&self.test_db.name).await;
- });
+ task::block_on(self.peer.reset());
}
}