tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod contributor_tests;
  4mod db_tests;
  5// we only run postgres tests on macos right now
  6#[cfg(target_os = "macos")]
  7mod embedding_tests;
  8mod extension_tests;
  9mod feature_flag_tests;
 10mod user_tests;
 11
 12use crate::migrations::run_database_migrations;
 13
 14use super::*;
 15use gpui::BackgroundExecutor;
 16use parking_lot::Mutex;
 17use rand::prelude::*;
 18use sea_orm::ConnectionTrait;
 19use sqlx::migrate::MigrateDatabase;
 20use std::{
 21    sync::{
 22        Arc,
 23        atomic::{AtomicI32, Ordering::SeqCst},
 24    },
 25    time::Duration,
 26};
 27
 28pub struct TestDb {
 29    pub db: Option<Arc<Database>>,
 30    pub connection: Option<sqlx::AnyConnection>,
 31}
 32
 33impl TestDb {
 34    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 35        let url = "sqlite::memory:";
 36        let runtime = tokio::runtime::Builder::new_current_thread()
 37            .enable_io()
 38            .enable_time()
 39            .build()
 40            .unwrap();
 41
 42        let mut db = runtime.block_on(async {
 43            let mut options = ConnectOptions::new(url);
 44            options.max_connections(5);
 45            let mut db = Database::new(options).await.unwrap();
 46            let sql = include_str!(concat!(
 47                env!("CARGO_MANIFEST_DIR"),
 48                "/migrations.sqlite/20221109000000_test_schema.sql"
 49            ));
 50            db.pool
 51                .execute(sea_orm::Statement::from_string(
 52                    db.pool.get_database_backend(),
 53                    sql,
 54                ))
 55                .await
 56                .unwrap();
 57            db.initialize_notification_kinds().await.unwrap();
 58            db
 59        });
 60
 61        db.test_options = Some(DatabaseTestOptions {
 62            executor,
 63            runtime,
 64            query_failure_probability: parking_lot::Mutex::new(0.0),
 65        });
 66
 67        Self {
 68            db: Some(Arc::new(db)),
 69            connection: None,
 70        }
 71    }
 72
 73    pub fn postgres(executor: BackgroundExecutor) -> Self {
 74        static LOCK: Mutex<()> = Mutex::new(());
 75
 76        let _guard = LOCK.lock();
 77        let mut rng = StdRng::from_os_rng();
 78        let url = format!(
 79            "postgres://postgres@localhost/zed-test-{}",
 80            rng.random::<u128>()
 81        );
 82        let runtime = tokio::runtime::Builder::new_current_thread()
 83            .enable_io()
 84            .enable_time()
 85            .build()
 86            .unwrap();
 87
 88        let mut db = runtime.block_on(async {
 89            sqlx::Postgres::create_database(&url)
 90                .await
 91                .expect("failed to create test db");
 92            let mut options = ConnectOptions::new(url);
 93            options
 94                .max_connections(5)
 95                .idle_timeout(Duration::from_secs(0));
 96            let mut db = Database::new(options).await.unwrap();
 97            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 98            run_database_migrations(db.options(), migrations_path)
 99                .await
100                .unwrap();
101            db.initialize_notification_kinds().await.unwrap();
102            db
103        });
104
105        db.test_options = Some(DatabaseTestOptions {
106            executor,
107            runtime,
108            query_failure_probability: parking_lot::Mutex::new(0.0),
109        });
110
111        Self {
112            db: Some(Arc::new(db)),
113            connection: None,
114        }
115    }
116
117    pub fn db(&self) -> &Arc<Database> {
118        self.db.as_ref().unwrap()
119    }
120
121    pub fn set_query_failure_probability(&self, probability: f64) {
122        let database = self.db.as_ref().unwrap();
123        let test_options = database.test_options.as_ref().unwrap();
124        *test_options.query_failure_probability.lock() = probability;
125    }
126}
127
128#[macro_export]
129macro_rules! test_both_dbs {
130    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
131        #[cfg(target_os = "macos")]
132        #[gpui::test]
133        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
134            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
135            $test_name(test_db.db()).await;
136        }
137
138        #[gpui::test]
139        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
140            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
141            $test_name(test_db.db()).await;
142        }
143    };
144}
145
146impl Drop for TestDb {
147    fn drop(&mut self) {
148        let db = self.db.take().unwrap();
149        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
150            db.test_options.as_ref().unwrap().runtime.block_on(async {
151                use util::ResultExt;
152                let query = "
153                        SELECT pg_terminate_backend(pg_stat_activity.pid)
154                        FROM pg_stat_activity
155                        WHERE
156                            pg_stat_activity.datname = current_database() AND
157                            pid <> pg_backend_pid();
158                    ";
159                db.pool
160                    .execute(sea_orm::Statement::from_string(
161                        db.pool.get_database_backend(),
162                        query,
163                    ))
164                    .await
165                    .log_err();
166                sqlx::Postgres::drop_database(db.options.get_url())
167                    .await
168                    .log_err();
169            })
170        }
171    }
172}
173
174#[track_caller]
175fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
176    let expected_channels = expected.into_iter().collect::<HashSet<_>>();
177    let actual_channels = actual.into_iter().collect::<HashSet<_>>();
178    pretty_assertions::assert_eq!(expected_channels, actual_channels);
179}
180
181fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
182    use std::collections::HashMap;
183
184    let mut result = Vec::new();
185    let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
186
187    for (id, parent_path, name) in channels {
188        let parent_key = parent_path.to_vec();
189        let order = if parent_key.is_empty() {
190            1
191        } else {
192            *order_by_parent
193                .entry(parent_key.clone())
194                .and_modify(|e| *e += 1)
195                .or_insert(1)
196        };
197
198        result.push(Channel {
199            id: *id,
200            name: name.to_string(),
201            visibility: ChannelVisibility::Members,
202            parent_path: parent_key,
203            channel_order: order,
204        });
205    }
206
207    result
208}
209
210static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
211
212async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
213    db.create_user(
214        email,
215        None,
216        false,
217        NewUserParams {
218            github_login: email[0..email.find('@').unwrap()].to_string(),
219            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
220        },
221    )
222    .await
223    .unwrap()
224    .user_id
225}