tests.rs

  1mod billing_subscription_tests;
  2mod buffer_tests;
  3mod channel_tests;
  4mod contributor_tests;
  5mod db_tests;
  6// we only run postgres tests on macos right now
  7#[cfg(target_os = "macos")]
  8mod embedding_tests;
  9mod extension_tests;
 10mod feature_flag_tests;
 11mod message_tests;
 12mod processed_stripe_event_tests;
 13
 14use super::*;
 15use gpui::BackgroundExecutor;
 16use parking_lot::Mutex;
 17use sea_orm::ConnectionTrait;
 18use sqlx::migrate::MigrateDatabase;
 19use std::sync::{
 20    atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 21    Arc,
 22};
 23
 24pub struct TestDb {
 25    pub db: Option<Arc<Database>>,
 26    pub connection: Option<sqlx::AnyConnection>,
 27}
 28
 29impl TestDb {
 30    pub fn sqlite(background: BackgroundExecutor) -> Self {
 31        let url = "sqlite::memory:";
 32        let runtime = tokio::runtime::Builder::new_current_thread()
 33            .enable_io()
 34            .enable_time()
 35            .build()
 36            .unwrap();
 37
 38        let mut db = runtime.block_on(async {
 39            let mut options = ConnectOptions::new(url);
 40            options.max_connections(5);
 41            let mut db = Database::new(options, Executor::Deterministic(background))
 42                .await
 43                .unwrap();
 44            let sql = include_str!(concat!(
 45                env!("CARGO_MANIFEST_DIR"),
 46                "/migrations.sqlite/20221109000000_test_schema.sql"
 47            ));
 48            db.pool
 49                .execute(sea_orm::Statement::from_string(
 50                    db.pool.get_database_backend(),
 51                    sql,
 52                ))
 53                .await
 54                .unwrap();
 55            db.initialize_notification_kinds().await.unwrap();
 56            db
 57        });
 58
 59        db.runtime = Some(runtime);
 60
 61        Self {
 62            db: Some(Arc::new(db)),
 63            connection: None,
 64        }
 65    }
 66
 67    pub fn postgres(background: BackgroundExecutor) -> Self {
 68        static LOCK: Mutex<()> = Mutex::new(());
 69
 70        let _guard = LOCK.lock();
 71        let mut rng = StdRng::from_entropy();
 72        let url = format!(
 73            "postgres://postgres@localhost/zed-test-{}",
 74            rng.gen::<u128>()
 75        );
 76        let runtime = tokio::runtime::Builder::new_current_thread()
 77            .enable_io()
 78            .enable_time()
 79            .build()
 80            .unwrap();
 81
 82        let mut db = runtime.block_on(async {
 83            sqlx::Postgres::create_database(&url)
 84                .await
 85                .expect("failed to create test db");
 86            let mut options = ConnectOptions::new(url);
 87            options
 88                .max_connections(5)
 89                .idle_timeout(Duration::from_secs(0));
 90            let mut db = Database::new(options, Executor::Deterministic(background))
 91                .await
 92                .unwrap();
 93            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 94            db.migrate(Path::new(migrations_path), false).await.unwrap();
 95            db.initialize_notification_kinds().await.unwrap();
 96            db
 97        });
 98
 99        db.runtime = Some(runtime);
100
101        Self {
102            db: Some(Arc::new(db)),
103            connection: None,
104        }
105    }
106
107    pub fn db(&self) -> &Arc<Database> {
108        self.db.as_ref().unwrap()
109    }
110}
111
112#[macro_export]
113macro_rules! test_both_dbs {
114    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
115        #[cfg(target_os = "macos")]
116        #[gpui::test]
117        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
118            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
119            $test_name(test_db.db()).await;
120        }
121
122        #[gpui::test]
123        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
124            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
125            $test_name(test_db.db()).await;
126        }
127    };
128}
129
130impl Drop for TestDb {
131    fn drop(&mut self) {
132        let db = self.db.take().unwrap();
133        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
134            db.runtime.as_ref().unwrap().block_on(async {
135                use util::ResultExt;
136                let query = "
137                        SELECT pg_terminate_backend(pg_stat_activity.pid)
138                        FROM pg_stat_activity
139                        WHERE
140                            pg_stat_activity.datname = current_database() AND
141                            pid <> pg_backend_pid();
142                    ";
143                db.pool
144                    .execute(sea_orm::Statement::from_string(
145                        db.pool.get_database_backend(),
146                        query,
147                    ))
148                    .await
149                    .log_err();
150                sqlx::Postgres::drop_database(db.options.get_url())
151                    .await
152                    .log_err();
153            })
154        }
155    }
156}
157
158fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
159    channels
160        .iter()
161        .map(|(id, parent_path, name)| Channel {
162            id: *id,
163            name: name.to_string(),
164            visibility: ChannelVisibility::Members,
165            parent_path: parent_path.to_vec(),
166        })
167        .collect()
168}
169
170static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
171
172async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
173    db.create_user(
174        email,
175        false,
176        NewUserParams {
177            github_login: email[0..email.find('@').unwrap()].to_string(),
178            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
179        },
180    )
181    .await
182    .unwrap()
183    .user_id
184}
185
186static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
187fn new_test_connection(server: ServerId) -> ConnectionId {
188    ConnectionId {
189        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
190        owner_id: server.0 as u32,
191    }
192}