Make database interactions deterministic in test

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

gpui/src/app.rs      |  20 -
gpui/src/executor.rs |  14 
server/src/db.rs     | 468 +++++++++++++++++++++++++++------------------
server/src/main.rs   |  11 
server/src/rpc.rs    | 112 ++++------
5 files changed, 339 insertions(+), 286 deletions(-)

Detailed changes

gpui/src/app.rs 🔗

@@ -14,7 +14,7 @@ use keymap::MatchResult;
 use parking_lot::{Mutex, RwLock};
 use pathfinder_geometry::{rect::RectF, vector::vec2f};
 use platform::Event;
-use postage::{mpsc, oneshot, sink::Sink as _, stream::Stream as _};
+use postage::{mpsc, sink::Sink as _, stream::Stream as _};
 use smol::prelude::*;
 use std::{
     any::{type_name, Any, TypeId},
@@ -2310,24 +2310,6 @@ impl<T: Entity> ModelHandle<T> {
         cx.update_model(self, update)
     }
 
-    pub fn next_notification(&self, cx: &TestAppContext) -> impl Future<Output = ()> {
-        let (tx, mut rx) = oneshot::channel();
-        let mut tx = Some(tx);
-
-        let mut cx = cx.cx.borrow_mut();
-        self.update(&mut *cx, |_, cx| {
-            cx.observe(self, move |_, _, _| {
-                if let Some(mut tx) = tx.take() {
-                    tx.blocking_send(()).ok();
-                }
-            });
-        });
-
-        async move {
-            rx.recv().await;
-        }
-    }
-
     pub fn condition(
         &self,
         cx: &TestAppContext,

gpui/src/executor.rs 🔗

@@ -122,9 +122,14 @@ impl Deterministic {
         smol::pin!(future);
 
         let unparker = self.parker.lock().unparker();
-        let waker = waker_fn(move || {
-            unparker.unpark();
-        });
+        let woken = Arc::new(AtomicBool::new(false));
+        let waker = {
+            let woken = woken.clone();
+            waker_fn(move || {
+                woken.store(true, SeqCst);
+                unparker.unpark();
+            })
+        };
 
         let mut cx = Context::from_waker(&waker);
         let mut trace = Trace::default();
@@ -166,10 +171,11 @@ impl Deterministic {
                     && state.scheduled_from_background.is_empty()
                     && state.spawned_from_foreground.is_empty()
                 {
-                    if state.forbid_parking {
+                    if state.forbid_parking && !woken.load(SeqCst) {
                         panic!("deterministic executor parked after a call to forbid_parking");
                     }
                     drop(state);
+                    woken.store(false, SeqCst);
                     self.parker.lock().park();
                 }
 

server/src/db.rs 🔗

@@ -1,3 +1,5 @@
+use anyhow::Context;
+use async_std::task::{block_on, yield_now};
 use serde::Serialize;
 use sqlx::{FromRow, Result};
 use time::OffsetDateTime;
@@ -5,7 +7,24 @@ use time::OffsetDateTime;
 pub use async_sqlx_session::PostgresSessionStore as SessionStore;
 pub use sqlx::postgres::PgPoolOptions as DbOptions;
 
-pub struct Db(pub sqlx::PgPool);
+macro_rules! test_support {
+    ($self:ident, { $($token:tt)* }) => {{
+        let body = async {
+            $($token)*
+        };
+        if $self.test_mode {
+            yield_now().await;
+            block_on(body)
+        } else {
+            body.await
+        }
+    }};
+}
+
+pub struct Db {
+    db: sqlx::PgPool,
+    test_mode: bool,
+}
 
 #[derive(Debug, FromRow, Serialize)]
 pub struct User {
@@ -37,6 +56,33 @@ pub struct ChannelMessage {
 }
 
 impl Db {
+    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
+        let db = DbOptions::new()
+            .max_connections(max_connections)
+            .connect(url)
+            .await
+            .context("failed to connect to postgres database")?;
+        Ok(Self {
+            db,
+            test_mode: false,
+        })
+    }
+
+    #[cfg(test)]
+    pub fn test(url: &str, max_connections: u32) -> Self {
+        let mut db = block_on(Self::new(url, max_connections)).unwrap();
+        db.test_mode = true;
+        db
+    }
+
+    #[cfg(test)]
+    pub fn migrate(&self, path: &std::path::Path) {
+        block_on(async {
+            let migrator = sqlx::migrate::Migrator::new(path).await.unwrap();
+            migrator.run(&self.db).await.unwrap();
+        });
+    }
+
     // signups
 
     pub async fn create_signup(
@@ -45,53 +91,63 @@ impl Db {
         email_address: &str,
         about: &str,
     ) -> Result<SignupId> {
-        let query = "
-            INSERT INTO signups (github_login, email_address, about)
-            VALUES ($1, $2, $3)
-            RETURNING id
-        ";
-        sqlx::query_scalar(query)
-            .bind(github_login)
-            .bind(email_address)
-            .bind(about)
-            .fetch_one(&self.0)
-            .await
-            .map(SignupId)
+        test_support!(self, {
+            let query = "
+                INSERT INTO signups (github_login, email_address, about)
+                VALUES ($1, $2, $3)
+                RETURNING id
+            ";
+            sqlx::query_scalar(query)
+                .bind(github_login)
+                .bind(email_address)
+                .bind(about)
+                .fetch_one(&self.db)
+                .await
+                .map(SignupId)
+        })
     }
 
     pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
-        let query = "SELECT * FROM users ORDER BY github_login ASC";
-        sqlx::query_as(query).fetch_all(&self.0).await
+        test_support!(self, {
+            let query = "SELECT * FROM users ORDER BY github_login ASC";
+            sqlx::query_as(query).fetch_all(&self.db).await
+        })
     }
 
     pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
-        let query = "DELETE FROM signups WHERE id = $1";
-        sqlx::query(query)
-            .bind(id.0)
-            .execute(&self.0)
-            .await
-            .map(drop)
+        test_support!(self, {
+            let query = "DELETE FROM signups WHERE id = $1";
+            sqlx::query(query)
+                .bind(id.0)
+                .execute(&self.db)
+                .await
+                .map(drop)
+        })
     }
 
     // users
 
     pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
-        let query = "
-            INSERT INTO users (github_login, admin)
-            VALUES ($1, $2)
-            RETURNING id
-        ";
-        sqlx::query_scalar(query)
-            .bind(github_login)
-            .bind(admin)
-            .fetch_one(&self.0)
-            .await
-            .map(UserId)
+        test_support!(self, {
+            let query = "
+                INSERT INTO users (github_login, admin)
+                VALUES ($1, $2)
+                RETURNING id
+            ";
+            sqlx::query_scalar(query)
+                .bind(github_login)
+                .bind(admin)
+                .fetch_one(&self.db)
+                .await
+                .map(UserId)
+        })
     }
 
     pub async fn get_all_users(&self) -> Result<Vec<User>> {
-        let query = "SELECT * FROM users ORDER BY github_login ASC";
-        sqlx::query_as(query).fetch_all(&self.0).await
+        test_support!(self, {
+            let query = "SELECT * FROM users ORDER BY github_login ASC";
+            sqlx::query_as(query).fetch_all(&self.db).await
+        })
     }
 
     pub async fn get_users_by_ids(
@@ -99,53 +155,61 @@ impl Db {
         requester_id: UserId,
         ids: impl Iterator<Item = UserId>,
     ) -> Result<Vec<User>> {
-        // Only return users that are in a common channel with the requesting user.
-        let query = "
-            SELECT users.*
-            FROM
-                users, channel_memberships
-            WHERE
-                users.id IN $1 AND
-                channel_memberships.user_id = users.id AND
-                channel_memberships.channel_id IN (
-                    SELECT channel_id
-                    FROM channel_memberships
-                    WHERE channel_memberships.user_id = $2
-                )
-        ";
-
-        sqlx::query_as(query)
-            .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
-            .bind(requester_id)
-            .fetch_all(&self.0)
-            .await
+        test_support!(self, {
+            // Only return users that are in a common channel with the requesting user.
+            let query = "
+                SELECT users.*
+                FROM
+                    users, channel_memberships
+                WHERE
+                    users.id IN $1 AND
+                    channel_memberships.user_id = users.id AND
+                    channel_memberships.channel_id IN (
+                        SELECT channel_id
+                        FROM channel_memberships
+                        WHERE channel_memberships.user_id = $2
+                    )
+            ";
+
+            sqlx::query_as(query)
+                .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
+                .bind(requester_id)
+                .fetch_all(&self.db)
+                .await
+        })
     }
 
     pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
-        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
-        sqlx::query_as(query)
-            .bind(github_login)
-            .fetch_optional(&self.0)
-            .await
+        test_support!(self, {
+            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
+            sqlx::query_as(query)
+                .bind(github_login)
+                .fetch_optional(&self.db)
+                .await
+        })
     }
 
     pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
-        let query = "UPDATE users SET admin = $1 WHERE id = $2";
-        sqlx::query(query)
-            .bind(is_admin)
-            .bind(id.0)
-            .execute(&self.0)
-            .await
-            .map(drop)
+        test_support!(self, {
+            let query = "UPDATE users SET admin = $1 WHERE id = $2";
+            sqlx::query(query)
+                .bind(is_admin)
+                .bind(id.0)
+                .execute(&self.db)
+                .await
+                .map(drop)
+        })
     }
 
     pub async fn delete_user(&self, id: UserId) -> Result<()> {
-        let query = "DELETE FROM users WHERE id = $1;";
-        sqlx::query(query)
-            .bind(id.0)
-            .execute(&self.0)
-            .await
-            .map(drop)
+        test_support!(self, {
+            let query = "DELETE FROM users WHERE id = $1;";
+            sqlx::query(query)
+                .bind(id.0)
+                .execute(&self.db)
+                .await
+                .map(drop)
+        })
     }
 
     // access tokens
@@ -155,41 +219,47 @@ impl Db {
         user_id: UserId,
         access_token_hash: String,
     ) -> Result<()> {
-        let query = "
+        test_support!(self, {
+            let query = "
             INSERT INTO access_tokens (user_id, hash)
             VALUES ($1, $2)
         ";
-        sqlx::query(query)
-            .bind(user_id.0)
-            .bind(access_token_hash)
-            .execute(&self.0)
-            .await
-            .map(drop)
+            sqlx::query(query)
+                .bind(user_id.0)
+                .bind(access_token_hash)
+                .execute(&self.db)
+                .await
+                .map(drop)
+        })
     }
 
     pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
-        let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
-        sqlx::query_scalar(query)
-            .bind(user_id.0)
-            .fetch_all(&self.0)
-            .await
+        test_support!(self, {
+            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
+            sqlx::query_scalar(query)
+                .bind(user_id.0)
+                .fetch_all(&self.db)
+                .await
+        })
     }
 
     // orgs
 
     #[cfg(test)]
     pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
-        let query = "
-            INSERT INTO orgs (name, slug)
-            VALUES ($1, $2)
-            RETURNING id
-        ";
-        sqlx::query_scalar(query)
-            .bind(name)
-            .bind(slug)
-            .fetch_one(&self.0)
-            .await
-            .map(OrgId)
+        test_support!(self, {
+            let query = "
+                INSERT INTO orgs (name, slug)
+                VALUES ($1, $2)
+                RETURNING id
+            ";
+            sqlx::query_scalar(query)
+                .bind(name)
+                .bind(slug)
+                .fetch_one(&self.db)
+                .await
+                .map(OrgId)
+        })
     }
 
     #[cfg(test)]
@@ -199,50 +269,56 @@ impl Db {
         user_id: UserId,
         is_admin: bool,
     ) -> Result<()> {
-        let query = "
-            INSERT INTO org_memberships (org_id, user_id, admin)
-            VALUES ($1, $2, $3)
-        ";
-        sqlx::query(query)
-            .bind(org_id.0)
-            .bind(user_id.0)
-            .bind(is_admin)
-            .execute(&self.0)
-            .await
-            .map(drop)
+        test_support!(self, {
+            let query = "
+                INSERT INTO org_memberships (org_id, user_id, admin)
+                VALUES ($1, $2, $3)
+            ";
+            sqlx::query(query)
+                .bind(org_id.0)
+                .bind(user_id.0)
+                .bind(is_admin)
+                .execute(&self.db)
+                .await
+                .map(drop)
+        })
     }
 
     // channels
 
     #[cfg(test)]
     pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
-        let query = "
-            INSERT INTO channels (owner_id, owner_is_user, name)
-            VALUES ($1, false, $2)
-            RETURNING id
-        ";
-        sqlx::query_scalar(query)
-            .bind(org_id.0)
-            .bind(name)
-            .fetch_one(&self.0)
-            .await
-            .map(ChannelId)
+        test_support!(self, {
+            let query = "
+                INSERT INTO channels (owner_id, owner_is_user, name)
+                VALUES ($1, false, $2)
+                RETURNING id
+            ";
+            sqlx::query_scalar(query)
+                .bind(org_id.0)
+                .bind(name)
+                .fetch_one(&self.db)
+                .await
+                .map(ChannelId)
+        })
     }
 
     pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
