tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod db_tests;
  4mod extension_tests;
  5mod migrations;
  6
  7use std::sync::Arc;
  8use std::sync::atomic::{AtomicI32, Ordering::SeqCst};
  9use std::time::Duration;
 10
 11use gpui::BackgroundExecutor;
 12use parking_lot::Mutex;
 13use rand::prelude::*;
 14use sea_orm::ConnectionTrait;
 15use sqlx::migrate::MigrateDatabase;
 16
 17use self::migrations::run_database_migrations;
 18
 19use super::*;
 20
 21pub struct TestDb {
 22    pub db: Option<Arc<Database>>,
 23    pub connection: Option<sqlx::AnyConnection>,
 24}
 25
 26impl TestDb {
 27    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 28        let url = "sqlite::memory:";
 29        let runtime = tokio::runtime::Builder::new_current_thread()
 30            .enable_io()
 31            .enable_time()
 32            .build()
 33            .unwrap();
 34
 35        let mut db = runtime.block_on(async {
 36            let mut options = ConnectOptions::new(url);
 37            options.max_connections(5);
 38            let mut db = Database::new(options).await.unwrap();
 39            let sql = include_str!(concat!(
 40                env!("CARGO_MANIFEST_DIR"),
 41                "/migrations.sqlite/20221109000000_test_schema.sql"
 42            ));
 43            db.pool
 44                .execute(sea_orm::Statement::from_string(
 45                    db.pool.get_database_backend(),
 46                    sql,
 47                ))
 48                .await
 49                .unwrap();
 50            db.initialize_notification_kinds().await.unwrap();
 51            db
 52        });
 53
 54        db.test_options = Some(DatabaseTestOptions {
 55            executor,
 56            runtime,
 57            query_failure_probability: parking_lot::Mutex::new(0.0),
 58        });
 59
 60        Self {
 61            db: Some(Arc::new(db)),
 62            connection: None,
 63        }
 64    }
 65
 66    pub fn postgres(executor: BackgroundExecutor) -> Self {
 67        static LOCK: Mutex<()> = Mutex::new(());
 68
 69        let _guard = LOCK.lock();
 70        let mut rng = StdRng::from_os_rng();
 71        let url = format!(
 72            "postgres://postgres@localhost/zed-test-{}",
 73            rng.random::<u128>()
 74        );
 75        let runtime = tokio::runtime::Builder::new_current_thread()
 76            .enable_io()
 77            .enable_time()
 78            .build()
 79            .unwrap();
 80
 81        let mut db = runtime.block_on(async {
 82            sqlx::Postgres::create_database(&url)
 83                .await
 84                .expect("failed to create test db");
 85            let mut options = ConnectOptions::new(url);
 86            options
 87                .max_connections(5)
 88                .idle_timeout(Duration::from_secs(0));
 89            let mut db = Database::new(options).await.unwrap();
 90            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 91            run_database_migrations(db.options(), migrations_path)
 92                .await
 93                .unwrap();
 94            db.initialize_notification_kinds().await.unwrap();
 95            db
 96        });
 97
 98        db.test_options = Some(DatabaseTestOptions {
 99            executor,
100            runtime,
101            query_failure_probability: parking_lot::Mutex::new(0.0),
102        });
103
104        Self {
105            db: Some(Arc::new(db)),
106            connection: None,
107        }
108    }
109
110    pub fn db(&self) -> &Arc<Database> {
111        self.db.as_ref().unwrap()
112    }
113
114    pub fn set_query_failure_probability(&self, probability: f64) {
115        let database = self.db.as_ref().unwrap();
116        let test_options = database.test_options.as_ref().unwrap();
117        *test_options.query_failure_probability.lock() = probability;
118    }
119}
120
121#[macro_export]
122macro_rules! test_both_dbs {
123    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
124        #[cfg(target_os = "macos")]
125        #[gpui::test]
126        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
127            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
128            $test_name(test_db.db()).await;
129        }
130
131        #[gpui::test]
132        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
133            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
134            $test_name(test_db.db()).await;
135        }
136    };
137}
138
139impl Drop for TestDb {
140    fn drop(&mut self) {
141        let db = self.db.take().unwrap();
142        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
143            db.test_options.as_ref().unwrap().runtime.block_on(async {
144                use util::ResultExt;
145                let query = "
146                        SELECT pg_terminate_backend(pg_stat_activity.pid)
147                        FROM pg_stat_activity
148                        WHERE
149                            pg_stat_activity.datname = current_database() AND
150                            pid <> pg_backend_pid();
151                    ";
152                db.pool
153                    .execute(sea_orm::Statement::from_string(
154                        db.pool.get_database_backend(),
155                        query,
156                    ))
157                    .await
158                    .log_err();
159                sqlx::Postgres::drop_database(db.options.get_url())
160                    .await
161                    .log_err();
162            })
163        }
164    }
165}
166
167#[track_caller]
168fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
169    let expected_channels = expected.into_iter().collect::<HashSet<_>>();
170    let actual_channels = actual.into_iter().collect::<HashSet<_>>();
171    pretty_assertions::assert_eq!(expected_channels, actual_channels);
172}
173
174fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
175    use std::collections::HashMap;
176
177    let mut result = Vec::new();
178    let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
179
180    for (id, parent_path, name) in channels {
181        let parent_key = parent_path.to_vec();
182        let order = if parent_key.is_empty() {
183            1
184        } else {
185            *order_by_parent
186                .entry(parent_key.clone())
187                .and_modify(|e| *e += 1)
188                .or_insert(1)
189        };
190
191        result.push(Channel {
192            id: *id,
193            name: (*name).to_owned(),
194            visibility: ChannelVisibility::Members,
195            parent_path: parent_key,
196            channel_order: order,
197        });
198    }
199
200    result
201}
202
203static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
204
205async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
206    db.create_user(
207        email,
208        None,
209        false,
210        NewUserParams {
211            github_login: email[0..email.find('@').unwrap()].to_string(),
212            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
213        },
214    )
215    .await
216    .unwrap()
217    .user_id
218}