tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod contributor_tests;
  4mod db_tests;
  5mod feature_flag_tests;
  6mod message_tests;
  7
  8use super::*;
  9use gpui::BackgroundExecutor;
 10use parking_lot::Mutex;
 11use sea_orm::ConnectionTrait;
 12use sqlx::migrate::MigrateDatabase;
 13use std::sync::{
 14    atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 15    Arc,
 16};
 17
 18pub struct TestDb {
 19    pub db: Option<Arc<Database>>,
 20    pub connection: Option<sqlx::AnyConnection>,
 21}
 22
 23impl TestDb {
 24    pub fn sqlite(background: BackgroundExecutor) -> Self {
 25        let url = format!("sqlite::memory:");
 26        let runtime = tokio::runtime::Builder::new_current_thread()
 27            .enable_io()
 28            .enable_time()
 29            .build()
 30            .unwrap();
 31
 32        let mut db = runtime.block_on(async {
 33            let mut options = ConnectOptions::new(url);
 34            options.max_connections(5);
 35            let mut db = Database::new(options, Executor::Deterministic(background))
 36                .await
 37                .unwrap();
 38            let sql = include_str!(concat!(
 39                env!("CARGO_MANIFEST_DIR"),
 40                "/migrations.sqlite/20221109000000_test_schema.sql"
 41            ));
 42            db.pool
 43                .execute(sea_orm::Statement::from_string(
 44                    db.pool.get_database_backend(),
 45                    sql,
 46                ))
 47                .await
 48                .unwrap();
 49            db.initialize_notification_kinds().await.unwrap();
 50            db
 51        });
 52
 53        db.runtime = Some(runtime);
 54
 55        Self {
 56            db: Some(Arc::new(db)),
 57            connection: None,
 58        }
 59    }
 60
 61    pub fn postgres(background: BackgroundExecutor) -> Self {
 62        static LOCK: Mutex<()> = Mutex::new(());
 63
 64        let _guard = LOCK.lock();
 65        let mut rng = StdRng::from_entropy();
 66        let url = format!(
 67            "postgres://postgres@localhost/zed-test-{}",
 68            rng.gen::<u128>()
 69        );
 70        let runtime = tokio::runtime::Builder::new_current_thread()
 71            .enable_io()
 72            .enable_time()
 73            .build()
 74            .unwrap();
 75
 76        let mut db = runtime.block_on(async {
 77            sqlx::Postgres::create_database(&url)
 78                .await
 79                .expect("failed to create test db");
 80            let mut options = ConnectOptions::new(url);
 81            options
 82                .max_connections(5)
 83                .idle_timeout(Duration::from_secs(0));
 84            let mut db = Database::new(options, Executor::Deterministic(background))
 85                .await
 86                .unwrap();
 87            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 88            db.migrate(Path::new(migrations_path), false).await.unwrap();
 89            db.initialize_notification_kinds().await.unwrap();
 90            db
 91        });
 92
 93        db.runtime = Some(runtime);
 94
 95        Self {
 96            db: Some(Arc::new(db)),
 97            connection: None,
 98        }
 99    }
100
101    pub fn db(&self) -> &Arc<Database> {
102        self.db.as_ref().unwrap()
103    }
104}
105
106#[macro_export]
107macro_rules! test_both_dbs {
108    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
109        #[gpui::test]
110        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
111            let test_db = crate::db::TestDb::postgres(cx.executor().clone());
112            $test_name(test_db.db()).await;
113        }
114
115        #[gpui::test]
116        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
117            let test_db = crate::db::TestDb::sqlite(cx.executor().clone());
118            $test_name(test_db.db()).await;
119        }
120    };
121}
122
123impl Drop for TestDb {
124    fn drop(&mut self) {
125        let db = self.db.take().unwrap();
126        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
127            db.runtime.as_ref().unwrap().block_on(async {
128                use util::ResultExt;
129                let query = "
130                        SELECT pg_terminate_backend(pg_stat_activity.pid)
131                        FROM pg_stat_activity
132                        WHERE
133                            pg_stat_activity.datname = current_database() AND
134                            pid <> pg_backend_pid();
135                    ";
136                db.pool
137                    .execute(sea_orm::Statement::from_string(
138                        db.pool.get_database_backend(),
139                        query,
140                    ))
141                    .await
142                    .log_err();
143                sqlx::Postgres::drop_database(db.options.get_url())
144                    .await
145                    .log_err();
146            })
147        }
148    }
149}
150
151fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
152    channels
153        .iter()
154        .map(|(id, parent_path, name)| Channel {
155            id: *id,
156            name: name.to_string(),
157            visibility: ChannelVisibility::Members,
158            parent_path: parent_path.to_vec(),
159        })
160        .collect()
161}
162
163static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
164
165async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
166    db.create_user(
167        email,
168        false,
169        NewUserParams {
170            github_login: email[0..email.find("@").unwrap()].to_string(),
171            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
172        },
173    )
174    .await
175    .unwrap()
176    .user_id
177}
178
179static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
180fn new_test_connection(server: ServerId) -> ConnectionId {
181    ConnectionId {
182        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
183        owner_id: server.0 as u32,
184    }
185}