Use a fake database in tests

Antonio Scandurra created

Change summary

crates/server/src/api.rs  |   2 
crates/server/src/auth.rs |   6 
crates/server/src/db.rs   | 899 ++++++++++++++++++++++++++++------------
crates/server/src/main.rs |  12 
crates/server/src/rpc.rs  | 119 +---
5 files changed, 669 insertions(+), 369 deletions(-)

Detailed changes

crates/server/src/api.rs 🔗

@@ -111,7 +111,7 @@ async fn create_access_token(request: Request) -> tide::Result {
         .get_user_by_github_login(request.param("github_login")?)
         .await?
         .ok_or_else(|| surf::Error::from_str(StatusCode::NotFound, "user not found"))?;
-    let access_token = auth::create_access_token(request.db(), user.id).await?;
+    let access_token = auth::create_access_token(request.db().as_ref(), user.id).await?;
 
     #[derive(Deserialize)]
     struct QueryParams {

crates/server/src/auth.rs 🔗

@@ -234,7 +234,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
         let mut user_id = user.id;
         if let Some(impersonated_login) = app_sign_in_params.impersonate {
             log::info!("attempting to impersonate user @{}", impersonated_login);
-            if let Some(user) = request.db().get_users_by_ids([user_id]).await?.first() {
+            if let Some(user) = request.db().get_users_by_ids(vec![user_id]).await?.first() {
                 if user.admin {
                     user_id = request.db().create_user(&impersonated_login, false).await?;
                     log::info!("impersonating user {}", user_id.0);
@@ -244,7 +244,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
             }
         }
 
-        let access_token = create_access_token(request.db(), user_id).await?;
+        let access_token = create_access_token(request.db().as_ref(), user_id).await?;
         let encrypted_access_token = encrypt_access_token(
             &access_token,
             app_sign_in_params.native_app_public_key.clone(),
@@ -267,7 +267,7 @@ async fn post_sign_out(mut request: Request) -> tide::Result {
 
 const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
 
-pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
+pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result<String> {
     let access_token = zed_auth::random_token();
     let access_token_hash =
         hash_access_token(&access_token).context("failed to hash access token")?;

crates/server/src/db.rs 🔗

@@ -1,11 +1,12 @@
 use anyhow::Context;
+use anyhow::Result;
+pub use async_sqlx_session::PostgresSessionStore as SessionStore;
 use async_std::task::{block_on, yield_now};
+use async_trait::async_trait;
 use serde::Serialize;
-use sqlx::{types::Uuid, FromRow, Result};
-use time::OffsetDateTime;
-
-pub use async_sqlx_session::PostgresSessionStore as SessionStore;
 pub use sqlx::postgres::PgPoolOptions as DbOptions;
+use sqlx::{types::Uuid, FromRow};
+use time::OffsetDateTime;
 
 macro_rules! test_support {
     ($self:ident, { $($token:tt)* }) => {{
@@ -21,13 +22,77 @@ macro_rules! test_support {
     }};
 }
 
-#[derive(Clone)]
-pub struct Db {
+#[async_trait]
+pub trait Db: Send + Sync {
+    async fn create_signup(
+        &self,
+        github_login: &str,
+        email_address: &str,
+        about: &str,
+        wants_releases: bool,
+        wants_updates: bool,
+        wants_community: bool,
+    ) -> Result<SignupId>;
+    async fn get_all_signups(&self) -> Result<Vec<Signup>>;
+    async fn destroy_signup(&self, id: SignupId) -> Result<()>;
+    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
+    async fn get_all_users(&self) -> Result<Vec<User>>;
+    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
+    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
+    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
+    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
+    async fn destroy_user(&self, id: UserId) -> Result<()>;
+    async fn create_access_token_hash(
+        &self,
+        user_id: UserId,
+        access_token_hash: &str,
+        max_access_token_count: usize,
+    ) -> Result<()>;
+    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
+    #[cfg(any(test, feature = "seed-support"))]
+    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
+    #[cfg(any(test, feature = "seed-support"))]
+    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
+    #[cfg(any(test, feature = "seed-support"))]
+    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
+    #[cfg(any(test, feature = "seed-support"))]
+    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
+    #[cfg(any(test, feature = "seed-support"))]
+    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
+    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
+    async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
+        -> Result<bool>;
+    #[cfg(any(test, feature = "seed-support"))]
+    async fn add_channel_member(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        is_admin: bool,
+    ) -> Result<()>;
+    async fn create_channel_message(
+        &self,
+        channel_id: ChannelId,
+        sender_id: UserId,
+        body: &str,
+        timestamp: OffsetDateTime,
+        nonce: u128,
+    ) -> Result<MessageId>;
+    async fn get_channel_messages(
+        &self,
+        channel_id: ChannelId,
+        count: usize,
+        before_id: Option<MessageId>,
+    ) -> Result<Vec<ChannelMessage>>;
+    #[cfg(test)]
+    async fn teardown(&self, name: &str, url: &str);
+}
+
+pub struct PostgresDb {
     pool: sqlx::PgPool,
     test_mode: bool,
 }
 
-impl Db {
+impl PostgresDb {
     pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
         let pool = DbOptions::new()
             .max_connections(max_connections)
@@ -39,10 +104,12 @@ impl Db {
             test_mode: false,
         })
     }
+}
 
+#[async_trait]
+impl Db for PostgresDb {
     // signups
-
-    pub async fn create_signup(
+    async fn create_signup(
         &self,
         github_login: &str,
         email_address: &str,
@@ -64,7 +131,7 @@ impl Db {
                 VALUES ($1, $2, $3, $4, $5, $6)
                 RETURNING id
             ";
-            sqlx::query_scalar(query)
+            Ok(sqlx::query_scalar(query)
                 .bind(github_login)
                 .bind(email_address)
                 .bind(about)
@@ -73,31 +140,31 @@ impl Db {
                 .bind(wants_community)
                 .fetch_one(&self.pool)
                 .await
-                .map(SignupId)
+                .map(SignupId)?)
         })
     }
 
-    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
+    async fn get_all_signups(&self) -> Result<Vec<Signup>> {
         test_support!(self, {
             let query = "SELECT * FROM signups ORDER BY github_login ASC";
-            sqlx::query_as(query).fetch_all(&self.pool).await
+            Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
         })
     }
 
-    pub async fn destroy_signup(&self, id: SignupId) -> Result<()> {
+    async fn destroy_signup(&self, id: SignupId) -> Result<()> {
         test_support!(self, {
             let query = "DELETE FROM signups WHERE id = $1";
-            sqlx::query(query)
+            Ok(sqlx::query(query)
                 .bind(id.0)
                 .execute(&self.pool)
                 .await
-                .map(drop)
+                .map(drop)?)
         })
     }
 
     // users
 
-    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
+    async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
         test_support!(self, {
             let query = "
                 INSERT INTO users (github_login, admin)
@@ -105,31 +172,28 @@ impl Db {
                 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
                 RETURNING id
             ";
-            sqlx::query_scalar(query)
+            Ok(sqlx::query_scalar(query)
                 .bind(github_login)
                 .bind(admin)
                 .fetch_one(&self.pool)
                 .await
-                .map(UserId)
+                .map(UserId)?)
         })
     }
 
-    pub async fn get_all_users(&self) -> Result<Vec<User>> {
+    async fn get_all_users(&self) -> Result<Vec<User>> {
         test_support!(self, {
             let query = "SELECT * FROM users ORDER BY github_login ASC";
-            sqlx::query_as(query).fetch_all(&self.pool).await
+            Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
         })
     }
 
-    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
-        let users = self.get_users_by_ids([id]).await?;
+    async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+        let users = self.get_users_by_ids(vec![id]).await?;
         Ok(users.into_iter().next())
     }
 
-    pub async fn get_users_by_ids(
-        &self,
-        ids: impl IntoIterator<Item = UserId>,
-    ) -> Result<Vec<User>> {
+    async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
         let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
         test_support!(self, {
             let query = "
@@ -138,33 +202,36 @@ impl Db {
                 WHERE users.id = ANY ($1)
             ";
 
-            sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
+            Ok(sqlx::query_as(query)
+                .bind(&ids)
+                .fetch_all(&self.pool)
+                .await?)
         })
     }
 
-    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
+    async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
         test_support!(self, {
             let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
-            sqlx::query_as(query)
+            Ok(sqlx::query_as(query)
                 .bind(github_login)
                 .fetch_optional(&self.pool)
-                .await
+                .await?)
         })
     }
 
-    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
+    async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
         test_support!(self, {
             let query = "UPDATE users SET admin = $1 WHERE id = $2";
-            sqlx::query(query)
+            Ok(sqlx::query(query)
                 .bind(is_admin)
                 .bind(id.0)
                 .execute(&self.pool)
                 .await
-                .map(drop)
+                .map(drop)?)
         })
     }
 
-    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
+    async fn destroy_user(&self, id: UserId) -> Result<()> {
         test_support!(self, {
             let query = "DELETE FROM access_tokens WHERE user_id = $1;";
             sqlx::query(query)
@@ -173,17 +240,17 @@ impl Db {
                 .await
                 .map(drop)?;
             let query = "DELETE FROM users WHERE id = $1;";
-            sqlx::query(query)
+            Ok(sqlx::query(query)
                 .bind(id.0)
                 .execute(&self.pool)
                 .await
-                .map(drop)
+                .map(drop)?)
         })
     }
 
     // access tokens
 
-    pub async fn create_access_token_hash(
+    async fn create_access_token_hash(
         &self,
         user_id: UserId,
         access_token_hash: &str,
@@ -216,11 +283,11 @@ impl Db {
                 .bind(max_access_token_count as u32)
                 .execute(&mut tx)
                 .await?;
-            tx.commit().await
+            Ok(tx.commit().await?)
         })
     }
 
-    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
+    async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
         test_support!(self, {
             let query = "
                 SELECT hash
@@ -228,10 +295,10 @@ impl Db {
                 WHERE user_id = $1
                 ORDER BY id DESC
             ";
-            sqlx::query_scalar(query)
+            Ok(sqlx::query_scalar(query)
                 .bind(user_id.0)
                 .fetch_all(&self.pool)
-                .await
+                .await?)
         })
     }
 
@@ -239,82 +306,77 @@ impl Db {
 
     #[allow(unused)] // Help rust-analyzer
     #[cfg(any(test, feature = "seed-support"))]
-    pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
+    async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
         test_support!(self, {
             let query = "
                 SELECT *
                 FROM orgs
                 WHERE slug = $1
             ";
-            sqlx::query_as(query)
+            Ok(sqlx::query_as(query)
                 .bind(slug)
                 .fetch_optional(&self.pool)
-                .await
+                .await?)
         })
     }
 
     #[cfg(any(test, feature = "seed-support"))]
-    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+    async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
         test_support!(self, {
             let query = "
                 INSERT INTO orgs (name, slug)
                 VALUES ($1, $2)
                 RETURNING id
             ";
-            sqlx::query_scalar(query)
+            Ok(sqlx::query_scalar(query)
                 .bind(name)
                 .bind(slug)
                 .fetch_one(&self.pool)
                 .await
-                .map(OrgId)
+                .map(OrgId)?)
         })
     }
 
     #[cfg(any(test, feature = "seed-support"))]
-    pub async fn add_org_member(
-        &self,
-        org_id: OrgId,
-        user_id: UserId,
-        is_admin: bool,
-    ) -> Result<()> {
+    async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
         test_support!(self, {
             let query = "
                 INSERT INTO org_memberships (org_id, user_id, admin)
                 VALUES ($1, $2, $3)
                 ON CONFLICT DO NOTHING
             ";
-            sqlx::query(query)
+            Ok(sqlx::query(query)
                 .bind(org_id.0)
                 .bind(user_id.0)
                 .bind(is_admin)
                 .execute(&self.pool)
                 .await
-                .map(drop)
+                .map(drop)?)
         })
     }
 
     // channels
 
     #[cfg(any(test, feature = "seed-support"))]
-    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+    async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
         test_support!(self, {
             let query = "
                 INSERT INTO channels (owner_id, owner_is_user, name)
                 VALUES ($1, false, $2)
                 RETURNING id
             ";
-            sqlx::query_scalar(query)
+            Ok(sqlx::query_scalar(query)
                 .bind(org_id.0)
                 .bind(name)
                 .fetch_one(&self.pool)
                 .await
-                .map(ChannelId)
+                .map(ChannelId)?)
         })
     }
 
     #[allow(unused)] // Help rust-analyzer
     #[cfg(any(test, feature = "seed-support"))]
-    pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
+    async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
         test_support!(self, {
             let query = "
                 SELECT *
@@ -323,32 +385,32 @@ impl Db {
                     channels.owner_is_user = false AND
                     channels.owner_id = $1
             ";
-            sqlx::query_as(query)
+            Ok(sqlx::query_as(query)
                 .bind(org_id.0)
                 .fetch_all(&self.pool)
-                .await
+                .await?)
         })
     }
 
-    pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
+    async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
         test_support!(self, {
             let query = "
                 SELECT
-                    channels.id, channels.name
+                    channels.*
                 FROM
                     channel_memberships, channels
                 WHERE
                     channel_memberships.user_id = $1 AND
                     channel_memberships.channel_id = channels.id
             ";
-            sqlx::query_as(query)
+            Ok(sqlx::query_as(query)
                 .bind(user_id.0)
                 .fetch_all(&self.pool)
-                .await
+                .await?)
         })
     }
 
-    pub async fn can_user_access_channel(
+    async fn can_user_access_channel(
         &self,
         user_id: UserId,
         channel_id: ChannelId,
@@ -360,17 +422,17 @@ impl Db {
                 WHERE user_id = $1 AND channel_id = $2
                 LIMIT 1
             ";
-            sqlx::query_scalar::<_, i32>(query)
+            Ok(sqlx::query_scalar::<_, i32>(query)
                 .bind(user_id.0)
                 .bind(channel_id.0)
                 .fetch_optional(&self.pool)
                 .await
-                .map(|e| e.is_some())
+                .map(|e| e.is_some())?)
         })
     }
 
     #[cfg(any(test, feature = "seed-support"))]
-    pub async fn add_channel_member(
+    async fn add_channel_member(
         &self,
         channel_id: ChannelId,
         user_id: UserId,
@@ -382,19 +444,19 @@ impl Db {
                 VALUES ($1, $2, $3)
                 ON CONFLICT DO NOTHING
             ";
-            sqlx::query(query)
+            Ok(sqlx::query(query)
                 .bind(channel_id.0)
                 .bind(user_id.0)
                 .bind(is_admin)
                 .execute(&self.pool)
                 .await
-                .map(drop)
+                .map(drop)?)
         })
     }
 
     // messages
 
-    pub async fn create_channel_message(
+    async fn create_channel_message(
         &self,
         channel_id: ChannelId,
         sender_id: UserId,
@@ -409,7 +471,7 @@ impl Db {
                 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
                 RETURNING id
             ";
-            sqlx::query_scalar(query)
+            Ok(sqlx::query_scalar(query)
                 .bind(channel_id.0)
                 .bind(sender_id.0)
                 .bind(body)
@@ -417,11 +479,11 @@ impl Db {
                 .bind(Uuid::from_u128(nonce))
                 .fetch_one(&self.pool)
                 .await
-                .map(MessageId)
+                .map(MessageId)?)
         })
     }
 
-    pub async fn get_channel_messages(
+    async fn get_channel_messages(
         &self,
         channel_id: ChannelId,
         count: usize,
@@ -431,7 +493,7 @@ impl Db {
             let query = r#"
                 SELECT * FROM (
                     SELECT
-                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
+                        id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
                     FROM
                         channel_messages
                     WHERE
@@ -442,12 +504,34 @@ impl Db {
                 ) as recent_messages
                 ORDER BY id ASC
             "#;
-            sqlx::query_as(query)
+            Ok(sqlx::query_as(query)
                 .bind(channel_id.0)
                 .bind(before_id.unwrap_or(MessageId::MAX))
                 .bind(count as i64)
                 .fetch_all(&self.pool)
+                .await?)
+        })
+    }
+
+    #[cfg(test)]
+    async fn teardown(&self, name: &str, url: &str) {
+        use util::ResultExt;
+
+        test_support!(self, {
+            let query = "
+                SELECT pg_terminate_backend(pg_stat_activity.pid)
+                FROM pg_stat_activity
+                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
+            ";
+            sqlx::query(query)
+                .bind(name)
+                .execute(&self.pool)
+                .await
+                .log_err();
+            self.pool.close().await;
+            <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
                 .await
+                .log_err();
         })
     }
 }
@@ -479,7 +563,7 @@ macro_rules! id_type {
 }
 
 id_type!(UserId);
-#[derive(Debug, FromRow, Serialize, PartialEq)]
+#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
 pub struct User {
     pub id: UserId,
     pub github_login: String,
@@ -507,16 +591,19 @@ pub struct Signup {
 }
 
 id_type!(ChannelId);
-#[derive(Debug, FromRow, Serialize)]
+#[derive(Clone, Debug, FromRow, Serialize)]
 pub struct Channel {
     pub id: ChannelId,
     pub name: String,
+    pub owner_id: i32,
+    pub owner_is_user: bool,
 }
 
 id_type!(MessageId);
-#[derive(Debug, FromRow)]
+#[derive(Clone, Debug, FromRow)]
 pub struct ChannelMessage {
     pub id: MessageId,
+    pub channel_id: ChannelId,
     pub sender_id: UserId,
     pub body: String,
     pub sent_at: OffsetDateTime,
@@ -526,6 +613,9 @@ pub struct ChannelMessage {
 #[cfg(test)]
 pub mod tests {
     use super::*;
+    use anyhow::anyhow;
+    use collections::BTreeMap;
+    use gpui::{executor::Background, TestAppContext};
     use lazy_static::lazy_static;
     use parking_lot::Mutex;
     use rand::prelude::*;
@@ -533,227 +623,119 @@ pub mod tests {
         migrate::{MigrateDatabase, Migrator},
         Postgres,
     };
-    use std::{
-        mem,
-        path::Path,
-        sync::atomic::{AtomicUsize, Ordering::SeqCst},
-        thread,
-    };
-    use util::ResultExt as _;
-
-    pub struct TestDb {
-        pub db: Option<Db>,
-        pub name: String,
-        pub url: String,
-        clean_pool_on_drop: bool,
-    }
+    use std::{path::Path, sync::Arc};
+    use util::post_inc;
 
-    lazy_static! {
-        static ref DB_POOL: Mutex<Vec<TestDb>> = Default::default();
-        static ref DB_COUNT: AtomicUsize = Default::default();
-    }
-
-    impl TestDb {
-        pub fn new() -> Self {
-            DB_COUNT.fetch_add(1, SeqCst);
-            let mut pool = DB_POOL.lock();
-            if let Some(db) = pool.pop() {
-                db.truncate();
-                db
-            } else {
-                let mut rng = StdRng::from_entropy();
-                let name = format!("zed-test-{}", rng.gen::<u128>());
-                let url = format!("postgres://postgres@localhost/{}", name);
-                let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
-                let db = block_on(async {
-                    Postgres::create_database(&url)
-                        .await
-                        .expect("failed to create test db");
-                    let mut db = Db::new(&url, 5).await.unwrap();
-                    db.test_mode = true;
-                    let migrator = Migrator::new(migrations_path).await.unwrap();
-                    migrator.run(&db.pool).await.unwrap();
-                    db
-                });
-
-                Self {
-                    db: Some(db),
-                    name,
-                    url,
-                    clean_pool_on_drop: false,
-                }
-            }
-        }
-
-        pub fn set_clean_pool_on_drop(&mut self, delete_on_drop: bool) {
-            self.clean_pool_on_drop = delete_on_drop;
-        }
+    #[gpui::test]
+    async fn test_get_users_by_ids(cx: TestAppContext) {
+        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
+            let db = test_db.db();
 
-        pub fn db(&self) -> &Db {
-            self.db.as_ref().unwrap()
-        }
+            let user = db.create_user("user", false).await.unwrap();
+            let friend1 = db.create_user("friend-1", false).await.unwrap();
+            let friend2 = db.create_user("friend-2", false).await.unwrap();
+            let friend3 = db.create_user("friend-3", false).await.unwrap();
 
-        fn truncate(&self) {
-            block_on(async {
-                let query = "
-                    SELECT tablename FROM pg_tables
-                    WHERE schemaname = 'public';
-                ";
-                let table_names = sqlx::query_scalar::<_, String>(query)
-                    .fetch_all(&self.db().pool)
+            assert_eq!(
+                db.get_users_by_ids(vec![user, friend1, friend2, friend3])
                     .await
-                    .unwrap();
-                sqlx::query(&format!(
-                    "TRUNCATE TABLE {} RESTART IDENTITY",
-                    table_names.join(", ")
-                ))
-                .execute(&self.db().pool)
-                .await
-                .unwrap();
-            })
-        }
-
-        async fn teardown(mut self) -> Result<()> {
-            let db = self.db.take().unwrap();
-            let query = "
-                SELECT pg_terminate_backend(pg_stat_activity.pid)
-                FROM pg_stat_activity
-                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
-            ";
-            sqlx::query(query)
-                .bind(&self.name)
-                .execute(&db.pool)
-                .await?;
-            db.pool.close().await;
-            Postgres::drop_database(&self.url).await?;
-            Ok(())
-        }
-    }
-
-    impl Drop for TestDb {
-        fn drop(&mut self) {
-            if let Some(db) = self.db.take() {
-                DB_POOL.lock().push(TestDb {
-                    db: Some(db),
-                    name: mem::take(&mut self.name),
-                    url: mem::take(&mut self.url),
-                    clean_pool_on_drop: true,
-                });
-                if DB_COUNT.fetch_sub(1, SeqCst) == 1
-                    && (self.clean_pool_on_drop || thread::panicking())
-                {
-                    block_on(async move {
-                        let mut pool = DB_POOL.lock();
-                        for db in pool.drain(..) {
-                            db.teardown().await.log_err();
-                        }
-                    });
-                }
-            }
+                    .unwrap(),
+                vec![
+                    User {
+                        id: user,
+                        github_login: "user".to_string(),
+                        admin: false,
+                    },
+                    User {
+                        id: friend1,
+                        github_login: "friend-1".to_string(),
+                        admin: false,
+                    },
+                    User {
+                        id: friend2,
+                        github_login: "friend-2".to_string(),
+                        admin: false,
+                    },
+                    User {
+                        id: friend3,
+                        github_login: "friend-3".to_string(),
+                        admin: false,
+                    }
+                ]
+            );
         }
     }
 
     #[gpui::test]
-    async fn test_get_users_by_ids() {
-        let test_db = TestDb::new();
-        let db = test_db.db();
-
-        let user = db.create_user("user", false).await.unwrap();
-        let friend1 = db.create_user("friend-1", false).await.unwrap();
-        let friend2 = db.create_user("friend-2", false).await.unwrap();
-        let friend3 = db.create_user("friend-3", false).await.unwrap();
-
-        assert_eq!(
-            db.get_users_by_ids([user, friend1, friend2, friend3])
+    async fn test_recent_channel_messages(cx: TestAppContext) {
+        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
+            let db = test_db.db();
+            let user = db.create_user("user", false).await.unwrap();
+            let org = db.create_org("org", "org").await.unwrap();
+            let channel = db.create_org_channel(org, "channel").await.unwrap();
+            for i in 0..10 {
+                db.create_channel_message(
+                    channel,
+                    user,
+                    &i.to_string(),
+                    OffsetDateTime::now_utc(),
+                    i,
+                )
                 .await
-                .unwrap(),
-            vec![
-                User {
-                    id: user,
-                    github_login: "user".to_string(),
-                    admin: false,
-                },
-                User {
-                    id: friend1,
-                    github_login: "friend-1".to_string(),
-                    admin: false,
-                },
-                User {
-                    id: friend2,
-                    github_login: "friend-2".to_string(),
-                    admin: false,
-                },
-                User {
-                    id: friend3,
-                    github_login: "friend-3".to_string(),
-                    admin: false,
-                }
-            ]
-        );
-    }
+                .unwrap();
+            }
 
-    #[gpui::test]
-    async fn test_recent_channel_messages() {
-        let test_db = TestDb::new();
-        let db = test_db.db();
-        let user = db.create_user("user", false).await.unwrap();
-        let org = db.create_org("org", "org").await.unwrap();
-        let channel = db.create_org_channel(org, "channel").await.unwrap();
-        for i in 0..10 {
-            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
+            let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
+            assert_eq!(
+                messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+                ["5", "6", "7", "8", "9"]
+            );
+
+            let prev_messages = db
+                .get_channel_messages(channel, 4, Some(messages[0].id))
                 .await
                 .unwrap();
+            assert_eq!(
+                prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+                ["1", "2", "3", "4"]
+            );
         }
-
-        let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
-        assert_eq!(
-            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
-            ["5", "6", "7", "8", "9"]
-        );
-
-        let prev_messages = db
-            .get_channel_messages(channel, 4, Some(messages[0].id))
-            .await
-            .unwrap();
-        assert_eq!(
-            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
-            ["1", "2", "3", "4"]
-        );
     }
 
     #[gpui::test]
-    async fn test_channel_message_nonces() {
-        let test_db = TestDb::new();
-        let db = test_db.db();
-        let user = db.create_user("user", false).await.unwrap();
-        let org = db.create_org("org", "org").await.unwrap();
-        let channel = db.create_org_channel(org, "channel").await.unwrap();
-
-        let msg1_id = db
-            .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
-            .await
-            .unwrap();
-        let msg2_id = db
-            .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
-            .await
-            .unwrap();
-        let msg3_id = db
-            .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
-            .await
-            .unwrap();
-        let msg4_id = db
-            .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
-            .await
-            .unwrap();
+    async fn test_channel_message_nonces(cx: TestAppContext) {
+        for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
+            let db = test_db.db();
+            let user = db.create_user("user", false).await.unwrap();
+            let org = db.create_org("org", "org").await.unwrap();
+            let channel = db.create_org_channel(org, "channel").await.unwrap();
+
+            let msg1_id = db
+                .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
+                .await
+                .unwrap();
+            let msg2_id = db
+                .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
+                .await
+                .unwrap();
+            let msg3_id = db
+                .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
+                .await
+                .unwrap();
+            let msg4_id = db
+                .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
+                .await
+                .unwrap();
 
-        assert_ne!(msg1_id, msg2_id);
-        assert_eq!(msg1_id, msg3_id);
-        assert_eq!(msg2_id, msg4_id);
+            assert_ne!(msg1_id, msg2_id);
+            assert_eq!(msg1_id, msg3_id);
+            assert_eq!(msg2_id, msg4_id);
+        }
     }
 
     #[gpui::test]
     async fn test_create_access_tokens() {
-        let test_db = TestDb::new();
+        let test_db = TestDb::postgres();
         let db = test_db.db();
         let user = db.create_user("the-user", false).await.unwrap();
 
@@ -782,4 +764,359 @@ pub mod tests {
             &["h5".to_string(), "h4".to_string(), "h3".to_string()]
         );
     }
+
+    pub struct TestDb {
+        pub db: Option<Arc<dyn Db>>,
+        pub name: String,
+        pub url: String,
+    }
+
+    impl TestDb {
+        pub fn postgres() -> Self {
+            lazy_static! {
+                static ref LOCK: Mutex<()> = Mutex::new(());
+            }
+
+            let _guard = LOCK.lock();
+            let mut rng = StdRng::from_entropy();
+            let name = format!("zed-test-{}", rng.gen::<u128>());
+            let url = format!("postgres://postgres@localhost/{}", name);
+            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
+            let db = block_on(async {
+                Postgres::create_database(&url)
+                    .await
+                    .expect("failed to create test db");
+                let mut db = PostgresDb::new(&url, 5).await.unwrap();
+                db.test_mode = true;
+                let migrator = Migrator::new(migrations_path).await.unwrap();
+                migrator.run(&db.pool).await.unwrap();
+                db
+            });
+            Self {
+                db: Some(Arc::new(db)),
+                name,
+                url,
+            }
+        }
+
+        pub fn fake(background: Arc<Background>) -> Self {
+            Self {
+                db: Some(Arc::new(FakeDb::new(background))),
+                name: "fake".to_string(),
+                url: "fake".to_string(),
+            }
+        }
+
+        pub fn db(&self) -> &Arc<dyn Db> {
+            self.db.as_ref().unwrap()
+        }
+    }
+
+    impl Drop for TestDb {
+        fn drop(&mut self) {
+            if let Some(db) = self.db.take() {
+                block_on(db.teardown(&self.name, &self.url));
+            }
+        }
+    }
+
+    pub struct FakeDb {
+        background: Arc<Background>,
+        users: Mutex<BTreeMap<UserId, User>>,
+        next_user_id: Mutex<i32>,
+        orgs: Mutex<BTreeMap<OrgId, Org>>,
+        next_org_id: Mutex<i32>,
+        org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
+        channels: Mutex<BTreeMap<ChannelId, Channel>>,
+        next_channel_id: Mutex<i32>,
+        channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
+        channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
+        next_channel_message_id: Mutex<i32>,
+    }
+
+    impl FakeDb {
+        pub fn new(background: Arc<Background>) -> Self {
+            Self {
+                background,
+                users: Default::default(),
+                next_user_id: Mutex::new(1),
+                orgs: Default::default(),
+                next_org_id: Mutex::new(1),
+                org_memberships: Default::default(),
+                channels: Default::default(),
+                next_channel_id: Mutex::new(1),
+                channel_memberships: Default::default(),
+                channel_messages: Default::default(),
+                next_channel_message_id: Mutex::new(1),
+            }
+        }
+    }
+
+    #[async_trait]
+    impl Db for FakeDb {
+        async fn create_signup(
+            &self,
+            _github_login: &str,
+            _email_address: &str,
+            _about: &str,
+            _wants_releases: bool,
+            _wants_updates: bool,
+            _wants_community: bool,
+        ) -> Result<SignupId> {
+            unimplemented!()
+        }
+
+        async fn get_all_signups(&self) -> Result<Vec<Signup>> {
+            unimplemented!()
+        }
+
+        async fn destroy_signup(&self, _id: SignupId) -> Result<()> {
+            unimplemented!()
+        }
+
+        async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
+            self.background.simulate_random_delay().await;
+
+            let mut users = self.users.lock();
+            if let Some(user) = users
+                .values()
+                .find(|user| user.github_login == github_login)
+            {
+                Ok(user.id)
+            } else {
+                let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
+                users.insert(
+                    user_id,
+                    User {
+                        id: user_id,
+                        github_login: github_login.to_string(),
+                        admin,
+                    },
+                );
+                Ok(user_id)
+            }
+        }
+
+        async fn get_all_users(&self) -> Result<Vec<User>> {
+            unimplemented!()
+        }
+
+        async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+            Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
+        }
+
+        async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
+            self.background.simulate_random_delay().await;
+            let users = self.users.lock();
+            Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
+        }
+
+        async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
+            unimplemented!()
+        }
+
+        async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
+            unimplemented!()
+        }
+
+        async fn destroy_user(&self, _id: UserId) -> Result<()> {
+            unimplemented!()
+        }
+
+        async fn create_access_token_hash(
+            &self,
+            _user_id: UserId,
+            _access_token_hash: &str,
+            _max_access_token_count: usize,
+        ) -> Result<()> {
+            unimplemented!()
+        }
+
+        async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
+            unimplemented!()
+        }
+
+        async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
+            unimplemented!()
+        }
+
+        async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+            self.background.simulate_random_delay().await;
+            let mut orgs = self.orgs.lock();
+            if orgs.values().any(|org| org.slug == slug) {
+                Err(anyhow!("org already exists"))
+            } else {
+                let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
+                orgs.insert(
+                    org_id,
+                    Org {
+                        id: org_id,
+                        name: name.to_string(),
+                        slug: slug.to_string(),
+                    },
+                );
+                Ok(org_id)
+            }
+        }
+
+        async fn add_org_member(
+            &self,
+            org_id: OrgId,
+            user_id: UserId,
+            is_admin: bool,
+        ) -> Result<()> {
+            self.background.simulate_random_delay().await;
+            if !self.orgs.lock().contains_key(&org_id) {
+                return Err(anyhow!("org does not exist"));
+            }
+            if !self.users.lock().contains_key(&user_id) {
+                return Err(anyhow!("user does not exist"));
+            }
+
+            self.org_memberships
+                .lock()
+                .entry((org_id, user_id))
+                .or_insert(is_admin);
+            Ok(())
+        }
+
+        async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+            self.background.simulate_random_delay().await;
+            if !self.orgs.lock().contains_key(&org_id) {
+                return Err(anyhow!("org does not exist"));
+            }
+
+            let mut channels = self.channels.lock();
+            let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
+            channels.insert(
+                channel_id,
+                Channel {
+                    id: channel_id,
+                    name: name.to_string(),
+                    owner_id: org_id.0,
+                    owner_is_user: false,
+                },
+            );
+            Ok(channel_id)
+        }
+
+        async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
+            self.background.simulate_random_delay().await;
+            Ok(self
+                .channels
+                .lock()
+                .values()
+                .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
+                .cloned()
+                .collect())
+        }
+
+        async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
+            self.background.simulate_random_delay().await;
+            let channels = self.channels.lock();
+            let memberships = self.channel_memberships.lock();
+            Ok(channels
+                .values()
+                .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
+                .cloned()
+                .collect())
+        }
+
+        async fn can_user_access_channel(
+            &self,
+            user_id: UserId,
+            channel_id: ChannelId,
+        ) -> Result<bool> {
+            self.background.simulate_random_delay().await;
+            Ok(self
+                .channel_memberships
+                .lock()
+                .contains_key(&(channel_id, user_id)))
+        }
+
+        async fn add_channel_member(
+            &self,
+            channel_id: ChannelId,
+            user_id: UserId,
+            is_admin: bool,
+        ) -> Result<()> {
+            self.background.simulate_random_delay().await;
+            if !self.channels.lock().contains_key(&channel_id) {
+                return Err(anyhow!("channel does not exist"));
+            }
+            if !self.users.lock().contains_key(&user_id) {
+                return Err(anyhow!("user does not exist"));
+            }
+
+            self.channel_memberships
+                .lock()
+                .entry((channel_id, user_id))
+                .or_insert(is_admin);
+            Ok(())
+        }
+
+        async fn create_channel_message(
+            &self,
+            channel_id: ChannelId,
+            sender_id: UserId,
+            body: &str,
+            timestamp: OffsetDateTime,
+            nonce: u128,
+        ) -> Result<MessageId> {
+            self.background.simulate_random_delay().await;
+            if !self.channels.lock().contains_key(&channel_id) {
+                return Err(anyhow!("channel does not exist"));
+            }
+            if !self.users.lock().contains_key(&sender_id) {
+                return Err(anyhow!("user does not exist"));
+            }
+
+            let mut messages = self.channel_messages.lock();
+            if let Some(message) = messages
+                .values()
+                .find(|message| message.nonce.as_u128() == nonce)
+            {
+                Ok(message.id)
+            } else {
+                let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
+                messages.insert(
+                    message_id,
+                    ChannelMessage {
+                        id: message_id,
+                        channel_id,
+                        sender_id,
+                        body: body.to_string(),
+                        sent_at: timestamp,
+                        nonce: Uuid::from_u128(nonce),
+                    },
+                );
+                Ok(message_id)
+            }
+        }
+
+        async fn get_channel_messages(
+            &self,
+            channel_id: ChannelId,
+            count: usize,
+            before_id: Option<MessageId>,
+        ) -> Result<Vec<ChannelMessage>> {
+            let mut messages = self
+                .channel_messages
+                .lock()
+                .values()
+                .rev()
+                .filter(|message| {
+                    message.channel_id == channel_id
+                        && message.id < before_id.unwrap_or(MessageId::MAX)
+                })
+                .take(count)
+                .cloned()
+                .collect::<Vec<_>>();
+            dbg!(count, before_id, &messages);
+            messages.sort_unstable_by_key(|message| message.id);
+            Ok(messages)
+        }
+
+        async fn teardown(&self, _name: &str, _url: &str) {}
+    }
 }

crates/server/src/main.rs 🔗

@@ -20,7 +20,7 @@ use anyhow::Result;
 use async_std::net::TcpListener;
 use async_trait::async_trait;
 use auth::RequestExt as _;
-use db::Db;
+use db::{Db, PostgresDb};
 use handlebars::{Handlebars, TemplateRenderError};
 use parking_lot::RwLock;
 use rust_embed::RustEmbed;
@@ -49,7 +49,7 @@ pub struct Config {
 }
 
 pub struct AppState {
-    db: Db,
+    db: Arc<dyn Db>,
     handlebars: RwLock<Handlebars<'static>>,
     auth_client: auth::Client,
     github_client: Arc<github::AppClient>,
@@ -59,7 +59,7 @@ pub struct AppState {
 
 impl AppState {
     async fn new(config: Config) -> tide::Result<Arc<Self>> {
-        let db = Db::new(&config.database_url, 5).await?;
+        let db = PostgresDb::new(&config.database_url, 5).await?;
         let github_client =
             github::AppClient::new(config.github_app_id, config.github_private_key.clone());
         let repo_client = github_client
@@ -68,7 +68,7 @@ impl AppState {
             .context("failed to initialize github client")?;
 
         let this = Self {
-            db,
+            db: Arc::new(db),
             handlebars: Default::default(),
             auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret),
             github_client,
@@ -112,7 +112,7 @@ impl AppState {
 #[async_trait]
 trait RequestExt {
     async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
-    fn db(&self) -> &Db;
+    fn db(&self) -> &Arc<dyn Db>;
 }
 
 #[async_trait]
@@ -126,7 +126,7 @@ impl RequestExt for Request {
         Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
     }
 
-    fn db(&self) -> &Db {
+    fn db(&self) -> &Arc<dyn Db> {
         &self.state().db
     }
 }

crates/server/src/rpc.rs 🔗

@@ -785,7 +785,12 @@ impl Server {
         self: Arc<Server>,
         request: TypedEnvelope<proto::GetUsers>,
     ) -> tide::Result<proto::GetUsersResponse> {
-        let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
+        let user_ids = request
+            .payload
+            .user_ids
+            .into_iter()
+            .map(UserId::from_proto)
+            .collect();
         let users = self
             .app_state
             .db
@@ -1139,18 +1144,14 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_share_project(
-        mut cx_a: TestAppContext,
-        mut cx_b: TestAppContext,
-        last_iteration: bool,
-    ) {
+    async fn test_share_project(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         let (window_b, _) = cx_b.add_window(|_| EmptyView);
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
         cx_a.foreground().forbid_parking();
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -1282,17 +1283,13 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_unshare_project(
-        mut cx_a: TestAppContext,
-        mut cx_b: TestAppContext,
-        last_iteration: bool,
-    ) {
+    async fn test_unshare_project(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
         cx_a.foreground().forbid_parking();
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -1387,14 +1384,13 @@ mod tests {
         mut cx_a: TestAppContext,
         mut cx_b: TestAppContext,
         mut cx_c: TestAppContext,
-        last_iteration: bool,
     ) {
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
         cx_a.foreground().forbid_parking();
 
         // Connect to a server as 3 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
         let client_c = server.create_client(&mut cx_c, "user_c").await;
@@ -1566,17 +1562,13 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_buffer_conflict_after_save(
-        mut cx_a: TestAppContext,
-        mut cx_b: TestAppContext,
-        last_iteration: bool,
-    ) {
+    async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         cx_a.foreground().forbid_parking();
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -1658,17 +1650,13 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_buffer_reloading(
-        mut cx_a: TestAppContext,
-        mut cx_b: TestAppContext,
-        last_iteration: bool,
-    ) {
+    async fn test_buffer_reloading(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         cx_a.foreground().forbid_parking();
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -1747,14 +1735,13 @@ mod tests {
     async fn test_editing_while_guest_opens_buffer(
         mut cx_a: TestAppContext,
         mut cx_b: TestAppContext,
-        last_iteration: bool,
     ) {
         cx_a.foreground().forbid_parking();
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -1830,14 +1817,13 @@ mod tests {
     async fn test_leaving_worktree_while_opening_buffer(
         mut cx_a: TestAppContext,
         mut cx_b: TestAppContext,
-        last_iteration: bool,
     ) {
         cx_a.foreground().forbid_parking();
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -1906,17 +1892,13 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_peer_disconnection(
-        mut cx_a: TestAppContext,
-        mut cx_b: TestAppContext,
-        last_iteration: bool,
-    ) {
+    async fn test_peer_disconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         cx_a.foreground().forbid_parking();
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -1984,7 +1966,6 @@ mod tests {
     async fn test_collaborating_with_diagnostics(
         mut cx_a: TestAppContext,
         mut cx_b: TestAppContext,
-        last_iteration: bool,
     ) {
         cx_a.foreground().forbid_parking();
         let mut lang_registry = Arc::new(LanguageRegistry::new());
@@ -2005,7 +1986,7 @@ mod tests {
             )));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -2209,7 +2190,6 @@ mod tests {
     async fn test_collaborating_with_completion(
         mut cx_a: TestAppContext,
         mut cx_b: TestAppContext,
-        last_iteration: bool,
     ) {
         cx_a.foreground().forbid_parking();
         let mut lang_registry = Arc::new(LanguageRegistry::new());
@@ -2237,7 +2217,7 @@ mod tests {
             )));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -2419,11 +2399,7 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_formatting_buffer(
-        mut cx_a: TestAppContext,
-        mut cx_b: TestAppContext,
-        last_iteration: bool,
-    ) {
+    async fn test_formatting_buffer(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         cx_a.foreground().forbid_parking();
         let mut lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
@@ -2443,7 +2419,7 @@ mod tests {
             )));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -2525,11 +2501,7 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_definition(
-        mut cx_a: TestAppContext,
-        mut cx_b: TestAppContext,
-        last_iteration: bool,
-    ) {
+    async fn test_definition(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         cx_a.foreground().forbid_parking();
         let mut lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
@@ -2564,7 +2536,7 @@ mod tests {
             )));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -2682,7 +2654,6 @@ mod tests {
         mut cx_a: TestAppContext,
         mut cx_b: TestAppContext,
         mut rng: StdRng,
-        last_iteration: bool,
     ) {
         cx_a.foreground().forbid_parking();
         let mut lang_registry = Arc::new(LanguageRegistry::new());
@@ -2713,7 +2684,7 @@ mod tests {
             )));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -2792,7 +2763,6 @@ mod tests {
     async fn test_collaborating_with_code_actions(
         mut cx_a: TestAppContext,
         mut cx_b: TestAppContext,
-        last_iteration: bool,
     ) {
         cx_a.foreground().forbid_parking();
         let mut lang_registry = Arc::new(LanguageRegistry::new());
@@ -2815,7 +2785,7 @@ mod tests {
             )));
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -3032,15 +3002,11 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_basic_chat(
-        mut cx_a: TestAppContext,
-        mut cx_b: TestAppContext,
-        last_iteration: bool,
-    ) {
+    async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         cx_a.foreground().forbid_parking();
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
 
@@ -3176,10 +3142,10 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_chat_message_validation(mut cx_a: TestAppContext, last_iteration: bool) {
+    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
         cx_a.foreground().forbid_parking();
 
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
 
         let db = &server.app_state.db;
@@ -3236,15 +3202,11 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_chat_reconnection(
-        mut cx_a: TestAppContext,
-        mut cx_b: TestAppContext,
-        last_iteration: bool,
-    ) {
+    async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
         cx_a.foreground().forbid_parking();
 
         // Connect to a server as 2 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
         let mut status_b = client_b.status();
@@ -3456,14 +3418,13 @@ mod tests {
         mut cx_a: TestAppContext,
         mut cx_b: TestAppContext,
         mut cx_c: TestAppContext,
-        last_iteration: bool,
     ) {
         cx_a.foreground().forbid_parking();
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(FakeFs::new(cx_a.background()));
 
         // Connect to a server as 3 clients.
-        let mut server = TestServer::start(cx_a.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
         let client_a = server.create_client(&mut cx_a, "user_a").await;
         let client_b = server.create_client(&mut cx_b, "user_b").await;
         let client_c = server.create_client(&mut cx_c, "user_c").await;
@@ -3595,7 +3556,7 @@ mod tests {
     }
 
     #[gpui::test(iterations = 100)]
-    async fn test_random_collaboration(cx: TestAppContext, rng: StdRng, last_iteration: bool) {
+    async fn test_random_collaboration(cx: TestAppContext, rng: StdRng) {
         cx.foreground().forbid_parking();
         let max_peers = env::var("MAX_PEERS")
             .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
@@ -3654,7 +3615,7 @@ mod tests {
         .await;
 
         let operations = Rc::new(Cell::new(0));
-        let mut server = TestServer::start(cx.foreground(), last_iteration).await;
+        let mut server = TestServer::start(cx.foreground(), cx.background()).await;
         let mut clients = Vec::new();
 
         let mut next_entity_id = 100000;
@@ -3849,9 +3810,11 @@ mod tests {
     }
 
     impl TestServer {
-        async fn start(foreground: Rc<executor::Foreground>, clean_db_pool_on_drop: bool) -> Self {
-            let mut test_db = TestDb::new();
-            test_db.set_clean_pool_on_drop(clean_db_pool_on_drop);
+        async fn start(
+            foreground: Rc<executor::Foreground>,
+            background: Arc<executor::Background>,
+        ) -> Self {
+            let test_db = TestDb::fake(background);
             let app_state = Self::build_app_state(&test_db).await;
             let peer = Peer::new();
             let notifications = mpsc::unbounded();