Simplify macro for running a test with both databases

Max Brunsfeld created

Change summary

crates/collab/src/db/db_tests.rs | 1014 +++++++++++++++++----------------
crates/collab/src/db/test_db.rs  |   17 
2 files changed, 526 insertions(+), 505 deletions(-)

Detailed changes

crates/collab/src/db/db_tests.rs 🔗

@@ -1,242 +1,234 @@
 use super::*;
+use crate::test_both_dbs;
 use gpui::executor::{Background, Deterministic};
 use pretty_assertions::{assert_eq, assert_ne};
 use std::sync::Arc;
 use test_db::TestDb;
 
-macro_rules! test_both_dbs {
-    ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => {
-        #[gpui::test]
-        async fn $postgres_test_name() {
-            let test_db = TestDb::postgres(Deterministic::new(0).build_background());
-            let $db = test_db.db();
-            $body
-        }
-
-        #[gpui::test]
-        async fn $sqlite_test_name() {
-            let test_db = TestDb::sqlite(Deterministic::new(0).build_background());
-            let $db = test_db.db();
-            $body
-        }
-    };
-}
-
 test_both_dbs!(
+    test_get_users,
     test_get_users_by_ids_postgres,
-    test_get_users_by_ids_sqlite,
-    db,
-    {
-        let mut user_ids = Vec::new();
-        let mut user_metric_ids = Vec::new();
-        for i in 1..=4 {
-            let user = db
-                .create_user(
-                    &format!("user{i}@example.com"),
-                    false,
-                    NewUserParams {
-                        github_login: format!("user{i}"),
-                        github_user_id: i,
-                        invite_count: 0,
-                    },
-                )
-                .await
-                .unwrap();
-            user_ids.push(user.user_id);
-            user_metric_ids.push(user.metrics_id);
-        }
-
-        assert_eq!(
-            db.get_users_by_ids(user_ids.clone()).await.unwrap(),
-            vec![
-                User {
-                    id: user_ids[0],
-                    github_login: "user1".to_string(),
-                    github_user_id: Some(1),
-                    email_address: Some("user1@example.com".to_string()),
-                    admin: false,
-                    metrics_id: user_metric_ids[0].parse().unwrap(),
-                    ..Default::default()
-                },
-                User {
-                    id: user_ids[1],
-                    github_login: "user2".to_string(),
-                    github_user_id: Some(2),
-                    email_address: Some("user2@example.com".to_string()),
-                    admin: false,
-                    metrics_id: user_metric_ids[1].parse().unwrap(),
-                    ..Default::default()
-                },
-                User {
-                    id: user_ids[2],
-                    github_login: "user3".to_string(),
-                    github_user_id: Some(3),
-                    email_address: Some("user3@example.com".to_string()),
-                    admin: false,
-                    metrics_id: user_metric_ids[2].parse().unwrap(),
-                    ..Default::default()
-                },
-                User {
-                    id: user_ids[3],
-                    github_login: "user4".to_string(),
-                    github_user_id: Some(4),
-                    email_address: Some("user4@example.com".to_string()),
-                    admin: false,
-                    metrics_id: user_metric_ids[3].parse().unwrap(),
-                    ..Default::default()
-                }
-            ]
-        );
-    }
+    test_get_users_by_ids_sqlite
 );
 
