tests.rs

  1mod billing_subscription_tests;
  2mod buffer_tests;
  3mod channel_tests;
  4mod contributor_tests;
  5mod db_tests;
  6// we only run postgres tests on macos right now
  7#[cfg(target_os = "macos")]
  8mod embedding_tests;
  9mod extension_tests;
 10mod feature_flag_tests;
 11mod message_tests;
 12mod processed_stripe_event_tests;
 13
 14use crate::migrations::run_database_migrations;
 15
 16use super::*;
 17use gpui::BackgroundExecutor;
 18use parking_lot::Mutex;
 19use sea_orm::ConnectionTrait;
 20use sqlx::migrate::MigrateDatabase;
 21use std::sync::{
 22    atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 23    Arc,
 24};
 25
 26pub struct TestDb {
 27    pub db: Option<Arc<Database>>,
 28    pub connection: Option<sqlx::AnyConnection>,
 29}
 30
 31impl TestDb {
 32    pub fn sqlite(background: BackgroundExecutor) -> Self {
 33        let url = "sqlite::memory:";
 34        let runtime = tokio::runtime::Builder::new_current_thread()
 35            .enable_io()
 36            .enable_time()
 37            .build()
 38            .unwrap();
 39
 40        let mut db = runtime.block_on(async {
 41            let mut options = ConnectOptions::new(url);
 42            options.max_connections(5);
 43            let mut db = Database::new(options, Executor::Deterministic(background))
 44                .await
 45                .unwrap();
 46            let sql = include_str!(concat!(
 47                env!("CARGO_MANIFEST_DIR"),
 48                "/migrations.sqlite/20221109000000_test_schema.sql"
 49            ));
 50            db.pool
 51                .execute(sea_orm::Statement::from_string(
 52                    db.pool.get_database_backend(),
 53                    sql,
 54                ))
 55                .await
 56                .unwrap();
 57            db.initialize_notification_kinds().await.unwrap();
 58            db
 59        });
 60
 61        db.runtime = Some(runtime);
 62
 63        Self {
 64            db: Some(Arc::new(db)),
 65            connection: None,
 66        }
 67    }
 68
 69    pub fn postgres(background: BackgroundExecutor) -> Self {
 70        static LOCK: Mutex<()> = Mutex::new(());
 71
 72        let _guard = LOCK.lock();
 73        let mut rng = StdRng::from_entropy();
 74        let url = format!(
 75            "postgres://postgres@localhost/zed-test-{}",
 76            rng.gen::<u128>()
 77        );
 78        let runtime = tokio::runtime::Builder::new_current_thread()
 79            .enable_io()
 80            .enable_time()
 81            .build()
 82            .unwrap();
 83
 84        let mut db = runtime.block_on(async {
 85            sqlx::Postgres::create_database(&url)
 86                .await
 87                .expect("failed to create test db");
 88            let mut options = ConnectOptions::new(url);
 89            options
 90                .max_connections(5)
 91                .idle_timeout(Duration::from_secs(0));
 92            let mut db = Database::new(options, Executor::Deterministic(background))
 93                .await
 94                .unwrap();
 95            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 96            run_database_migrations(db.options(), migrations_path)
 97                .await
 98                .unwrap();
 99            db.initialize_notification_kinds().await.unwrap();
100            db
101        });
102
103        db.runtime = Some(runtime);
104
105        Self {
106            db: Some(Arc::new(db)),
107            connection: None,
108        }
109    }
110
111    pub fn db(&self) -> &Arc<Database> {
112        self.db.as_ref().unwrap()
113    }
114}
115
116#[macro_export]
117macro_rules! test_both_dbs {
118    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
119        #[cfg(target_os = "macos")]
120        #[gpui::test]
121        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
122            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
123            $test_name(test_db.db()).await;
124        }
125
126        #[gpui::test]
127        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
128            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
129            $test_name(test_db.db()).await;
130        }
131    };
132}
133
134impl Drop for TestDb {
135    fn drop(&mut self) {
136        let db = self.db.take().unwrap();
137        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
138            db.runtime.as_ref().unwrap().block_on(async {
139                use util::ResultExt;
140                let query = "
141                        SELECT pg_terminate_backend(pg_stat_activity.pid)
142                        FROM pg_stat_activity
143                        WHERE
144                            pg_stat_activity.datname = current_database() AND
145                            pid <> pg_backend_pid();
146                    ";
147                db.pool
148                    .execute(sea_orm::Statement::from_string(
149                        db.pool.get_database_backend(),
150                        query,
151                    ))
152                    .await
153                    .log_err();
154                sqlx::Postgres::drop_database(db.options.get_url())
155                    .await
156                    .log_err();
157            })
158        }
159    }
160}
161
162fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
163    channels
164        .iter()
165        .map(|(id, parent_path, name)| Channel {
166            id: *id,
167            name: name.to_string(),
168            visibility: ChannelVisibility::Members,
169            parent_path: parent_path.to_vec(),
170        })
171        .collect()
172}
173
174static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
175
176async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
177    db.create_user(
178        email,
179        false,
180        NewUserParams {
181            github_login: email[0..email.find('@').unwrap()].to_string(),
182            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
183        },
184    )
185    .await
186    .unwrap()
187    .user_id
188}
189
190static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
191fn new_test_connection(server: ServerId) -> ConnectionId {
192    ConnectionId {
193        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
194        owner_id: server.0 as u32,
195    }
196}