tests.rs

  1mod billing_subscription_tests;
  2mod buffer_tests;
  3mod channel_tests;
  4mod contributor_tests;
  5mod db_tests;
  6// we only run postgres tests on macos right now
  7#[cfg(target_os = "macos")]
  8mod embedding_tests;
  9mod extension_tests;
 10mod feature_flag_tests;
 11mod message_tests;
 12mod processed_stripe_event_tests;
 13mod user_tests;
 14
 15use crate::migrations::run_database_migrations;
 16
 17use super::*;
 18use gpui::BackgroundExecutor;
 19use parking_lot::Mutex;
 20use sea_orm::ConnectionTrait;
 21use sqlx::migrate::MigrateDatabase;
 22use std::sync::{
 23    Arc,
 24    atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 25};
 26
 27pub struct TestDb {
 28    pub db: Option<Arc<Database>>,
 29    pub connection: Option<sqlx::AnyConnection>,
 30}
 31
 32impl TestDb {
 33    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 34        let url = "sqlite::memory:";
 35        let runtime = tokio::runtime::Builder::new_current_thread()
 36            .enable_io()
 37            .enable_time()
 38            .build()
 39            .unwrap();
 40
 41        let mut db = runtime.block_on(async {
 42            let mut options = ConnectOptions::new(url);
 43            options.max_connections(5);
 44            let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
 45                .await
 46                .unwrap();
 47            let sql = include_str!(concat!(
 48                env!("CARGO_MANIFEST_DIR"),
 49                "/migrations.sqlite/20221109000000_test_schema.sql"
 50            ));
 51            db.pool
 52                .execute(sea_orm::Statement::from_string(
 53                    db.pool.get_database_backend(),
 54                    sql,
 55                ))
 56                .await
 57                .unwrap();
 58            db.initialize_notification_kinds().await.unwrap();
 59            db
 60        });
 61
 62        db.test_options = Some(DatabaseTestOptions {
 63            runtime,
 64            query_failure_probability: parking_lot::Mutex::new(0.0),
 65        });
 66
 67        Self {
 68            db: Some(Arc::new(db)),
 69            connection: None,
 70        }
 71    }
 72
 73    pub fn postgres(executor: BackgroundExecutor) -> Self {
 74        static LOCK: Mutex<()> = Mutex::new(());
 75
 76        let _guard = LOCK.lock();
 77        let mut rng = StdRng::from_entropy();
 78        let url = format!(
 79            "postgres://postgres@localhost/zed-test-{}",
 80            rng.r#gen::<u128>()
 81        );
 82        let runtime = tokio::runtime::Builder::new_current_thread()
 83            .enable_io()
 84            .enable_time()
 85            .build()
 86            .unwrap();
 87
 88        let mut db = runtime.block_on(async {
 89            sqlx::Postgres::create_database(&url)
 90                .await
 91                .expect("failed to create test db");
 92            let mut options = ConnectOptions::new(url);
 93            options
 94                .max_connections(5)
 95                .idle_timeout(Duration::from_secs(0));
 96            let mut db = Database::new(options, Executor::Deterministic(executor.clone()))
 97                .await
 98                .unwrap();
 99            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
100            run_database_migrations(db.options(), migrations_path)
101                .await
102                .unwrap();
103            db.initialize_notification_kinds().await.unwrap();
104            db
105        });
106
107        db.test_options = Some(DatabaseTestOptions {
108            runtime,
109            query_failure_probability: parking_lot::Mutex::new(0.0),
110        });
111
112        Self {
113            db: Some(Arc::new(db)),
114            connection: None,
115        }
116    }
117
118    pub fn db(&self) -> &Arc<Database> {
119        self.db.as_ref().unwrap()
120    }
121
122    pub fn set_query_failure_probability(&self, probability: f64) {
123        let database = self.db.as_ref().unwrap();
124        let test_options = database.test_options.as_ref().unwrap();
125        *test_options.query_failure_probability.lock() = probability;
126    }
127}
128
129#[macro_export]
130macro_rules! test_both_dbs {
131    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
132        #[cfg(target_os = "macos")]
133        #[gpui::test]
134        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
135            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
136            $test_name(test_db.db()).await;
137        }
138
139        #[gpui::test]
140        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
141            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
142            $test_name(test_db.db()).await;
143        }
144    };
145}
146
147impl Drop for TestDb {
148    fn drop(&mut self) {
149        let db = self.db.take().unwrap();
150        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
151            db.test_options.as_ref().unwrap().runtime.block_on(async {
152                use util::ResultExt;
153                let query = "
154                        SELECT pg_terminate_backend(pg_stat_activity.pid)
155                        FROM pg_stat_activity
156                        WHERE
157                            pg_stat_activity.datname = current_database() AND
158                            pid <> pg_backend_pid();
159                    ";
160                db.pool
161                    .execute(sea_orm::Statement::from_string(
162                        db.pool.get_database_backend(),
163                        query,
164                    ))
165                    .await
166                    .log_err();
167                sqlx::Postgres::drop_database(db.options.get_url())
168                    .await
169                    .log_err();
170            })
171        }
172    }
173}
174
175fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
176    channels
177        .iter()
178        .map(|(id, parent_path, name)| Channel {
179            id: *id,
180            name: name.to_string(),
181            visibility: ChannelVisibility::Members,
182            parent_path: parent_path.to_vec(),
183        })
184        .collect()
185}
186
187static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
188
189async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
190    db.create_user(
191        email,
192        None,
193        false,
194        NewUserParams {
195            github_login: email[0..email.find('@').unwrap()].to_string(),
196            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
197        },
198    )
199    .await
200    .unwrap()
201    .user_id
202}
203
204static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
205fn new_test_connection(server: ServerId) -> ConnectionId {
206    ConnectionId {
207        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
208        owner_id: server.0 as u32,
209    }
210}