-        let query = "
-            SELECT
-                channels.id, channels.name
-            FROM
-                channel_memberships, channels
-            WHERE
-                channel_memberships.user_id = $1 AND
-                channel_memberships.channel_id = channels.id
-        ";
-        sqlx::query_as(query)
-            .bind(user_id.0)
-            .fetch_all(&self.0)
-            .await
+        test_support!(self, {
+            let query = "
+                SELECT
+                    channels.id, channels.name
+                FROM
+                    channel_memberships, channels
+                WHERE
+                    channel_memberships.user_id = $1 AND
+                    channel_memberships.channel_id = channels.id
+            ";
+            sqlx::query_as(query)
+                .bind(user_id.0)
+                .fetch_all(&self.db)
+                .await
+        })
     }
 
     pub async fn can_user_access_channel(
@@ -250,18 +326,20 @@ impl Db {
         user_id: UserId,
         channel_id: ChannelId,
     ) -> Result<bool> {
-        let query = "
-            SELECT id
-            FROM channel_memberships
-            WHERE user_id = $1 AND channel_id = $2
-            LIMIT 1
-        ";
-        sqlx::query_scalar::<_, i32>(query)
-            .bind(user_id.0)
-            .bind(channel_id.0)
-            .fetch_optional(&self.0)
-            .await
-            .map(|e| e.is_some())
+        test_support!(self, {
+            let query = "
+                SELECT id
+                FROM channel_memberships
+                WHERE user_id = $1 AND channel_id = $2
+                LIMIT 1
+            ";
+            sqlx::query_scalar::<_, i32>(query)
+                .bind(user_id.0)
+                .bind(channel_id.0)
+                .fetch_optional(&self.db)
+                .await
+                .map(|e| e.is_some())
+        })
     }
 
     #[cfg(test)]
@@ -271,17 +349,19 @@ impl Db {
         user_id: UserId,
         is_admin: bool,
     ) -> Result<()> {
-        let query = "
-            INSERT INTO channel_memberships (channel_id, user_id, admin)
-            VALUES ($1, $2, $3)
-        ";
-        sqlx::query(query)
-            .bind(channel_id.0)
-            .bind(user_id.0)
-            .bind(is_admin)
-            .execute(&self.0)
-            .await
-            .map(drop)
+        test_support!(self, {
+            let query = "
+                INSERT INTO channel_memberships (channel_id, user_id, admin)
+                VALUES ($1, $2, $3)
+            ";
+            sqlx::query(query)
+                .bind(channel_id.0)
+                .bind(user_id.0)
+                .bind(is_admin)
+                .execute(&self.db)
+                .await
+                .map(drop)
+        })
     }
 
     // messages
@@ -293,19 +373,21 @@ impl Db {
         body: &str,
         timestamp: OffsetDateTime,
     ) -> Result<MessageId> {
-        let query = "
-            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
-            VALUES ($1, $2, $3, $4)
-            RETURNING id
-        ";
-        sqlx::query_scalar(query)
-            .bind(channel_id.0)
-            .bind(sender_id.0)
-            .bind(body)
-            .bind(timestamp)
-            .fetch_one(&self.0)
-            .await
-            .map(MessageId)
+        test_support!(self, {
+            let query = "
+                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
+                VALUES ($1, $2, $3, $4)
+                RETURNING id
+            ";
+            sqlx::query_scalar(query)
+                .bind(channel_id.0)
+                .bind(sender_id.0)
+                .bind(body)
+                .bind(timestamp)
+                .fetch_one(&self.db)
+                .await
+                .map(MessageId)
+        })
     }
 
     pub async fn get_recent_channel_messages(
@@ -313,35 +395,39 @@ impl Db {
         channel_id: ChannelId,
         count: usize,
     ) -> Result<Vec<ChannelMessage>> {
-        let query = r#"
-            SELECT
-                id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
-            FROM
-                channel_messages
-            WHERE
-                channel_id = $1
-            LIMIT $2
-        "#;
-        sqlx::query_as(query)
-            .bind(channel_id.0)
-            .bind(count as i64)
-            .fetch_all(&self.0)
-            .await
+        test_support!(self, {
+            let query = r#"
+                SELECT
+                    id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
+                FROM
+                    channel_messages
+                WHERE
+                    channel_id = $1
+                LIMIT $2
+            "#;
+            sqlx::query_as(query)
+                .bind(channel_id.0)
+                .bind(count as i64)
+                .fetch_all(&self.db)
+                .await
+        })
     }
 
     #[cfg(test)]
     pub async fn close(&self, db_name: &str) {
-        let query = "
-            SELECT pg_terminate_backend(pg_stat_activity.pid)
-            FROM pg_stat_activity
-            WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
-        ";
-        sqlx::query(query)
-            .bind(db_name)
-            .execute(&self.0)
-            .await
-            .unwrap();
-        self.0.close().await;
+        test_support!(self, {
+            let query = "
+                SELECT pg_terminate_backend(pg_stat_activity.pid)
+                FROM pg_stat_activity
+                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
+            ";
+            sqlx::query(query)
+                .bind(db_name)
+                .execute(&self.db)
+                .await
+                .unwrap();
+            self.db.close().await;
+        })
     }
 }
 

server/src/main.rs 🔗

@@ -11,11 +11,11 @@ mod rpc;
 mod team;
 
 use self::errors::TideResultExt as _;
-use anyhow::{Context, Result};
+use anyhow::Result;
 use async_std::net::TcpListener;
 use async_trait::async_trait;
 use auth::RequestExt as _;
-use db::{Db, DbOptions};
+use db::Db;
 use handlebars::{Handlebars, TemplateRenderError};
 use parking_lot::RwLock;
 use rust_embed::RustEmbed;
@@ -54,12 +54,7 @@ pub struct AppState {
 
 impl AppState {
     async fn new(config: Config) -> tide::Result<Arc<Self>> {
-        let db = Db(DbOptions::new()
-            .max_connections(5)
-            .connect(&config.database_url)
-            .await
-            .context("failed to connect to postgres database")?);
-
+        let db = Db::new(&config.database_url, 5).await?;
         let github_client =
             github::AppClient::new(config.github_app_id, config.github_private_key.clone());
         let repo_client = github_client

server/src/rpc.rs 🔗

@@ -922,16 +922,15 @@ mod tests {
         db::{self, UserId},
         github, AppState, Config,
     };
-    use async_std::{sync::RwLockReadGuard, task};
-    use gpui::{ModelHandle, TestAppContext};
+    use async_std::{
+        sync::RwLockReadGuard,
+        task::{self, block_on},
+    };
+    use gpui::TestAppContext;
     use postage::mpsc;
     use rand::prelude::*;
     use serde_json::json;
-    use sqlx::{
-        migrate::{MigrateDatabase, Migrator},
-        types::time::OffsetDateTime,
-        Postgres,
-    };
+    use sqlx::{migrate::MigrateDatabase, types::time::OffsetDateTime, Postgres};
     use std::{path::Path, sync::Arc, time::Duration};
     use zed::{
         channel::{Channel, ChannelDetails, ChannelList},
@@ -1400,6 +1399,8 @@ mod tests {
 
     #[gpui::test]
     async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
+        cx_a.foreground().forbid_parking();
+
         // Connect to a server as 2 clients.
         let mut server = TestServer::start().await;
         let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
@@ -1444,11 +1445,12 @@ mod tests {
             this.get_channel(channel_id.to_proto(), cx).unwrap()
         });
         channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
-        channel_a.next_notification(&cx_a).await;
-        assert_eq!(
-            channel_messages(&channel_a, &cx_a),
-            &[(user_id_b.to_proto(), "hello A, it's B.".to_string())]
-        );
+        channel_a
+            .condition(&cx_a, |channel, _| {
+                channel_messages(channel)
+                    == [(user_id_b.to_proto(), "hello A, it's B.".to_string())]
+            })
+            .await;
 
         let channels_b = ChannelList::new(client_b, &mut cx_b.to_async())
             .await
@@ -1462,15 +1464,17 @@ mod tests {
                 }]
             )
         });
+
         let channel_b = channels_b.update(&mut cx_b, |this, cx| {
             this.get_channel(channel_id.to_proto(), cx).unwrap()
         });
         channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
-        channel_b.next_notification(&cx_b).await;
-        assert_eq!(
-            channel_messages(&channel_b, &cx_b),
-            &[(user_id_b.to_proto(), "hello A, it's B.".to_string())]
-        );
+        channel_b
+            .condition(&cx_b, |channel, _| {
+                channel_messages(channel)
+                    == [(user_id_b.to_proto(), "hello A, it's B.".to_string())]
+            })
+            .await;
 
         channel_a.update(&mut cx_a, |channel, cx| {
             channel.send_message("oh, hi B.".to_string(), cx).unwrap();
@@ -1484,24 +1488,20 @@ mod tests {
                 &["oh, hi B.", "sup"]
             )
         });
-        channel_a.next_notification(&cx_a).await;
-        channel_a.read_with(&cx_a, |channel, _| {
-            assert_eq!(channel.pending_messages().len(), 1);
-        });
-        channel_a.next_notification(&cx_a).await;
-        channel_a.read_with(&cx_a, |channel, _| {
-            assert_eq!(channel.pending_messages().len(), 0);
-        });
 
-        channel_b.next_notification(&cx_b).await;
-        assert_eq!(
-            channel_messages(&channel_b, &cx_b),
-            &[
-                (user_id_b.to_proto(), "hello A, it's B.".to_string()),
-                (user_id_a.to_proto(), "oh, hi B.".to_string()),
-                (user_id_a.to_proto(), "sup".to_string()),
-            ]
-        );
+        channel_a
+            .condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
+            .await;
+        channel_b
+            .condition(&cx_b, |channel, _| {
+                channel_messages(channel)
+                    == [
+                        (user_id_b.to_proto(), "hello A, it's B.".to_string()),
+                        (user_id_a.to_proto(), "oh, hi B.".to_string()),
+                        (user_id_a.to_proto(), "sup".to_string()),
+                    ]
+            })
+            .await;
 
         assert_eq!(
             server.state().await.channels[&channel_id]
@@ -1519,17 +1519,12 @@ mod tests {
             .condition(|state| !state.channels.contains_key(&channel_id))
             .await;
 
-        fn channel_messages(
-            channel: &ModelHandle<Channel>,
-            cx: &TestAppContext,
-        ) -> Vec<(u64, String)> {
-            channel.read_with(cx, |channel, _| {
-                channel
-                    .messages()
-                    .iter()
-                    .map(|m| (m.sender_id, m.body.clone()))
-                    .collect()
-            })
+        fn channel_messages(channel: &Channel) -> Vec<(u64, String)> {
+            channel
+                .messages()
+                .iter()
+                .map(|m| (m.sender_id, m.body.clone()))
+                .collect()
         }
     }
 
@@ -1584,21 +1579,12 @@ mod tests {
             config.session_secret = "a".repeat(32);
             config.database_url = format!("postgres://postgres@localhost/{}", db_name);
 
-            Self::create_db(&config.database_url).await;
-            let db = db::Db(
-                db::DbOptions::new()
-                    .max_connections(5)
-                    .connect(&config.database_url)
-                    .await
-                    .expect("failed to connect to postgres database"),
-            );
-            let migrator = Migrator::new(Path::new(concat!(
+            Self::create_db(&config.database_url);
+            let db = db::Db::test(&config.database_url, 5);
+            db.migrate(Path::new(concat!(
                 env!("CARGO_MANIFEST_DIR"),
                 "/migrations"
-            )))
-            .await
-            .unwrap();
-            migrator.run(&db.0).await.unwrap();
+            )));
 
             let github_client = github::AppClient::test();
             Arc::new(AppState {
@@ -1611,16 +1597,14 @@ mod tests {
             })
         }
 
-        async fn create_db(url: &str) {
+        fn create_db(url: &str) {
             // Enable tests to run in parallel by serializing the creation of each test database.
             lazy_static::lazy_static! {
-                static ref DB_CREATION: async_std::sync::Mutex<()> = async_std::sync::Mutex::new(());
+                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
             }
 
-            let _lock = DB_CREATION.lock().await;
-            Postgres::create_database(url)
-                .await
-                .expect("failed to create test database");
+            let _lock = DB_CREATION.lock();
+            block_on(Postgres::create_database(url)).expect("failed to create test database");
         }
 
         async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> {