tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod contributor_tests;
  4mod db_tests;
  5mod extension_tests;
  6mod migrations;
  7
  8use std::sync::Arc;
  9use std::sync::atomic::{AtomicI32, Ordering::SeqCst};
 10use std::time::Duration;
 11
 12use gpui::BackgroundExecutor;
 13use parking_lot::Mutex;
 14use rand::prelude::*;
 15use sea_orm::ConnectionTrait;
 16use sqlx::migrate::MigrateDatabase;
 17
 18use self::migrations::run_database_migrations;
 19
 20use super::*;
 21
 22pub struct TestDb {
 23    pub db: Option<Arc<Database>>,
 24    pub connection: Option<sqlx::AnyConnection>,
 25}
 26
 27impl TestDb {
 28    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 29        let url = "sqlite::memory:";
 30        let runtime = tokio::runtime::Builder::new_current_thread()
 31            .enable_io()
 32            .enable_time()
 33            .build()
 34            .unwrap();
 35
 36        let mut db = runtime.block_on(async {
 37            let mut options = ConnectOptions::new(url);
 38            options.max_connections(5);
 39            let mut db = Database::new(options).await.unwrap();
 40            let sql = include_str!(concat!(
 41                env!("CARGO_MANIFEST_DIR"),
 42                "/migrations.sqlite/20221109000000_test_schema.sql"
 43            ));
 44            db.pool
 45                .execute(sea_orm::Statement::from_string(
 46                    db.pool.get_database_backend(),
 47                    sql,
 48                ))
 49                .await
 50                .unwrap();
 51            db.initialize_notification_kinds().await.unwrap();
 52            db
 53        });
 54
 55        db.test_options = Some(DatabaseTestOptions {
 56            executor,
 57            runtime,
 58            query_failure_probability: parking_lot::Mutex::new(0.0),
 59        });
 60
 61        Self {
 62            db: Some(Arc::new(db)),
 63            connection: None,
 64        }
 65    }
 66
 67    pub fn postgres(executor: BackgroundExecutor) -> Self {
 68        static LOCK: Mutex<()> = Mutex::new(());
 69
 70        let _guard = LOCK.lock();
 71        let mut rng = StdRng::from_os_rng();
 72        let url = format!(
 73            "postgres://postgres@localhost/zed-test-{}",
 74            rng.random::<u128>()
 75        );
 76        let runtime = tokio::runtime::Builder::new_current_thread()
 77            .enable_io()
 78            .enable_time()
 79            .build()
 80            .unwrap();
 81
 82        let mut db = runtime.block_on(async {
 83            sqlx::Postgres::create_database(&url)
 84                .await
 85                .expect("failed to create test db");
 86            let mut options = ConnectOptions::new(url);
 87            options
 88                .max_connections(5)
 89                .idle_timeout(Duration::from_secs(0));
 90            let mut db = Database::new(options).await.unwrap();
 91            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 92            run_database_migrations(db.options(), migrations_path)
 93                .await
 94                .unwrap();
 95            db.initialize_notification_kinds().await.unwrap();
 96            db
 97        });
 98
 99        db.test_options = Some(DatabaseTestOptions {
100            executor,
101            runtime,
102            query_failure_probability: parking_lot::Mutex::new(0.0),
103        });
104
105        Self {
106            db: Some(Arc::new(db)),
107            connection: None,
108        }
109    }
110
111    pub fn db(&self) -> &Arc<Database> {
112        self.db.as_ref().unwrap()
113    }
114
115    pub fn set_query_failure_probability(&self, probability: f64) {
116        let database = self.db.as_ref().unwrap();
117        let test_options = database.test_options.as_ref().unwrap();
118        *test_options.query_failure_probability.lock() = probability;
119    }
120}
121
122#[macro_export]
123macro_rules! test_both_dbs {
124    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
125        #[cfg(target_os = "macos")]
126        #[gpui::test]
127        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
128            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
129            $test_name(test_db.db()).await;
130        }
131
132        #[gpui::test]
133        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
134            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
135            $test_name(test_db.db()).await;
136        }
137    };
138}
139
140impl Drop for TestDb {
141    fn drop(&mut self) {
142        let db = self.db.take().unwrap();
143        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
144            db.test_options.as_ref().unwrap().runtime.block_on(async {
145                use util::ResultExt;
146                let query = "
147                        SELECT pg_terminate_backend(pg_stat_activity.pid)
148                        FROM pg_stat_activity
149                        WHERE
150                            pg_stat_activity.datname = current_database() AND
151                            pid <> pg_backend_pid();
152                    ";
153                db.pool
154                    .execute(sea_orm::Statement::from_string(
155                        db.pool.get_database_backend(),
156                        query,
157                    ))
158                    .await
159                    .log_err();
160                sqlx::Postgres::drop_database(db.options.get_url())
161                    .await
162                    .log_err();
163            })
164        }
165    }
166}
167
168#[track_caller]
169fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
170    let expected_channels = expected.into_iter().collect::<HashSet<_>>();
171    let actual_channels = actual.into_iter().collect::<HashSet<_>>();
172    pretty_assertions::assert_eq!(expected_channels, actual_channels);
173}
174
175fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
176    use std::collections::HashMap;
177
178    let mut result = Vec::new();
179    let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
180
181    for (id, parent_path, name) in channels {
182        let parent_key = parent_path.to_vec();
183        let order = if parent_key.is_empty() {
184            1
185        } else {
186            *order_by_parent
187                .entry(parent_key.clone())
188                .and_modify(|e| *e += 1)
189                .or_insert(1)
190        };
191
192        result.push(Channel {
193            id: *id,
194            name: (*name).to_owned(),
195            visibility: ChannelVisibility::Members,
196            parent_path: parent_key,
197            channel_order: order,
198        });
199    }
200
201    result
202}
203
204static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
205
206async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
207    db.create_user(
208        email,
209        None,
210        false,
211        NewUserParams {
212            github_login: email[0..email.find('@').unwrap()].to_string(),
213            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
214        },
215    )
216    .await
217    .unwrap()
218    .user_id
219}