tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod contributor_tests;
  4mod db_tests;
  5mod extension_tests;
  6
  7use crate::migrations::run_database_migrations;
  8
  9use super::*;
 10use gpui::BackgroundExecutor;
 11use parking_lot::Mutex;
 12use rand::prelude::*;
 13use sea_orm::ConnectionTrait;
 14use sqlx::migrate::MigrateDatabase;
 15use std::{
 16    sync::{
 17        Arc,
 18        atomic::{AtomicI32, Ordering::SeqCst},
 19    },
 20    time::Duration,
 21};
 22
 23pub struct TestDb {
 24    pub db: Option<Arc<Database>>,
 25    pub connection: Option<sqlx::AnyConnection>,
 26}
 27
 28impl TestDb {
 29    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 30        let url = "sqlite::memory:";
 31        let runtime = tokio::runtime::Builder::new_current_thread()
 32            .enable_io()
 33            .enable_time()
 34            .build()
 35            .unwrap();
 36
 37        let mut db = runtime.block_on(async {
 38            let mut options = ConnectOptions::new(url);
 39            options.max_connections(5);
 40            let mut db = Database::new(options).await.unwrap();
 41            let sql = include_str!(concat!(
 42                env!("CARGO_MANIFEST_DIR"),
 43                "/migrations.sqlite/20221109000000_test_schema.sql"
 44            ));
 45            db.pool
 46                .execute(sea_orm::Statement::from_string(
 47                    db.pool.get_database_backend(),
 48                    sql,
 49                ))
 50                .await
 51                .unwrap();
 52            db.initialize_notification_kinds().await.unwrap();
 53            db
 54        });
 55
 56        db.test_options = Some(DatabaseTestOptions {
 57            executor,
 58            runtime,
 59            query_failure_probability: parking_lot::Mutex::new(0.0),
 60        });
 61
 62        Self {
 63            db: Some(Arc::new(db)),
 64            connection: None,
 65        }
 66    }
 67
 68    pub fn postgres(executor: BackgroundExecutor) -> Self {
 69        static LOCK: Mutex<()> = Mutex::new(());
 70
 71        let _guard = LOCK.lock();
 72        let mut rng = StdRng::from_os_rng();
 73        let url = format!(
 74            "postgres://postgres@localhost/zed-test-{}",
 75            rng.random::<u128>()
 76        );
 77        let runtime = tokio::runtime::Builder::new_current_thread()
 78            .enable_io()
 79            .enable_time()
 80            .build()
 81            .unwrap();
 82
 83        let mut db = runtime.block_on(async {
 84            sqlx::Postgres::create_database(&url)
 85                .await
 86                .expect("failed to create test db");
 87            let mut options = ConnectOptions::new(url);
 88            options
 89                .max_connections(5)
 90                .idle_timeout(Duration::from_secs(0));
 91            let mut db = Database::new(options).await.unwrap();
 92            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 93            run_database_migrations(db.options(), migrations_path)
 94                .await
 95                .unwrap();
 96            db.initialize_notification_kinds().await.unwrap();
 97            db
 98        });
 99
100        db.test_options = Some(DatabaseTestOptions {
101            executor,
102            runtime,
103            query_failure_probability: parking_lot::Mutex::new(0.0),
104        });
105
106        Self {
107            db: Some(Arc::new(db)),
108            connection: None,
109        }
110    }
111
112    pub fn db(&self) -> &Arc<Database> {
113        self.db.as_ref().unwrap()
114    }
115
116    pub fn set_query_failure_probability(&self, probability: f64) {
117        let database = self.db.as_ref().unwrap();
118        let test_options = database.test_options.as_ref().unwrap();
119        *test_options.query_failure_probability.lock() = probability;
120    }
121}
122
123#[macro_export]
124macro_rules! test_both_dbs {
125    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
126        #[cfg(target_os = "macos")]
127        #[gpui::test]
128        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
129            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
130            $test_name(test_db.db()).await;
131        }
132
133        #[gpui::test]
134        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
135            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
136            $test_name(test_db.db()).await;
137        }
138    };
139}
140
141impl Drop for TestDb {
142    fn drop(&mut self) {
143        let db = self.db.take().unwrap();
144        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
145            db.test_options.as_ref().unwrap().runtime.block_on(async {
146                use util::ResultExt;
147                let query = "
148                        SELECT pg_terminate_backend(pg_stat_activity.pid)
149                        FROM pg_stat_activity
150                        WHERE
151                            pg_stat_activity.datname = current_database() AND
152                            pid <> pg_backend_pid();
153                    ";
154                db.pool
155                    .execute(sea_orm::Statement::from_string(
156                        db.pool.get_database_backend(),
157                        query,
158                    ))
159                    .await
160                    .log_err();
161                sqlx::Postgres::drop_database(db.options.get_url())
162                    .await
163                    .log_err();
164            })
165        }
166    }
167}
168
169#[track_caller]
170fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
171    let expected_channels = expected.into_iter().collect::<HashSet<_>>();
172    let actual_channels = actual.into_iter().collect::<HashSet<_>>();
173    pretty_assertions::assert_eq!(expected_channels, actual_channels);
174}
175
176fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
177    use std::collections::HashMap;
178
179    let mut result = Vec::new();
180    let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
181
182    for (id, parent_path, name) in channels {
183        let parent_key = parent_path.to_vec();
184        let order = if parent_key.is_empty() {
185            1
186        } else {
187            *order_by_parent
188                .entry(parent_key.clone())
189                .and_modify(|e| *e += 1)
190                .or_insert(1)
191        };
192
193        result.push(Channel {
194            id: *id,
195            name: (*name).to_owned(),
196            visibility: ChannelVisibility::Members,
197            parent_path: parent_key,
198            channel_order: order,
199        });
200    }
201
202    result
203}
204
205static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
206
207async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
208    db.create_user(
209        email,
210        None,
211        false,
212        NewUserParams {
213            github_login: email[0..email.find('@').unwrap()].to_string(),
214            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
215        },
216    )
217    .await
218    .unwrap()
219    .user_id
220}