tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod contributor_tests;
  4mod db_tests;
  5// we only run postgres tests on macos right now
  6#[cfg(target_os = "macos")]
  7mod embedding_tests;
  8mod extension_tests;
  9
 10use crate::migrations::run_database_migrations;
 11
 12use super::*;
 13use gpui::BackgroundExecutor;
 14use parking_lot::Mutex;
 15use rand::prelude::*;
 16use sea_orm::ConnectionTrait;
 17use sqlx::migrate::MigrateDatabase;
 18use std::{
 19    sync::{
 20        Arc,
 21        atomic::{AtomicI32, Ordering::SeqCst},
 22    },
 23    time::Duration,
 24};
 25
 26pub struct TestDb {
 27    pub db: Option<Arc<Database>>,
 28    pub connection: Option<sqlx::AnyConnection>,
 29}
 30
 31impl TestDb {
 32    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 33        let url = "sqlite::memory:";
 34        let runtime = tokio::runtime::Builder::new_current_thread()
 35            .enable_io()
 36            .enable_time()
 37            .build()
 38            .unwrap();
 39
 40        let mut db = runtime.block_on(async {
 41            let mut options = ConnectOptions::new(url);
 42            options.max_connections(5);
 43            let mut db = Database::new(options).await.unwrap();
 44            let sql = include_str!(concat!(
 45                env!("CARGO_MANIFEST_DIR"),
 46                "/migrations.sqlite/20221109000000_test_schema.sql"
 47            ));
 48            db.pool
 49                .execute(sea_orm::Statement::from_string(
 50                    db.pool.get_database_backend(),
 51                    sql,
 52                ))
 53                .await
 54                .unwrap();
 55            db.initialize_notification_kinds().await.unwrap();
 56            db
 57        });
 58
 59        db.test_options = Some(DatabaseTestOptions {
 60            executor,
 61            runtime,
 62            query_failure_probability: parking_lot::Mutex::new(0.0),
 63        });
 64
 65        Self {
 66            db: Some(Arc::new(db)),
 67            connection: None,
 68        }
 69    }
 70
 71    pub fn postgres(executor: BackgroundExecutor) -> Self {
 72        static LOCK: Mutex<()> = Mutex::new(());
 73
 74        let _guard = LOCK.lock();
 75        let mut rng = StdRng::from_os_rng();
 76        let url = format!(
 77            "postgres://postgres@localhost/zed-test-{}",
 78            rng.random::<u128>()
 79        );
 80        let runtime = tokio::runtime::Builder::new_current_thread()
 81            .enable_io()
 82            .enable_time()
 83            .build()
 84            .unwrap();
 85
 86        let mut db = runtime.block_on(async {
 87            sqlx::Postgres::create_database(&url)
 88                .await
 89                .expect("failed to create test db");
 90            let mut options = ConnectOptions::new(url);
 91            options
 92                .max_connections(5)
 93                .idle_timeout(Duration::from_secs(0));
 94            let mut db = Database::new(options).await.unwrap();
 95            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 96            run_database_migrations(db.options(), migrations_path)
 97                .await
 98                .unwrap();
 99            db.initialize_notification_kinds().await.unwrap();
100            db
101        });
102
103        db.test_options = Some(DatabaseTestOptions {
104            executor,
105            runtime,
106            query_failure_probability: parking_lot::Mutex::new(0.0),
107        });
108
109        Self {
110            db: Some(Arc::new(db)),
111            connection: None,
112        }
113    }
114
115    pub fn db(&self) -> &Arc<Database> {
116        self.db.as_ref().unwrap()
117    }
118
119    pub fn set_query_failure_probability(&self, probability: f64) {
120        let database = self.db.as_ref().unwrap();
121        let test_options = database.test_options.as_ref().unwrap();
122        *test_options.query_failure_probability.lock() = probability;
123    }
124}
125
126#[macro_export]
127macro_rules! test_both_dbs {
128    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
129        #[cfg(target_os = "macos")]
130        #[gpui::test]
131        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
132            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
133            $test_name(test_db.db()).await;
134        }
135
136        #[gpui::test]
137        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
138            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
139            $test_name(test_db.db()).await;
140        }
141    };
142}
143
144impl Drop for TestDb {
145    fn drop(&mut self) {
146        let db = self.db.take().unwrap();
147        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
148            db.test_options.as_ref().unwrap().runtime.block_on(async {
149                use util::ResultExt;
150                let query = "
151                        SELECT pg_terminate_backend(pg_stat_activity.pid)
152                        FROM pg_stat_activity
153                        WHERE
154                            pg_stat_activity.datname = current_database() AND
155                            pid <> pg_backend_pid();
156                    ";
157                db.pool
158                    .execute(sea_orm::Statement::from_string(
159                        db.pool.get_database_backend(),
160                        query,
161                    ))
162                    .await
163                    .log_err();
164                sqlx::Postgres::drop_database(db.options.get_url())
165                    .await
166                    .log_err();
167            })
168        }
169    }
170}
171
172#[track_caller]
173fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
174    let expected_channels = expected.into_iter().collect::<HashSet<_>>();
175    let actual_channels = actual.into_iter().collect::<HashSet<_>>();
176    pretty_assertions::assert_eq!(expected_channels, actual_channels);
177}
178
179fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
180    use std::collections::HashMap;
181
182    let mut result = Vec::new();
183    let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
184
185    for (id, parent_path, name) in channels {
186        let parent_key = parent_path.to_vec();
187        let order = if parent_key.is_empty() {
188            1
189        } else {
190            *order_by_parent
191                .entry(parent_key.clone())
192                .and_modify(|e| *e += 1)
193                .or_insert(1)
194        };
195
196        result.push(Channel {
197            id: *id,
198            name: (*name).to_owned(),
199            visibility: ChannelVisibility::Members,
200            parent_path: parent_key,
201            channel_order: order,
202        });
203    }
204
205    result
206}
207
208static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
209
210async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
211    db.create_user(
212        email,
213        None,
214        false,
215        NewUserParams {
216            github_login: email[0..email.find('@').unwrap()].to_string(),
217            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
218        },
219    )
220    .await
221    .unwrap()
222    .user_id
223}