-test_both_dbs!(
-    test_get_or_create_user_by_github_account_postgres,
-    test_get_or_create_user_by_github_account_sqlite,
-    db,
-    {
-        let user_id1 = db
-            .create_user(
-                "user1@example.com",
-                false,
-                NewUserParams {
-                    github_login: "login1".into(),
-                    github_user_id: 101,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap()
-            .user_id;
-        let user_id2 = db
+async fn test_get_users(db: &Arc<Database>) {
+    let mut user_ids = Vec::new();
+    let mut user_metric_ids = Vec::new();
+    for i in 1..=4 {
+        let user = db
             .create_user(
-                "user2@example.com",
+                &format!("user{i}@example.com"),
                 false,
                 NewUserParams {
-                    github_login: "login2".into(),
-                    github_user_id: 102,
+                    github_login: format!("user{i}"),
+                    github_user_id: i,
                     invite_count: 0,
                 },
             )
             .await
-            .unwrap()
-            .user_id;
-
-        let user = db
-            .get_or_create_user_by_github_account("login1", None, None)
-            .await
-            .unwrap()
             .unwrap();
-        assert_eq!(user.id, user_id1);
-        assert_eq!(&user.github_login, "login1");
-        assert_eq!(user.github_user_id, Some(101));
-
-        assert!(db
-            .get_or_create_user_by_github_account("non-existent-login", None, None)
-            .await
-            .unwrap()
-            .is_none());
+        user_ids.push(user.user_id);
+        user_metric_ids.push(user.metrics_id);
+    }
 
-        let user = db
-            .get_or_create_user_by_github_account("the-new-login2", Some(102), None)
-            .await
-            .unwrap()
-            .unwrap();
-        assert_eq!(user.id, user_id2);
-        assert_eq!(&user.github_login, "the-new-login2");
-        assert_eq!(user.github_user_id, Some(102));
+    assert_eq!(
+        db.get_users_by_ids(user_ids.clone()).await.unwrap(),
+        vec![
+            User {
+                id: user_ids[0],
+                github_login: "user1".to_string(),
+                github_user_id: Some(1),
+                email_address: Some("user1@example.com".to_string()),
+                admin: false,
+                metrics_id: user_metric_ids[0].parse().unwrap(),
+                ..Default::default()
+            },
+            User {
+                id: user_ids[1],
+                github_login: "user2".to_string(),
+                github_user_id: Some(2),
+                email_address: Some("user2@example.com".to_string()),
+                admin: false,
+                metrics_id: user_metric_ids[1].parse().unwrap(),
+                ..Default::default()
+            },
+            User {
+                id: user_ids[2],
+                github_login: "user3".to_string(),
+                github_user_id: Some(3),
+                email_address: Some("user3@example.com".to_string()),
+                admin: false,
+                metrics_id: user_metric_ids[2].parse().unwrap(),
+                ..Default::default()
+            },
+            User {
+                id: user_ids[3],
+                github_login: "user4".to_string(),
+                github_user_id: Some(4),
+                email_address: Some("user4@example.com".to_string()),
+                admin: false,
+                metrics_id: user_metric_ids[3].parse().unwrap(),
+                ..Default::default()
+            }
+        ]
+    );
+}
 
-        let user = db
-            .get_or_create_user_by_github_account("login3", Some(103), Some("user3@example.com"))
-            .await
-            .unwrap()
-            .unwrap();
-        assert_eq!(&user.github_login, "login3");
-        assert_eq!(user.github_user_id, Some(103));
-        assert_eq!(user.email_address, Some("user3@example.com".into()));
-    }
+test_both_dbs!(
+    test_get_or_create_user_by_github_account,
+    test_get_or_create_user_by_github_account_postgres,
+    test_get_or_create_user_by_github_account_sqlite
 );
 
+async fn test_get_or_create_user_by_github_account(db: &Arc<Database>) {
+    let user_id1 = db
+        .create_user(
+            "user1@example.com",
+            false,
+            NewUserParams {
+                github_login: "login1".into(),
+                github_user_id: 101,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+    let user_id2 = db
+        .create_user(
+            "user2@example.com",
+            false,
+            NewUserParams {
+                github_login: "login2".into(),
+                github_user_id: 102,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+
+    let user = db
+        .get_or_create_user_by_github_account("login1", None, None)
+        .await
+        .unwrap()
+        .unwrap();
+    assert_eq!(user.id, user_id1);
+    assert_eq!(&user.github_login, "login1");
+    assert_eq!(user.github_user_id, Some(101));
+
+    assert!(db
+        .get_or_create_user_by_github_account("non-existent-login", None, None)
+        .await
+        .unwrap()
+        .is_none());
+
+    let user = db
+        .get_or_create_user_by_github_account("the-new-login2", Some(102), None)
+        .await
+        .unwrap()
+        .unwrap();
+    assert_eq!(user.id, user_id2);
+    assert_eq!(&user.github_login, "the-new-login2");
+    assert_eq!(user.github_user_id, Some(102));
+
+    let user = db
+        .get_or_create_user_by_github_account("login3", Some(103), Some("user3@example.com"))
+        .await
+        .unwrap()
+        .unwrap();
+    assert_eq!(&user.github_login, "login3");
+    assert_eq!(user.github_user_id, Some(103));
+    assert_eq!(user.email_address, Some("user3@example.com".into()));
+}
+
 test_both_dbs!(
+    test_create_access_tokens,
     test_create_access_tokens_postgres,
-    test_create_access_tokens_sqlite,
-    db,
-    {
-        let user = db
-            .create_user(
-                "u1@example.com",
-                false,
-                NewUserParams {
-                    github_login: "u1".into(),
-                    github_user_id: 1,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap()
-            .user_id;
-
-        let token_1 = db.create_access_token(user, "h1", 2).await.unwrap();
-        let token_2 = db.create_access_token(user, "h2", 2).await.unwrap();
-        assert_eq!(
-            db.get_access_token(token_1).await.unwrap(),
-            access_token::Model {
-                id: token_1,
-                user_id: user,
-                hash: "h1".into(),
-            }
-        );
-        assert_eq!(
-            db.get_access_token(token_2).await.unwrap(),
-            access_token::Model {
-                id: token_2,
-                user_id: user,
-                hash: "h2".into()
-            }
-        );
+    test_create_access_tokens_sqlite
+);
 
-        let token_3 = db.create_access_token(user, "h3", 2).await.unwrap();
-        assert_eq!(
-            db.get_access_token(token_3).await.unwrap(),
-            access_token::Model {
-                id: token_3,
-                user_id: user,
-                hash: "h3".into()
-            }
-        );
-        assert_eq!(
-            db.get_access_token(token_2).await.unwrap(),
-            access_token::Model {
-                id: token_2,
-                user_id: user,
-                hash: "h2".into()
-            }
-        );
-        assert!(db.get_access_token(token_1).await.is_err());
-
-        let token_4 = db.create_access_token(user, "h4", 2).await.unwrap();
-        assert_eq!(
-            db.get_access_token(token_4).await.unwrap(),
-            access_token::Model {
-                id: token_4,
-                user_id: user,
-                hash: "h4".into()
-            }
-        );
-        assert_eq!(
-            db.get_access_token(token_3).await.unwrap(),
-            access_token::Model {
-                id: token_3,
-                user_id: user,
-                hash: "h3".into()
-            }
-        );
-        assert!(db.get_access_token(token_2).await.is_err());
-        assert!(db.get_access_token(token_1).await.is_err());
-    }
+async fn test_create_access_tokens(db: &Arc<Database>) {
+    let user = db
+        .create_user(
+            "u1@example.com",
+            false,
+            NewUserParams {
+                github_login: "u1".into(),
+                github_user_id: 1,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+
+    let token_1 = db.create_access_token(user, "h1", 2).await.unwrap();
+    let token_2 = db.create_access_token(user, "h2", 2).await.unwrap();
+    assert_eq!(
+        db.get_access_token(token_1).await.unwrap(),
+        access_token::Model {
+            id: token_1,
+            user_id: user,
+            hash: "h1".into(),
+        }
+    );
+    assert_eq!(
+        db.get_access_token(token_2).await.unwrap(),
+        access_token::Model {
+            id: token_2,
+            user_id: user,
+            hash: "h2".into()
+        }
+    );
+
+    let token_3 = db.create_access_token(user, "h3", 2).await.unwrap();
+    assert_eq!(
+        db.get_access_token(token_3).await.unwrap(),
+        access_token::Model {
+            id: token_3,
+            user_id: user,
+            hash: "h3".into()
+        }
+    );
+    assert_eq!(
+        db.get_access_token(token_2).await.unwrap(),
+        access_token::Model {
+            id: token_2,
+            user_id: user,
+            hash: "h2".into()
+        }
+    );
+    assert!(db.get_access_token(token_1).await.is_err());
+
+    let token_4 = db.create_access_token(user, "h4", 2).await.unwrap();
+    assert_eq!(
+        db.get_access_token(token_4).await.unwrap(),
+        access_token::Model {
+            id: token_4,
+            user_id: user,
+            hash: "h4".into()
+        }
+    );
+    assert_eq!(
+        db.get_access_token(token_3).await.unwrap(),
+        access_token::Model {
+            id: token_3,
+            user_id: user,
+            hash: "h3".into()
+        }
+    );
+    assert!(db.get_access_token(token_2).await.is_err());
+    assert!(db.get_access_token(token_1).await.is_err());
+}
+
+test_both_dbs!(
+    test_add_contacts,
+    test_add_contacts_postgres,
+    test_add_contacts_sqlite
 );
 
-test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, {
+async fn test_add_contacts(db: &Arc<Database>) {
     let mut user_ids = Vec::new();
     for i in 0..3 {
         user_ids.push(
@@ -403,9 +395,15 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, {
             busy: false,
         }],
     );
-});
+}
 
-test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, {
+test_both_dbs!(
+    test_metrics_id,
+    test_metrics_id_postgres,
+    test_metrics_id_sqlite
+);
+
+async fn test_metrics_id(db: &Arc<Database>) {
     let NewUserResult {
         user_id: user1,
         metrics_id: metrics_id1,
@@ -444,82 +442,83 @@ test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, {
     assert_eq!(metrics_id1.len(), 36);
     assert_eq!(metrics_id2.len(), 36);
     assert_ne!(metrics_id1, metrics_id2);
-});
+}
 
 test_both_dbs!(
+    test_project_count,
     test_project_count_postgres,
-    test_project_count_sqlite,
-    db,
-    {
-        let owner_id = db.create_server("test").await.unwrap().0 as u32;
+    test_project_count_sqlite
+);
 
-        let user1 = db
-            .create_user(
-                &format!("admin@example.com"),
-                true,
-                NewUserParams {
-                    github_login: "admin".into(),
-                    github_user_id: 0,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap();
-        let user2 = db
-            .create_user(
-                &format!("user@example.com"),
-                false,
-                NewUserParams {
-                    github_login: "user".into(),
-                    github_user_id: 1,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap();
+async fn test_project_count(db: &Arc<Database>) {
+    let owner_id = db.create_server("test").await.unwrap().0 as u32;
 
-        let room_id = RoomId::from_proto(
-            db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "")
-                .await
-                .unwrap()
-                .id,
-        );
-        db.call(
-            room_id,
-            user1.user_id,
-            ConnectionId { owner_id, id: 0 },
-            user2.user_id,
-            None,
+    let user1 = db
+        .create_user(
+            &format!("admin@example.com"),
+            true,
+            NewUserParams {
+                github_login: "admin".into(),
+                github_user_id: 0,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap();
+    let user2 = db
+        .create_user(
+            &format!("user@example.com"),
+            false,
+            NewUserParams {
+                github_login: "user".into(),
+                github_user_id: 1,
+                invite_count: 0,
+            },
         )
         .await
         .unwrap();
-        db.join_room(room_id, user2.user_id, ConnectionId { owner_id, id: 1 })
-            .await
-            .unwrap();
-        assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0);
 
-        db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[])
+    let room_id = RoomId::from_proto(
+        db.create_room(user1.user_id, ConnectionId { owner_id, id: 0 }, "")
             .await
-            .unwrap();
-        assert_eq!(db.project_count_excluding_admins().await.unwrap(), 1);
+            .unwrap()
+            .id,
+    );
+    db.call(
+        room_id,
+        user1.user_id,
+        ConnectionId { owner_id, id: 0 },
+        user2.user_id,
+        None,
+    )
+    .await
+    .unwrap();
+    db.join_room(room_id, user2.user_id, ConnectionId { owner_id, id: 1 })
+        .await
+        .unwrap();
+    assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0);
 
-        db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[])
-            .await
-            .unwrap();
-        assert_eq!(db.project_count_excluding_admins().await.unwrap(), 2);
+    db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[])
+        .await
+        .unwrap();
+    assert_eq!(db.project_count_excluding_admins().await.unwrap(), 1);
 
-        // Projects shared by admins aren't counted.
-        db.share_project(room_id, ConnectionId { owner_id, id: 0 }, &[])
-            .await
-            .unwrap();
-        assert_eq!(db.project_count_excluding_admins().await.unwrap(), 2);
+    db.share_project(room_id, ConnectionId { owner_id, id: 1 }, &[])
+        .await
+        .unwrap();
+    assert_eq!(db.project_count_excluding_admins().await.unwrap(), 2);
 
-        db.leave_room(ConnectionId { owner_id, id: 1 })
-            .await
-            .unwrap();
-        assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0);
-    }
-);
+    // Projects shared by admins aren't counted.
+    db.share_project(room_id, ConnectionId { owner_id, id: 0 }, &[])
+        .await
+        .unwrap();
+    assert_eq!(db.project_count_excluding_admins().await.unwrap(), 2);
+
+    db.leave_room(ConnectionId { owner_id, id: 1 })
+        .await
+        .unwrap();
+    assert_eq!(db.project_count_excluding_admins().await.unwrap(), 0);
+}
 
 #[test]
 fn test_fuzzy_like_string() {
@@ -878,7 +877,9 @@ async fn test_invite_codes() {
     assert!(db.has_contact(user5, user1).await.unwrap());
 }
 
-test_both_dbs!(test_channels_postgres, test_channels_sqlite, db, {
+test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
+
+async fn test_channels(db: &Arc<Database>) {
     let a_id = db
         .create_user(
             "user1@example.com",
@@ -1063,267 +1064,270 @@ test_both_dbs!(test_channels_postgres, test_channels_sqlite, db, {
     assert!(db.get_channel(rust_id, a_id).await.unwrap().is_none());
     assert!(db.get_channel(cargo_id, a_id).await.unwrap().is_none());
     assert!(db.get_channel(cargo_ra_id, a_id).await.unwrap().is_none());
-});
+}
 
 test_both_dbs!(
+    test_joining_channels,
     test_joining_channels_postgres,
-    test_joining_channels_sqlite,
-    db,
-    {
-        let owner_id = db.create_server("test").await.unwrap().0 as u32;
+    test_joining_channels_sqlite
+);
 
-        let user_1 = db
-            .create_user(
-                "user1@example.com",
-                false,
-                NewUserParams {
-                    github_login: "user1".into(),
-                    github_user_id: 5,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap()
-            .user_id;
-        let user_2 = db
-            .create_user(
-                "user2@example.com",
-                false,
-                NewUserParams {
-                    github_login: "user2".into(),
-                    github_user_id: 6,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap()
-            .user_id;
+async fn test_joining_channels(db: &Arc<Database>) {
+    let owner_id = db.create_server("test").await.unwrap().0 as u32;
 
-        let channel_1 = db
-            .create_root_channel("channel_1", "1", user_1)
-            .await
-            .unwrap();
-        let room_1 = db.room_id_for_channel(channel_1).await.unwrap();
+    let user_1 = db
+        .create_user(
+            "user1@example.com",
+            false,
+            NewUserParams {
+                github_login: "user1".into(),
+                github_user_id: 5,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+    let user_2 = db
+        .create_user(
+            "user2@example.com",
+            false,
+            NewUserParams {
+                github_login: "user2".into(),
+                github_user_id: 6,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
 
-        // can join a room with membership to its channel
-        let joined_room = db
-            .join_room(room_1, user_1, ConnectionId { owner_id, id: 1 })
-            .await
-            .unwrap();
-        assert_eq!(joined_room.room.participants.len(), 1);
+    let channel_1 = db
+        .create_root_channel("channel_1", "1", user_1)
+        .await
+        .unwrap();
+    let room_1 = db.room_id_for_channel(channel_1).await.unwrap();
 
-        drop(joined_room);
-        // cannot join a room without membership to its channel
-        assert!(db
-            .join_room(room_1, user_2, ConnectionId { owner_id, id: 1 })
-            .await
-            .is_err());
-    }
-);
+    // can join a room with membership to its channel
+    let joined_room = db
+        .join_room(room_1, user_1, ConnectionId { owner_id, id: 1 })
+        .await
+        .unwrap();
+    assert_eq!(joined_room.room.participants.len(), 1);
+
+    drop(joined_room);
+    // cannot join a room without membership to its channel
+    assert!(db
+        .join_room(room_1, user_2, ConnectionId { owner_id, id: 1 })
+        .await
+        .is_err());
+}
 
 test_both_dbs!(
+    test_channel_invites,
     test_channel_invites_postgres,
-    test_channel_invites_sqlite,
-    db,
-    {
-        db.create_server("test").await.unwrap();
+    test_channel_invites_sqlite
+);
 
-        let user_1 = db
-            .create_user(
-                "user1@example.com",
-                false,
-                NewUserParams {
-                    github_login: "user1".into(),
-                    github_user_id: 5,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap()
-            .user_id;
-        let user_2 = db
-            .create_user(
-                "user2@example.com",
-                false,
-                NewUserParams {
-                    github_login: "user2".into(),
-                    github_user_id: 6,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap()
-            .user_id;
+async fn test_channel_invites(db: &Arc<Database>) {
+    db.create_server("test").await.unwrap();
 
-        let user_3 = db
-            .create_user(
-                "user3@example.com",
-                false,
-                NewUserParams {
-                    github_login: "user3".into(),
-                    github_user_id: 7,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap()
-            .user_id;
+    let user_1 = db
+        .create_user(
+            "user1@example.com",
+            false,
+            NewUserParams {
+                github_login: "user1".into(),
+                github_user_id: 5,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+    let user_2 = db
+        .create_user(
+            "user2@example.com",
+            false,
+            NewUserParams {
+                github_login: "user2".into(),
+                github_user_id: 6,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
 
-        let channel_1_1 = db
-            .create_root_channel("channel_1", "1", user_1)
-            .await
-            .unwrap();
+    let user_3 = db
+        .create_user(
+            "user3@example.com",
+            false,
+            NewUserParams {
+                github_login: "user3".into(),
+                github_user_id: 7,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
 
-        let channel_1_2 = db
-            .create_root_channel("channel_2", "2", user_1)
-            .await
-            .unwrap();
+    let channel_1_1 = db
+        .create_root_channel("channel_1", "1", user_1)
+        .await
+        .unwrap();
 
-        db.invite_channel_member(channel_1_1, user_2, user_1, false)
-            .await
-            .unwrap();
-        db.invite_channel_member(channel_1_2, user_2, user_1, false)
-            .await
-            .unwrap();
-        db.invite_channel_member(channel_1_1, user_3, user_1, true)
-            .await
-            .unwrap();
+    let channel_1_2 = db
+        .create_root_channel("channel_2", "2", user_1)
+        .await
+        .unwrap();
 
-        let user_2_invites = db
-            .get_channel_invites_for_user(user_2) // -> [channel_1_1, channel_1_2]
-            .await
-            .unwrap()
-            .into_iter()
-            .map(|channel| channel.id)
-            .collect::<Vec<_>>();
+    db.invite_channel_member(channel_1_1, user_2, user_1, false)
+        .await
+        .unwrap();
+    db.invite_channel_member(channel_1_2, user_2, user_1, false)
+        .await
+        .unwrap();
+    db.invite_channel_member(channel_1_1, user_3, user_1, true)
+        .await
+        .unwrap();
 
-        assert_eq!(user_2_invites, &[channel_1_1, channel_1_2]);
+    let user_2_invites = db
+        .get_channel_invites_for_user(user_2) // -> [channel_1_1, channel_1_2]
+        .await
+        .unwrap()
+        .into_iter()
+        .map(|channel| channel.id)
+        .collect::<Vec<_>>();
 
-        let user_3_invites = db
-            .get_channel_invites_for_user(user_3) // -> [channel_1_1]
-            .await
-            .unwrap()
-            .into_iter()
-            .map(|channel| channel.id)
-            .collect::<Vec<_>>();
+    assert_eq!(user_2_invites, &[channel_1_1, channel_1_2]);
 
-        assert_eq!(user_3_invites, &[channel_1_1]);
+    let user_3_invites = db
+        .get_channel_invites_for_user(user_3) // -> [channel_1_1]
+        .await
+        .unwrap()
+        .into_iter()
+        .map(|channel| channel.id)
+        .collect::<Vec<_>>();
 
-        let members = db
-            .get_channel_member_details(channel_1_1, user_1)
-            .await
-            .unwrap();
-        assert_eq!(
-            members,
-            &[
-                proto::ChannelMember {
-                    user_id: user_1.to_proto(),
-                    kind: proto::channel_member::Kind::Member.into(),
-                    admin: true,
-                },
-                proto::ChannelMember {
-                    user_id: user_2.to_proto(),
-                    kind: proto::channel_member::Kind::Invitee.into(),
-                    admin: false,
-                },
-                proto::ChannelMember {
-                    user_id: user_3.to_proto(),
-                    kind: proto::channel_member::Kind::Invitee.into(),
-                    admin: true,
-                },
-            ]
-        );
+    assert_eq!(user_3_invites, &[channel_1_1]);
 
-        db.respond_to_channel_invite(channel_1_1, user_2, true)
-            .await
-            .unwrap();
+    let members = db
+        .get_channel_member_details(channel_1_1, user_1)
+        .await
+        .unwrap();
+    assert_eq!(
+        members,
+        &[
+            proto::ChannelMember {
+                user_id: user_1.to_proto(),
+                kind: proto::channel_member::Kind::Member.into(),
+                admin: true,
+            },
+            proto::ChannelMember {
+                user_id: user_2.to_proto(),
+                kind: proto::channel_member::Kind::Invitee.into(),
+                admin: false,
+            },
+            proto::ChannelMember {
+                user_id: user_3.to_proto(),
+                kind: proto::channel_member::Kind::Invitee.into(),
+                admin: true,
+            },
+        ]
+    );
 
-        let channel_1_3 = db
-            .create_channel("channel_3", Some(channel_1_1), "1", user_1)
-            .await
-            .unwrap();
+    db.respond_to_channel_invite(channel_1_1, user_2, true)
+        .await
+        .unwrap();
 
-        let members = db
-            .get_channel_member_details(channel_1_3, user_1)
-            .await
-            .unwrap();
-        assert_eq!(
-            members,
-            &[
-                proto::ChannelMember {
-                    user_id: user_1.to_proto(),
-                    kind: proto::channel_member::Kind::Member.into(),
-                    admin: true,
-                },
-                proto::ChannelMember {
-                    user_id: user_2.to_proto(),
-                    kind: proto::channel_member::Kind::AncestorMember.into(),
-                    admin: false,
-                },
-            ]
-        );
-    }
-);
+    let channel_1_3 = db
+        .create_channel("channel_3", Some(channel_1_1), "1", user_1)
+        .await
+        .unwrap();
+
+    let members = db
+        .get_channel_member_details(channel_1_3, user_1)
+        .await
+        .unwrap();
+    assert_eq!(
+        members,
+        &[
+            proto::ChannelMember {
+                user_id: user_1.to_proto(),
+                kind: proto::channel_member::Kind::Member.into(),
+                admin: true,
+            },
+            proto::ChannelMember {
+                user_id: user_2.to_proto(),
+                kind: proto::channel_member::Kind::AncestorMember.into(),
+                admin: false,
+            },
+        ]
+    );
+}
 
 test_both_dbs!(
+    test_channel_renames,
     test_channel_renames_postgres,
-    test_channel_renames_sqlite,
-    db,
-    {
-        db.create_server("test").await.unwrap();
+    test_channel_renames_sqlite
+);
 
-        let user_1 = db
-            .create_user(
-                "user1@example.com",
-                false,
-                NewUserParams {
-                    github_login: "user1".into(),
-                    github_user_id: 5,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap()
-            .user_id;
+async fn test_channel_renames(db: &Arc<Database>) {
+    db.create_server("test").await.unwrap();
 
-        let user_2 = db
-            .create_user(
-                "user2@example.com",
-                false,
-                NewUserParams {
-                    github_login: "user2".into(),
-                    github_user_id: 6,
-                    invite_count: 0,
-                },
-            )
-            .await
-            .unwrap()
-            .user_id;
+    let user_1 = db
+        .create_user(
+            "user1@example.com",
+            false,
+            NewUserParams {
+                github_login: "user1".into(),
+                github_user_id: 5,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
 
-        let zed_id = db.create_root_channel("zed", "1", user_1).await.unwrap();
+    let user_2 = db
+        .create_user(
+            "user2@example.com",
+            false,
+            NewUserParams {
+                github_login: "user2".into(),
+                github_user_id: 6,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
 
-        db.rename_channel(zed_id, user_1, "#zed-archive")
-            .await
-            .unwrap();
+    let zed_id = db.create_root_channel("zed", "1", user_1).await.unwrap();
 
-        let zed_archive_id = zed_id;
+    db.rename_channel(zed_id, user_1, "#zed-archive")
+        .await
+        .unwrap();
 
-        let (channel, _) = db
-            .get_channel(zed_archive_id, user_1)
-            .await
-            .unwrap()
-            .unwrap();
-        assert_eq!(channel.name, "zed-archive");
+    let zed_archive_id = zed_id;
 
-        let non_permissioned_rename = db
-            .rename_channel(zed_archive_id, user_2, "hacked-lol")
-            .await;
-        assert!(non_permissioned_rename.is_err());
+    let (channel, _) = db
+        .get_channel(zed_archive_id, user_1)
+        .await
+        .unwrap()
+        .unwrap();
+    assert_eq!(channel.name, "zed-archive");
 
-        let bad_name_rename = db.rename_channel(zed_id, user_1, "#").await;
-        assert!(bad_name_rename.is_err())
-    }
-);
+    let non_permissioned_rename = db
+        .rename_channel(zed_archive_id, user_2, "hacked-lol")
+        .await;
+    assert!(non_permissioned_rename.is_err());
+
+    let bad_name_rename = db.rename_channel(zed_id, user_1, "#").await;
+    assert!(bad_name_rename.is_err())
+}
 
 #[gpui::test]
 async fn test_multiple_signup_overwrite() {

crates/collab/src/db/test_db.rs 🔗

@@ -91,6 +91,23 @@ impl TestDb {
     }
 }
 
+#[macro_export]
+macro_rules! test_both_dbs {
+    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
+        #[gpui::test]
+        async fn $postgres_test_name() {
+            let test_db = TestDb::postgres(Deterministic::new(0).build_background());
+            $test_name(test_db.db()).await;
+        }
+
+        #[gpui::test]
+        async fn $sqlite_test_name() {
+            let test_db = TestDb::sqlite(Deterministic::new(0).build_background());
+            $test_name(test_db.db()).await;
+        }
+    };
+}
+
 impl Drop for TestDb {
     fn drop(&mut self) {
         let db = self.db.take().unwrap();