tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod contributor_tests;
  4mod db_tests;
  5// we only run postgres tests on macos right now
  6#[cfg(target_os = "macos")]
  7mod embedding_tests;
  8mod extension_tests;
  9mod feature_flag_tests;
 10mod message_tests;
 11
 12use super::*;
 13use gpui::BackgroundExecutor;
 14use parking_lot::Mutex;
 15use sea_orm::ConnectionTrait;
 16use sqlx::migrate::MigrateDatabase;
 17use std::sync::{
 18    atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 19    Arc,
 20};
 21
 22pub struct TestDb {
 23    pub db: Option<Arc<Database>>,
 24    pub connection: Option<sqlx::AnyConnection>,
 25}
 26
 27impl TestDb {
 28    pub fn sqlite(background: BackgroundExecutor) -> Self {
 29        let url = "sqlite::memory:";
 30        let runtime = tokio::runtime::Builder::new_current_thread()
 31            .enable_io()
 32            .enable_time()
 33            .build()
 34            .unwrap();
 35
 36        let mut db = runtime.block_on(async {
 37            let mut options = ConnectOptions::new(url);
 38            options.max_connections(5);
 39            let mut db = Database::new(options, Executor::Deterministic(background))
 40                .await
 41                .unwrap();
 42            let sql = include_str!(concat!(
 43                env!("CARGO_MANIFEST_DIR"),
 44                "/migrations.sqlite/20221109000000_test_schema.sql"
 45            ));
 46            db.pool
 47                .execute(sea_orm::Statement::from_string(
 48                    db.pool.get_database_backend(),
 49                    sql,
 50                ))
 51                .await
 52                .unwrap();
 53            db.initialize_notification_kinds().await.unwrap();
 54            db
 55        });
 56
 57        db.runtime = Some(runtime);
 58
 59        Self {
 60            db: Some(Arc::new(db)),
 61            connection: None,
 62        }
 63    }
 64
 65    pub fn postgres(background: BackgroundExecutor) -> Self {
 66        static LOCK: Mutex<()> = Mutex::new(());
 67
 68        let _guard = LOCK.lock();
 69        let mut rng = StdRng::from_entropy();
 70        let url = format!(
 71            "postgres://postgres@localhost/zed-test-{}",
 72            rng.gen::<u128>()
 73        );
 74        let runtime = tokio::runtime::Builder::new_current_thread()
 75            .enable_io()
 76            .enable_time()
 77            .build()
 78            .unwrap();
 79
 80        let mut db = runtime.block_on(async {
 81            sqlx::Postgres::create_database(&url)
 82                .await
 83                .expect("failed to create test db");
 84            let mut options = ConnectOptions::new(url);
 85            options
 86                .max_connections(5)
 87                .idle_timeout(Duration::from_secs(0));
 88            let mut db = Database::new(options, Executor::Deterministic(background))
 89                .await
 90                .unwrap();
 91            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 92            db.migrate(Path::new(migrations_path), false).await.unwrap();
 93            db.initialize_notification_kinds().await.unwrap();
 94            db
 95        });
 96
 97        db.runtime = Some(runtime);
 98
 99        Self {
100            db: Some(Arc::new(db)),
101            connection: None,
102        }
103    }
104
105    pub fn db(&self) -> &Arc<Database> {
106        self.db.as_ref().unwrap()
107    }
108}
109
110#[macro_export]
111macro_rules! test_both_dbs {
112    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
113        #[cfg(target_os = "macos")]
114        #[gpui::test]
115        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
116            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
117            $test_name(test_db.db()).await;
118        }
119
120        #[gpui::test]
121        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
122            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
123            $test_name(test_db.db()).await;
124        }
125    };
126}
127
128impl Drop for TestDb {
129    fn drop(&mut self) {
130        let db = self.db.take().unwrap();
131        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
132            db.runtime.as_ref().unwrap().block_on(async {
133                use util::ResultExt;
134                let query = "
135                        SELECT pg_terminate_backend(pg_stat_activity.pid)
136                        FROM pg_stat_activity
137                        WHERE
138                            pg_stat_activity.datname = current_database() AND
139                            pid <> pg_backend_pid();
140                    ";
141                db.pool
142                    .execute(sea_orm::Statement::from_string(
143                        db.pool.get_database_backend(),
144                        query,
145                    ))
146                    .await
147                    .log_err();
148                sqlx::Postgres::drop_database(db.options.get_url())
149                    .await
150                    .log_err();
151            })
152        }
153    }
154}
155
156fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
157    channels
158        .iter()
159        .map(|(id, parent_path, name)| Channel {
160            id: *id,
161            name: name.to_string(),
162            visibility: ChannelVisibility::Members,
163            parent_path: parent_path.to_vec(),
164        })
165        .collect()
166}
167
168static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
169
170async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
171    db.create_user(
172        email,
173        false,
174        NewUserParams {
175            github_login: email[0..email.find('@').unwrap()].to_string(),
176            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
177        },
178    )
179    .await
180    .unwrap()
181    .user_id
182}
183
184static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
185fn new_test_connection(server: ServerId) -> ConnectionId {
186    ConnectionId {
187        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
188        owner_id: server.0 as u32,
189    }
190}