@@ -1,11 +1,12 @@
use anyhow::Context;
+use anyhow::Result;
+pub use async_sqlx_session::PostgresSessionStore as SessionStore;
use async_std::task::{block_on, yield_now};
+use async_trait::async_trait;
use serde::Serialize;
-use sqlx::{types::Uuid, FromRow, Result};
-use time::OffsetDateTime;
-
-pub use async_sqlx_session::PostgresSessionStore as SessionStore;
pub use sqlx::postgres::PgPoolOptions as DbOptions;
+use sqlx::{types::Uuid, FromRow};
+use time::OffsetDateTime;
macro_rules! test_support {
($self:ident, { $($token:tt)* }) => {{
@@ -21,13 +22,77 @@ macro_rules! test_support {
}};
}
-#[derive(Clone)]
-pub struct Db {
+#[async_trait]
+pub trait Db: Send + Sync {
+ async fn create_signup(
+ &self,
+ github_login: &str,
+ email_address: &str,
+ about: &str,
+ wants_releases: bool,
+ wants_updates: bool,
+ wants_community: bool,
+ ) -> Result<SignupId>;
+ async fn get_all_signups(&self) -> Result<Vec<Signup>>;
+ async fn destroy_signup(&self, id: SignupId) -> Result<()>;
+ async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
+ async fn get_all_users(&self) -> Result<Vec<User>>;
+ async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
+ async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
+ async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
+ async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
+ async fn destroy_user(&self, id: UserId) -> Result<()>;
+ async fn create_access_token_hash(
+ &self,
+ user_id: UserId,
+ access_token_hash: &str,
+ max_access_token_count: usize,
+ ) -> Result<()>;
+ async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
+ async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
+ async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
+ -> Result<bool>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn add_channel_member(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ is_admin: bool,
+ ) -> Result<()>;
+ async fn create_channel_message(
+ &self,
+ channel_id: ChannelId,
+ sender_id: UserId,
+ body: &str,
+ timestamp: OffsetDateTime,
+ nonce: u128,
+ ) -> Result<MessageId>;
+ async fn get_channel_messages(
+ &self,
+ channel_id: ChannelId,
+ count: usize,
+ before_id: Option<MessageId>,
+ ) -> Result<Vec<ChannelMessage>>;
+ #[cfg(test)]
+ async fn teardown(&self, name: &str, url: &str);
+}
+
+pub struct PostgresDb {
pool: sqlx::PgPool,
test_mode: bool,
}
-impl Db {
+impl PostgresDb {
pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
let pool = DbOptions::new()
.max_connections(max_connections)
@@ -39,10 +104,12 @@ impl Db {
test_mode: false,
})
}
+}
+#[async_trait]
+impl Db for PostgresDb {
// signups
-
- pub async fn create_signup(
+ async fn create_signup(
&self,
github_login: &str,
email_address: &str,
@@ -64,7 +131,7 @@ impl Db {
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(github_login)
.bind(email_address)
.bind(about)
@@ -73,31 +140,31 @@ impl Db {
.bind(wants_community)
.fetch_one(&self.pool)
.await
- .map(SignupId)
+ .map(SignupId)?)
})
}
- pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
+ async fn get_all_signups(&self) -> Result<Vec<Signup>> {
test_support!(self, {
let query = "SELECT * FROM signups ORDER BY github_login ASC";
- sqlx::query_as(query).fetch_all(&self.pool).await
+ Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
})
}
- pub async fn destroy_signup(&self, id: SignupId) -> Result<()> {
+ async fn destroy_signup(&self, id: SignupId) -> Result<()> {
test_support!(self, {
let query = "DELETE FROM signups WHERE id = $1";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(id.0)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
// users
- pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
+ async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
test_support!(self, {
let query = "
INSERT INTO users (github_login, admin)
@@ -105,31 +172,28 @@ impl Db {
ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(github_login)
.bind(admin)
.fetch_one(&self.pool)
.await
- .map(UserId)
+ .map(UserId)?)
})
}
- pub async fn get_all_users(&self) -> Result<Vec<User>> {
+ async fn get_all_users(&self) -> Result<Vec<User>> {
test_support!(self, {
let query = "SELECT * FROM users ORDER BY github_login ASC";
- sqlx::query_as(query).fetch_all(&self.pool).await
+ Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
})
}
- pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
- let users = self.get_users_by_ids([id]).await?;
+ async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+ let users = self.get_users_by_ids(vec![id]).await?;
Ok(users.into_iter().next())
}
- pub async fn get_users_by_ids(
- &self,
- ids: impl IntoIterator<Item = UserId>,
- ) -> Result<Vec<User>> {
+ async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
test_support!(self, {
let query = "
@@ -138,33 +202,36 @@ impl Db {
WHERE users.id = ANY ($1)
";
- sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
+ Ok(sqlx::query_as(query)
+ .bind(&ids)
+ .fetch_all(&self.pool)
+ .await?)
})
}
- pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
+ async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
test_support!(self, {
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(github_login)
.fetch_optional(&self.pool)
- .await
+ .await?)
})
}
- pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
+ async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
test_support!(self, {
let query = "UPDATE users SET admin = $1 WHERE id = $2";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(is_admin)
.bind(id.0)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
- pub async fn destroy_user(&self, id: UserId) -> Result<()> {
+ async fn destroy_user(&self, id: UserId) -> Result<()> {
test_support!(self, {
let query = "DELETE FROM access_tokens WHERE user_id = $1;";
sqlx::query(query)
@@ -173,17 +240,17 @@ impl Db {
.await
.map(drop)?;
let query = "DELETE FROM users WHERE id = $1;";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(id.0)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
// access tokens
- pub async fn create_access_token_hash(
+ async fn create_access_token_hash(
&self,
user_id: UserId,
access_token_hash: &str,
@@ -216,11 +283,11 @@ impl Db {
.bind(max_access_token_count as u32)
.execute(&mut tx)
.await?;
- tx.commit().await
+ Ok(tx.commit().await?)
})
}
- pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
+ async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
test_support!(self, {
let query = "
SELECT hash
@@ -228,10 +295,10 @@ impl Db {
WHERE user_id = $1
ORDER BY id DESC
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(user_id.0)
.fetch_all(&self.pool)
- .await
+ .await?)
})
}
@@ -239,82 +306,77 @@ impl Db {
#[allow(unused)] // Help rust-analyzer
#[cfg(any(test, feature = "seed-support"))]
- pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
+ async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
test_support!(self, {
let query = "
SELECT *
FROM orgs
WHERE slug = $1
";
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(slug)
.fetch_optional(&self.pool)
- .await
+ .await?)
})
}
#[cfg(any(test, feature = "seed-support"))]
- pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+ async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
test_support!(self, {
let query = "
INSERT INTO orgs (name, slug)
VALUES ($1, $2)
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(name)
.bind(slug)
.fetch_one(&self.pool)
.await
- .map(OrgId)
+ .map(OrgId)?)
})
}
#[cfg(any(test, feature = "seed-support"))]
- pub async fn add_org_member(
- &self,
- org_id: OrgId,
- user_id: UserId,
- is_admin: bool,
- ) -> Result<()> {
+ async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
test_support!(self, {
let query = "
INSERT INTO org_memberships (org_id, user_id, admin)
VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING
";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(org_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
// channels
#[cfg(any(test, feature = "seed-support"))]
- pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+ async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
test_support!(self, {
let query = "
INSERT INTO channels (owner_id, owner_is_user, name)
VALUES ($1, false, $2)
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(org_id.0)
.bind(name)
.fetch_one(&self.pool)
.await
- .map(ChannelId)
+ .map(ChannelId)?)
})
}
#[allow(unused)] // Help rust-analyzer
#[cfg(any(test, feature = "seed-support"))]
- pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
+ async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
test_support!(self, {
let query = "
SELECT *
@@ -323,32 +385,32 @@ impl Db {
channels.owner_is_user = false AND
channels.owner_id = $1
";
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(org_id.0)
.fetch_all(&self.pool)
- .await
+ .await?)
})
}
- pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
+ async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
test_support!(self, {
let query = "
SELECT
- channels.id, channels.name
+ channels.*
FROM
channel_memberships, channels
WHERE
channel_memberships.user_id = $1 AND
channel_memberships.channel_id = channels.id
";
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(user_id.0)
.fetch_all(&self.pool)
- .await
+ .await?)
})
}
- pub async fn can_user_access_channel(
+ async fn can_user_access_channel(
&self,
user_id: UserId,
channel_id: ChannelId,
@@ -360,17 +422,17 @@ impl Db {
WHERE user_id = $1 AND channel_id = $2
LIMIT 1
";
- sqlx::query_scalar::<_, i32>(query)
+ Ok(sqlx::query_scalar::<_, i32>(query)
.bind(user_id.0)
.bind(channel_id.0)
.fetch_optional(&self.pool)
.await
- .map(|e| e.is_some())
+ .map(|e| e.is_some())?)
})
}
#[cfg(any(test, feature = "seed-support"))]
- pub async fn add_channel_member(
+ async fn add_channel_member(
&self,
channel_id: ChannelId,
user_id: UserId,
@@ -382,19 +444,19 @@ impl Db {
VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING
";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(channel_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
// messages
- pub async fn create_channel_message(
+ async fn create_channel_message(
&self,
channel_id: ChannelId,
sender_id: UserId,
@@ -409,7 +471,7 @@ impl Db {
ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(channel_id.0)
.bind(sender_id.0)
.bind(body)
@@ -417,11 +479,11 @@ impl Db {
.bind(Uuid::from_u128(nonce))
.fetch_one(&self.pool)
.await
- .map(MessageId)
+ .map(MessageId)?)
})
}
- pub async fn get_channel_messages(
+ async fn get_channel_messages(
&self,
channel_id: ChannelId,
count: usize,
@@ -431,7 +493,7 @@ impl Db {
let query = r#"
SELECT * FROM (
SELECT
- id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
+ id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
FROM
channel_messages
WHERE
@@ -442,12 +504,34 @@ impl Db {
) as recent_messages
ORDER BY id ASC
"#;
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(channel_id.0)
.bind(before_id.unwrap_or(MessageId::MAX))
.bind(count as i64)
.fetch_all(&self.pool)
+ .await?)
+ })
+ }
+
+ #[cfg(test)]
+ async fn teardown(&self, name: &str, url: &str) {
+ use util::ResultExt;
+
+ test_support!(self, {
+ let query = "
+ SELECT pg_terminate_backend(pg_stat_activity.pid)
+ FROM pg_stat_activity
+ WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
+ ";
+ sqlx::query(query)
+ .bind(name)
+ .execute(&self.pool)
+ .await
+ .log_err();
+ self.pool.close().await;
+ <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
.await
+ .log_err();
})
}
}
@@ -479,7 +563,7 @@ macro_rules! id_type {
}
id_type!(UserId);
-#[derive(Debug, FromRow, Serialize, PartialEq)]
+#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
pub struct User {
pub id: UserId,
pub github_login: String,
@@ -507,16 +591,19 @@ pub struct Signup {
}
id_type!(ChannelId);
-#[derive(Debug, FromRow, Serialize)]
+#[derive(Clone, Debug, FromRow, Serialize)]
pub struct Channel {
pub id: ChannelId,
pub name: String,
+ pub owner_id: i32,
+ pub owner_is_user: bool,
}
id_type!(MessageId);
-#[derive(Debug, FromRow)]
+#[derive(Clone, Debug, FromRow)]
pub struct ChannelMessage {
pub id: MessageId,
+ pub channel_id: ChannelId,
pub sender_id: UserId,
pub body: String,
pub sent_at: OffsetDateTime,
@@ -526,6 +613,9 @@ pub struct ChannelMessage {
#[cfg(test)]
pub mod tests {
use super::*;
+ use anyhow::anyhow;
+ use collections::BTreeMap;
+ use gpui::{executor::Background, TestAppContext};
use lazy_static::lazy_static;
use parking_lot::Mutex;
use rand::prelude::*;
@@ -533,227 +623,119 @@ pub mod tests {
migrate::{MigrateDatabase, Migrator},
Postgres,
};
- use std::{
- mem,
- path::Path,
- sync::atomic::{AtomicUsize, Ordering::SeqCst},
- thread,
- };
- use util::ResultExt as _;
-
- pub struct TestDb {
- pub db: Option<Db>,
- pub name: String,
- pub url: String,
- clean_pool_on_drop: bool,
- }
+ use std::{path::Path, sync::Arc};
+ use util::post_inc;
- lazy_static! {
- static ref DB_POOL: Mutex<Vec<TestDb>> = Default::default();
- static ref DB_COUNT: AtomicUsize = Default::default();
- }
-
- impl TestDb {
- pub fn new() -> Self {
- DB_COUNT.fetch_add(1, SeqCst);
- let mut pool = DB_POOL.lock();
- if let Some(db) = pool.pop() {
- db.truncate();
- db
- } else {
- let mut rng = StdRng::from_entropy();
- let name = format!("zed-test-{}", rng.gen::<u128>());
- let url = format!("postgres://postgres@localhost/{}", name);
- let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
- let db = block_on(async {
- Postgres::create_database(&url)
- .await
- .expect("failed to create test db");
- let mut db = Db::new(&url, 5).await.unwrap();
- db.test_mode = true;
- let migrator = Migrator::new(migrations_path).await.unwrap();
- migrator.run(&db.pool).await.unwrap();
- db
- });
-
- Self {
- db: Some(db),
- name,
- url,
- clean_pool_on_drop: false,
- }
- }
- }
-
- pub fn set_clean_pool_on_drop(&mut self, delete_on_drop: bool) {
- self.clean_pool_on_drop = delete_on_drop;
- }
+ #[gpui::test]
+ async fn test_get_users_by_ids(cx: TestAppContext) {
+ for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
+ let db = test_db.db();
- pub fn db(&self) -> &Db {
- self.db.as_ref().unwrap()
- }
+ let user = db.create_user("user", false).await.unwrap();
+ let friend1 = db.create_user("friend-1", false).await.unwrap();
+ let friend2 = db.create_user("friend-2", false).await.unwrap();
+ let friend3 = db.create_user("friend-3", false).await.unwrap();
- fn truncate(&self) {
- block_on(async {
- let query = "
- SELECT tablename FROM pg_tables
- WHERE schemaname = 'public';
- ";
- let table_names = sqlx::query_scalar::<_, String>(query)
- .fetch_all(&self.db().pool)
+ assert_eq!(
+ db.get_users_by_ids(vec![user, friend1, friend2, friend3])
.await
- .unwrap();
- sqlx::query(&format!(
- "TRUNCATE TABLE {} RESTART IDENTITY",
- table_names.join(", ")
- ))
- .execute(&self.db().pool)
- .await
- .unwrap();
- })
- }
-
- async fn teardown(mut self) -> Result<()> {
- let db = self.db.take().unwrap();
- let query = "
- SELECT pg_terminate_backend(pg_stat_activity.pid)
- FROM pg_stat_activity
- WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
- ";
- sqlx::query(query)
- .bind(&self.name)
- .execute(&db.pool)
- .await?;
- db.pool.close().await;
- Postgres::drop_database(&self.url).await?;
- Ok(())
- }
- }
-
- impl Drop for TestDb {
- fn drop(&mut self) {
- if let Some(db) = self.db.take() {
- DB_POOL.lock().push(TestDb {
- db: Some(db),
- name: mem::take(&mut self.name),
- url: mem::take(&mut self.url),
- clean_pool_on_drop: true,
- });
- if DB_COUNT.fetch_sub(1, SeqCst) == 1
- && (self.clean_pool_on_drop || thread::panicking())
- {
- block_on(async move {
- let mut pool = DB_POOL.lock();
- for db in pool.drain(..) {
- db.teardown().await.log_err();
- }
- });
- }
- }
+ .unwrap(),
+ vec![
+ User {
+ id: user,
+ github_login: "user".to_string(),
+ admin: false,
+ },
+ User {
+ id: friend1,
+ github_login: "friend-1".to_string(),
+ admin: false,
+ },
+ User {
+ id: friend2,
+ github_login: "friend-2".to_string(),
+ admin: false,
+ },
+ User {
+ id: friend3,
+ github_login: "friend-3".to_string(),
+ admin: false,
+ }
+ ]
+ );
}
}
#[gpui::test]
- async fn test_get_users_by_ids() {
- let test_db = TestDb::new();
- let db = test_db.db();
-
- let user = db.create_user("user", false).await.unwrap();
- let friend1 = db.create_user("friend-1", false).await.unwrap();
- let friend2 = db.create_user("friend-2", false).await.unwrap();
- let friend3 = db.create_user("friend-3", false).await.unwrap();
-
- assert_eq!(
- db.get_users_by_ids([user, friend1, friend2, friend3])
+ async fn test_recent_channel_messages(cx: TestAppContext) {
+ for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
+ let db = test_db.db();
+ let user = db.create_user("user", false).await.unwrap();
+ let org = db.create_org("org", "org").await.unwrap();
+ let channel = db.create_org_channel(org, "channel").await.unwrap();
+ for i in 0..10 {
+ db.create_channel_message(
+ channel,
+ user,
+ &i.to_string(),
+ OffsetDateTime::now_utc(),
+ i,
+ )
.await
- .unwrap(),
- vec![
- User {
- id: user,
- github_login: "user".to_string(),
- admin: false,
- },
- User {
- id: friend1,
- github_login: "friend-1".to_string(),
- admin: false,
- },
- User {
- id: friend2,
- github_login: "friend-2".to_string(),
- admin: false,
- },
- User {
- id: friend3,
- github_login: "friend-3".to_string(),
- admin: false,
- }
- ]
- );
- }
+ .unwrap();
+ }
- #[gpui::test]
- async fn test_recent_channel_messages() {
- let test_db = TestDb::new();
- let db = test_db.db();
- let user = db.create_user("user", false).await.unwrap();
- let org = db.create_org("org", "org").await.unwrap();
- let channel = db.create_org_channel(org, "channel").await.unwrap();
- for i in 0..10 {
- db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
+ let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
+ assert_eq!(
+ messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+ ["5", "6", "7", "8", "9"]
+ );
+
+ let prev_messages = db
+ .get_channel_messages(channel, 4, Some(messages[0].id))
.await
.unwrap();
+ assert_eq!(
+ prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+ ["1", "2", "3", "4"]
+ );
}
-
- let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
- assert_eq!(
- messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
- ["5", "6", "7", "8", "9"]
- );
-
- let prev_messages = db
- .get_channel_messages(channel, 4, Some(messages[0].id))
- .await
- .unwrap();
- assert_eq!(
- prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
- ["1", "2", "3", "4"]
- );
}
#[gpui::test]
- async fn test_channel_message_nonces() {
- let test_db = TestDb::new();
- let db = test_db.db();
- let user = db.create_user("user", false).await.unwrap();
- let org = db.create_org("org", "org").await.unwrap();
- let channel = db.create_org_channel(org, "channel").await.unwrap();
-
- let msg1_id = db
- .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
- .await
- .unwrap();
- let msg2_id = db
- .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
- .await
- .unwrap();
- let msg3_id = db
- .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
- .await
- .unwrap();
- let msg4_id = db
- .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
- .await
- .unwrap();
+ async fn test_channel_message_nonces(cx: TestAppContext) {
+ for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
+ let db = test_db.db();
+ let user = db.create_user("user", false).await.unwrap();
+ let org = db.create_org("org", "org").await.unwrap();
+ let channel = db.create_org_channel(org, "channel").await.unwrap();
+
+ let msg1_id = db
+ .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
+ .await
+ .unwrap();
+ let msg2_id = db
+ .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
+ .await
+ .unwrap();
+ let msg3_id = db
+ .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
+ .await
+ .unwrap();
+ let msg4_id = db
+ .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
+ .await
+ .unwrap();
- assert_ne!(msg1_id, msg2_id);
- assert_eq!(msg1_id, msg3_id);
- assert_eq!(msg2_id, msg4_id);
+ assert_ne!(msg1_id, msg2_id);
+ assert_eq!(msg1_id, msg3_id);
+ assert_eq!(msg2_id, msg4_id);
+ }
}
#[gpui::test]
async fn test_create_access_tokens() {
- let test_db = TestDb::new();
+ let test_db = TestDb::postgres();
let db = test_db.db();
let user = db.create_user("the-user", false).await.unwrap();
@@ -782,4 +764,359 @@ pub mod tests {
&["h5".to_string(), "h4".to_string(), "h3".to_string()]
);
}
+
+ pub struct TestDb {
+ pub db: Option<Arc<dyn Db>>,
+ pub name: String,
+ pub url: String,
+ }
+
+ impl TestDb {
+ pub fn postgres() -> Self {
+ lazy_static! {
+ static ref LOCK: Mutex<()> = Mutex::new(());
+ }
+
+ let _guard = LOCK.lock();
+ let mut rng = StdRng::from_entropy();
+ let name = format!("zed-test-{}", rng.gen::<u128>());
+ let url = format!("postgres://postgres@localhost/{}", name);
+ let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
+ let db = block_on(async {
+ Postgres::create_database(&url)
+ .await
+ .expect("failed to create test db");
+ let mut db = PostgresDb::new(&url, 5).await.unwrap();
+ db.test_mode = true;
+ let migrator = Migrator::new(migrations_path).await.unwrap();
+ migrator.run(&db.pool).await.unwrap();
+ db
+ });
+ Self {
+ db: Some(Arc::new(db)),
+ name,
+ url,
+ }
+ }
+
+ pub fn fake(background: Arc<Background>) -> Self {
+ Self {
+ db: Some(Arc::new(FakeDb::new(background))),
+ name: "fake".to_string(),
+ url: "fake".to_string(),
+ }
+ }
+
+ pub fn db(&self) -> &Arc<dyn Db> {
+ self.db.as_ref().unwrap()
+ }
+ }
+
+ impl Drop for TestDb {
+ fn drop(&mut self) {
+ if let Some(db) = self.db.take() {
+ block_on(db.teardown(&self.name, &self.url));
+ }
+ }
+ }
+
+ pub struct FakeDb {
+ background: Arc<Background>,
+ users: Mutex<BTreeMap<UserId, User>>,
+ next_user_id: Mutex<i32>,
+ orgs: Mutex<BTreeMap<OrgId, Org>>,
+ next_org_id: Mutex<i32>,
+ org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
+ channels: Mutex<BTreeMap<ChannelId, Channel>>,
+ next_channel_id: Mutex<i32>,
+ channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
+ channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
+ next_channel_message_id: Mutex<i32>,
+ }
+
+ impl FakeDb {
+ pub fn new(background: Arc<Background>) -> Self {
+ Self {
+ background,
+ users: Default::default(),
+ next_user_id: Mutex::new(1),
+ orgs: Default::default(),
+ next_org_id: Mutex::new(1),
+ org_memberships: Default::default(),
+ channels: Default::default(),
+ next_channel_id: Mutex::new(1),
+ channel_memberships: Default::default(),
+ channel_messages: Default::default(),
+ next_channel_message_id: Mutex::new(1),
+ }
+ }
+ }
+
+ #[async_trait]
+ impl Db for FakeDb {
+ async fn create_signup(
+ &self,
+ _github_login: &str,
+ _email_address: &str,
+ _about: &str,
+ _wants_releases: bool,
+ _wants_updates: bool,
+ _wants_community: bool,
+ ) -> Result<SignupId> {
+ unimplemented!()
+ }
+
+ async fn get_all_signups(&self) -> Result<Vec<Signup>> {
+ unimplemented!()
+ }
+
+ async fn destroy_signup(&self, _id: SignupId) -> Result<()> {
+ unimplemented!()
+ }
+
+ async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
+ self.background.simulate_random_delay().await;
+
+ let mut users = self.users.lock();
+ if let Some(user) = users
+ .values()
+ .find(|user| user.github_login == github_login)
+ {
+ Ok(user.id)
+ } else {
+ let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
+ users.insert(
+ user_id,
+ User {
+ id: user_id,
+ github_login: github_login.to_string(),
+ admin,
+ },
+ );
+ Ok(user_id)
+ }
+ }
+
+ async fn get_all_users(&self) -> Result<Vec<User>> {
+ unimplemented!()
+ }
+
+ async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+ Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
+ }
+
+ async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
+ self.background.simulate_random_delay().await;
+ let users = self.users.lock();
+ Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
+ }
+
+ async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
+ unimplemented!()
+ }
+
+ async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
+ unimplemented!()
+ }
+
+ async fn destroy_user(&self, _id: UserId) -> Result<()> {
+ unimplemented!()
+ }
+
+ async fn create_access_token_hash(
+ &self,
+ _user_id: UserId,
+ _access_token_hash: &str,
+ _max_access_token_count: usize,
+ ) -> Result<()> {
+ unimplemented!()
+ }
+
+ async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
+ unimplemented!()
+ }
+
+ async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
+ unimplemented!()
+ }
+
+ async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+ self.background.simulate_random_delay().await;
+ let mut orgs = self.orgs.lock();
+ if orgs.values().any(|org| org.slug == slug) {
+ Err(anyhow!("org already exists"))
+ } else {
+ let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
+ orgs.insert(
+ org_id,
+ Org {
+ id: org_id,
+ name: name.to_string(),
+ slug: slug.to_string(),
+ },
+ );
+ Ok(org_id)
+ }
+ }
+
+ async fn add_org_member(
+ &self,
+ org_id: OrgId,
+ user_id: UserId,
+ is_admin: bool,
+ ) -> Result<()> {
+ self.background.simulate_random_delay().await;
+ if !self.orgs.lock().contains_key(&org_id) {
+ return Err(anyhow!("org does not exist"));
+ }
+ if !self.users.lock().contains_key(&user_id) {
+ return Err(anyhow!("user does not exist"));
+ }
+
+ self.org_memberships
+ .lock()
+ .entry((org_id, user_id))
+ .or_insert(is_admin);
+ Ok(())
+ }
+
+ async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+ self.background.simulate_random_delay().await;
+ if !self.orgs.lock().contains_key(&org_id) {
+ return Err(anyhow!("org does not exist"));
+ }
+
+ let mut channels = self.channels.lock();
+ let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
+ channels.insert(
+ channel_id,
+ Channel {
+ id: channel_id,
+ name: name.to_string(),
+ owner_id: org_id.0,
+ owner_is_user: false,
+ },
+ );
+ Ok(channel_id)
+ }
+
+ async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
+ self.background.simulate_random_delay().await;
+ Ok(self
+ .channels
+ .lock()
+ .values()
+ .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
+ .cloned()
+ .collect())
+ }
+
+ async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
+ self.background.simulate_random_delay().await;
+ let channels = self.channels.lock();
+ let memberships = self.channel_memberships.lock();
+ Ok(channels
+ .values()
+ .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
+ .cloned()
+ .collect())
+ }
+
+ async fn can_user_access_channel(
+ &self,
+ user_id: UserId,
+ channel_id: ChannelId,
+ ) -> Result<bool> {
+ self.background.simulate_random_delay().await;
+ Ok(self
+ .channel_memberships
+ .lock()
+ .contains_key(&(channel_id, user_id)))
+ }
+
+ async fn add_channel_member(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ is_admin: bool,
+ ) -> Result<()> {
+ self.background.simulate_random_delay().await;
+ if !self.channels.lock().contains_key(&channel_id) {
+ return Err(anyhow!("channel does not exist"));
+ }
+ if !self.users.lock().contains_key(&user_id) {
+ return Err(anyhow!("user does not exist"));
+ }
+
+ self.channel_memberships
+ .lock()
+ .entry((channel_id, user_id))
+ .or_insert(is_admin);
+ Ok(())
+ }
+
+ async fn create_channel_message(
+ &self,
+ channel_id: ChannelId,
+ sender_id: UserId,
+ body: &str,
+ timestamp: OffsetDateTime,
+ nonce: u128,
+ ) -> Result<MessageId> {
+ self.background.simulate_random_delay().await;
+ if !self.channels.lock().contains_key(&channel_id) {
+ return Err(anyhow!("channel does not exist"));
+ }
+ if !self.users.lock().contains_key(&sender_id) {
+ return Err(anyhow!("user does not exist"));
+ }
+
+ let mut messages = self.channel_messages.lock();
+ if let Some(message) = messages
+ .values()
+ .find(|message| message.nonce.as_u128() == nonce)
+ {
+ Ok(message.id)
+ } else {
+ let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
+ messages.insert(
+ message_id,
+ ChannelMessage {
+ id: message_id,
+ channel_id,
+ sender_id,
+ body: body.to_string(),
+ sent_at: timestamp,
+ nonce: Uuid::from_u128(nonce),
+ },
+ );
+ Ok(message_id)
+ }
+ }
+
+ async fn get_channel_messages(
+ &self,
+ channel_id: ChannelId,
+ count: usize,
+ before_id: Option<MessageId>,
+ ) -> Result<Vec<ChannelMessage>> {
+ let mut messages = self
+ .channel_messages
+ .lock()
+ .values()
+ .rev()
+ .filter(|message| {
+ message.channel_id == channel_id
+ && message.id < before_id.unwrap_or(MessageId::MAX)
+ })
+ .take(count)
+ .cloned()
+ .collect::<Vec<_>>();
+ dbg!(count, before_id, &messages);
+ messages.sort_unstable_by_key(|message| message.id);
+ Ok(messages)
+ }
+
+ async fn teardown(&self, _name: &str, _url: &str) {}
+ }
}
@@ -785,7 +785,12 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::GetUsers>,
) -> tide::Result<proto::GetUsersResponse> {
- let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
+ let user_ids = request
+ .payload
+ .user_ids
+ .into_iter()
+ .map(UserId::from_proto)
+ .collect();
let users = self
.app_state
.db
@@ -1139,18 +1144,14 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_share_project(
- mut cx_a: TestAppContext,
- mut cx_b: TestAppContext,
- last_iteration: bool,
- ) {
+ async fn test_share_project(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
let (window_b, _) = cx_b.add_window(|_| EmptyView);
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1282,17 +1283,13 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_unshare_project(
- mut cx_a: TestAppContext,
- mut cx_b: TestAppContext,
- last_iteration: bool,
- ) {
+ async fn test_unshare_project(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1387,14 +1384,13 @@ mod tests {
mut cx_a: TestAppContext,
mut cx_b: TestAppContext,
mut cx_c: TestAppContext,
- last_iteration: bool,
) {
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
cx_a.foreground().forbid_parking();
// Connect to a server as 3 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
let client_c = server.create_client(&mut cx_c, "user_c").await;
@@ -1566,17 +1562,13 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_buffer_conflict_after_save(
- mut cx_a: TestAppContext,
- mut cx_b: TestAppContext,
- last_iteration: bool,
- ) {
+ async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
cx_a.foreground().forbid_parking();
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1658,17 +1650,13 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_buffer_reloading(
- mut cx_a: TestAppContext,
- mut cx_b: TestAppContext,
- last_iteration: bool,
- ) {
+ async fn test_buffer_reloading(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
cx_a.foreground().forbid_parking();
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1747,14 +1735,13 @@ mod tests {
async fn test_editing_while_guest_opens_buffer(
mut cx_a: TestAppContext,
mut cx_b: TestAppContext,
- last_iteration: bool,
) {
cx_a.foreground().forbid_parking();
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1830,14 +1817,13 @@ mod tests {
async fn test_leaving_worktree_while_opening_buffer(
mut cx_a: TestAppContext,
mut cx_b: TestAppContext,
- last_iteration: bool,
) {
cx_a.foreground().forbid_parking();
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1906,17 +1892,13 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_peer_disconnection(
- mut cx_a: TestAppContext,
- mut cx_b: TestAppContext,
- last_iteration: bool,
- ) {
+ async fn test_peer_disconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
cx_a.foreground().forbid_parking();
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1984,7 +1966,6 @@ mod tests {
async fn test_collaborating_with_diagnostics(
mut cx_a: TestAppContext,
mut cx_b: TestAppContext,
- last_iteration: bool,
) {
cx_a.foreground().forbid_parking();
let mut lang_registry = Arc::new(LanguageRegistry::new());
@@ -2005,7 +1986,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2209,7 +2190,6 @@ mod tests {
async fn test_collaborating_with_completion(
mut cx_a: TestAppContext,
mut cx_b: TestAppContext,
- last_iteration: bool,
) {
cx_a.foreground().forbid_parking();
let mut lang_registry = Arc::new(LanguageRegistry::new());
@@ -2237,7 +2217,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2419,11 +2399,7 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_formatting_buffer(
- mut cx_a: TestAppContext,
- mut cx_b: TestAppContext,
- last_iteration: bool,
- ) {
+ async fn test_formatting_buffer(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
cx_a.foreground().forbid_parking();
let mut lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
@@ -2443,7 +2419,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2525,11 +2501,7 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_definition(
- mut cx_a: TestAppContext,
- mut cx_b: TestAppContext,
- last_iteration: bool,
- ) {
+ async fn test_definition(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
cx_a.foreground().forbid_parking();
let mut lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
@@ -2564,7 +2536,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2682,7 +2654,6 @@ mod tests {
mut cx_a: TestAppContext,
mut cx_b: TestAppContext,
mut rng: StdRng,
- last_iteration: bool,
) {
cx_a.foreground().forbid_parking();
let mut lang_registry = Arc::new(LanguageRegistry::new());
@@ -2713,7 +2684,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2792,7 +2763,6 @@ mod tests {
async fn test_collaborating_with_code_actions(
mut cx_a: TestAppContext,
mut cx_b: TestAppContext,
- last_iteration: bool,
) {
cx_a.foreground().forbid_parking();
let mut lang_registry = Arc::new(LanguageRegistry::new());
@@ -2815,7 +2785,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -3032,15 +3002,11 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_basic_chat(
- mut cx_a: TestAppContext,
- mut cx_b: TestAppContext,
- last_iteration: bool,
- ) {
+ async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -3176,10 +3142,10 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_chat_message_validation(mut cx_a: TestAppContext, last_iteration: bool) {
+ async fn test_chat_message_validation(mut cx_a: TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let db = &server.app_state.db;
@@ -3236,15 +3202,11 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_chat_reconnection(
- mut cx_a: TestAppContext,
- mut cx_b: TestAppContext,
- last_iteration: bool,
- ) {
+ async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
let mut status_b = client_b.status();
@@ -3456,14 +3418,13 @@ mod tests {
mut cx_a: TestAppContext,
mut cx_b: TestAppContext,
mut cx_c: TestAppContext,
- last_iteration: bool,
) {
cx_a.foreground().forbid_parking();
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 3 clients.
- let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
let client_c = server.create_client(&mut cx_c, "user_c").await;
@@ -3595,7 +3556,7 @@ mod tests {
}
#[gpui::test(iterations = 100)]
- async fn test_random_collaboration(cx: TestAppContext, rng: StdRng, last_iteration: bool) {
+ async fn test_random_collaboration(cx: TestAppContext, rng: StdRng) {
cx.foreground().forbid_parking();
let max_peers = env::var("MAX_PEERS")
.map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
@@ -3654,7 +3615,7 @@ mod tests {
.await;
let operations = Rc::new(Cell::new(0));
- let mut server = TestServer::start(cx.foreground(), last_iteration).await;
+ let mut server = TestServer::start(cx.foreground(), cx.background()).await;
let mut clients = Vec::new();
let mut next_entity_id = 100000;
@@ -3849,9 +3810,11 @@ mod tests {
}
impl TestServer {
- async fn start(foreground: Rc<executor::Foreground>, clean_db_pool_on_drop: bool) -> Self {
- let mut test_db = TestDb::new();
- test_db.set_clean_pool_on_drop(clean_db_pool_on_drop);
+ async fn start(
+ foreground: Rc<executor::Foreground>,
+ background: Arc<executor::Background>,
+ ) -> Self {
+ let test_db = TestDb::fake(background);
let app_state = Self::build_app_state(&test_db).await;
let peer = Peer::new();
let notifications = mpsc::unbounded();