tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod contributor_tests;
  4mod db_tests;
  5// we only run postgres tests on macos right now
  6#[cfg(target_os = "macos")]
  7mod embedding_tests;
  8mod extension_tests;
  9mod user_tests;
 10
 11use crate::migrations::run_database_migrations;
 12
 13use super::*;
 14use gpui::BackgroundExecutor;
 15use parking_lot::Mutex;
 16use rand::prelude::*;
 17use sea_orm::ConnectionTrait;
 18use sqlx::migrate::MigrateDatabase;
 19use std::{
 20    sync::{
 21        Arc,
 22        atomic::{AtomicI32, Ordering::SeqCst},
 23    },
 24    time::Duration,
 25};
 26
 27pub struct TestDb {
 28    pub db: Option<Arc<Database>>,
 29    pub connection: Option<sqlx::AnyConnection>,
 30}
 31
 32impl TestDb {
 33    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 34        let url = "sqlite::memory:";
 35        let runtime = tokio::runtime::Builder::new_current_thread()
 36            .enable_io()
 37            .enable_time()
 38            .build()
 39            .unwrap();
 40
 41        let mut db = runtime.block_on(async {
 42            let mut options = ConnectOptions::new(url);
 43            options.max_connections(5);
 44            let mut db = Database::new(options).await.unwrap();
 45            let sql = include_str!(concat!(
 46                env!("CARGO_MANIFEST_DIR"),
 47                "/migrations.sqlite/20221109000000_test_schema.sql"
 48            ));
 49            db.pool
 50                .execute(sea_orm::Statement::from_string(
 51                    db.pool.get_database_backend(),
 52                    sql,
 53                ))
 54                .await
 55                .unwrap();
 56            db.initialize_notification_kinds().await.unwrap();
 57            db
 58        });
 59
 60        db.test_options = Some(DatabaseTestOptions {
 61            executor,
 62            runtime,
 63            query_failure_probability: parking_lot::Mutex::new(0.0),
 64        });
 65
 66        Self {
 67            db: Some(Arc::new(db)),
 68            connection: None,
 69        }
 70    }
 71
 72    pub fn postgres(executor: BackgroundExecutor) -> Self {
 73        static LOCK: Mutex<()> = Mutex::new(());
 74
 75        let _guard = LOCK.lock();
 76        let mut rng = StdRng::from_os_rng();
 77        let url = format!(
 78            "postgres://postgres@localhost/zed-test-{}",
 79            rng.random::<u128>()
 80        );
 81        let runtime = tokio::runtime::Builder::new_current_thread()
 82            .enable_io()
 83            .enable_time()
 84            .build()
 85            .unwrap();
 86
 87        let mut db = runtime.block_on(async {
 88            sqlx::Postgres::create_database(&url)
 89                .await
 90                .expect("failed to create test db");
 91            let mut options = ConnectOptions::new(url);
 92            options
 93                .max_connections(5)
 94                .idle_timeout(Duration::from_secs(0));
 95            let mut db = Database::new(options).await.unwrap();
 96            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 97            run_database_migrations(db.options(), migrations_path)
 98                .await
 99                .unwrap();
100            db.initialize_notification_kinds().await.unwrap();
101            db
102        });
103
104        db.test_options = Some(DatabaseTestOptions {
105            executor,
106            runtime,
107            query_failure_probability: parking_lot::Mutex::new(0.0),
108        });
109
110        Self {
111            db: Some(Arc::new(db)),
112            connection: None,
113        }
114    }
115
116    pub fn db(&self) -> &Arc<Database> {
117        self.db.as_ref().unwrap()
118    }
119
120    pub fn set_query_failure_probability(&self, probability: f64) {
121        let database = self.db.as_ref().unwrap();
122        let test_options = database.test_options.as_ref().unwrap();
123        *test_options.query_failure_probability.lock() = probability;
124    }
125}
126
127#[macro_export]
128macro_rules! test_both_dbs {
129    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
130        #[cfg(target_os = "macos")]
131        #[gpui::test]
132        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
133            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
134            $test_name(test_db.db()).await;
135        }
136
137        #[gpui::test]
138        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
139            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
140            $test_name(test_db.db()).await;
141        }
142    };
143}
144
145impl Drop for TestDb {
146    fn drop(&mut self) {
147        let db = self.db.take().unwrap();
148        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
149            db.test_options.as_ref().unwrap().runtime.block_on(async {
150                use util::ResultExt;
151                let query = "
152                        SELECT pg_terminate_backend(pg_stat_activity.pid)
153                        FROM pg_stat_activity
154                        WHERE
155                            pg_stat_activity.datname = current_database() AND
156                            pid <> pg_backend_pid();
157                    ";
158                db.pool
159                    .execute(sea_orm::Statement::from_string(
160                        db.pool.get_database_backend(),
161                        query,
162                    ))
163                    .await
164                    .log_err();
165                sqlx::Postgres::drop_database(db.options.get_url())
166                    .await
167                    .log_err();
168            })
169        }
170    }
171}
172
173#[track_caller]
174fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
175    let expected_channels = expected.into_iter().collect::<HashSet<_>>();
176    let actual_channels = actual.into_iter().collect::<HashSet<_>>();
177    pretty_assertions::assert_eq!(expected_channels, actual_channels);
178}
179
180fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
181    use std::collections::HashMap;
182
183    let mut result = Vec::new();
184    let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
185
186    for (id, parent_path, name) in channels {
187        let parent_key = parent_path.to_vec();
188        let order = if parent_key.is_empty() {
189            1
190        } else {
191            *order_by_parent
192                .entry(parent_key.clone())
193                .and_modify(|e| *e += 1)
194                .or_insert(1)
195        };
196
197        result.push(Channel {
198            id: *id,
199            name: name.to_string(),
200            visibility: ChannelVisibility::Members,
201            parent_path: parent_key,
202            channel_order: order,
203        });
204    }
205
206    result
207}
208
209static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
210
211async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
212    db.create_user(
213        email,
214        None,
215        false,
216        NewUserParams {
217            github_login: email[0..email.find('@').unwrap()].to_string(),
218            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
219        },
220    )
221    .await
222    .unwrap()
223    .user_id
224}