tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod contributor_tests;
  4mod db_tests;
  5// we only run postgres tests on macos right now
  6#[cfg(target_os = "macos")]
  7mod embedding_tests;
  8mod extension_tests;
  9mod feature_flag_tests;
 10mod message_tests;
 11mod processed_stripe_event_tests;
 12mod user_tests;
 13
 14use crate::migrations::run_database_migrations;
 15
 16use super::*;
 17use gpui::BackgroundExecutor;
 18use parking_lot::Mutex;
 19use rand::prelude::*;
 20use sea_orm::ConnectionTrait;
 21use sqlx::migrate::MigrateDatabase;
 22use std::{
 23    sync::{
 24        Arc,
 25        atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 26    },
 27    time::Duration,
 28};
 29
 30pub struct TestDb {
 31    pub db: Option<Arc<Database>>,
 32    pub connection: Option<sqlx::AnyConnection>,
 33}
 34
 35impl TestDb {
 36    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 37        let url = "sqlite::memory:";
 38        let runtime = tokio::runtime::Builder::new_current_thread()
 39            .enable_io()
 40            .enable_time()
 41            .build()
 42            .unwrap();
 43
 44        let mut db = runtime.block_on(async {
 45            let mut options = ConnectOptions::new(url);
 46            options.max_connections(5);
 47            let mut db = Database::new(options).await.unwrap();
 48            let sql = include_str!(concat!(
 49                env!("CARGO_MANIFEST_DIR"),
 50                "/migrations.sqlite/20221109000000_test_schema.sql"
 51            ));
 52            db.pool
 53                .execute(sea_orm::Statement::from_string(
 54                    db.pool.get_database_backend(),
 55                    sql,
 56                ))
 57                .await
 58                .unwrap();
 59            db.initialize_notification_kinds().await.unwrap();
 60            db
 61        });
 62
 63        db.test_options = Some(DatabaseTestOptions {
 64            executor,
 65            runtime,
 66            query_failure_probability: parking_lot::Mutex::new(0.0),
 67        });
 68
 69        Self {
 70            db: Some(Arc::new(db)),
 71            connection: None,
 72        }
 73    }
 74
 75    pub fn postgres(executor: BackgroundExecutor) -> Self {
 76        static LOCK: Mutex<()> = Mutex::new(());
 77
 78        let _guard = LOCK.lock();
 79        let mut rng = StdRng::from_entropy();
 80        let url = format!(
 81            "postgres://postgres@localhost/zed-test-{}",
 82            rng.r#gen::<u128>()
 83        );
 84        let runtime = tokio::runtime::Builder::new_current_thread()
 85            .enable_io()
 86            .enable_time()
 87            .build()
 88            .unwrap();
 89
 90        let mut db = runtime.block_on(async {
 91            sqlx::Postgres::create_database(&url)
 92                .await
 93                .expect("failed to create test db");
 94            let mut options = ConnectOptions::new(url);
 95            options
 96                .max_connections(5)
 97                .idle_timeout(Duration::from_secs(0));
 98            let mut db = Database::new(options).await.unwrap();
 99            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
100            run_database_migrations(db.options(), migrations_path)
101                .await
102                .unwrap();
103            db.initialize_notification_kinds().await.unwrap();
104            db
105        });
106
107        db.test_options = Some(DatabaseTestOptions {
108            executor,
109            runtime,
110            query_failure_probability: parking_lot::Mutex::new(0.0),
111        });
112
113        Self {
114            db: Some(Arc::new(db)),
115            connection: None,
116        }
117    }
118
119    pub fn db(&self) -> &Arc<Database> {
120        self.db.as_ref().unwrap()
121    }
122
123    pub fn set_query_failure_probability(&self, probability: f64) {
124        let database = self.db.as_ref().unwrap();
125        let test_options = database.test_options.as_ref().unwrap();
126        *test_options.query_failure_probability.lock() = probability;
127    }
128}
129
130#[macro_export]
131macro_rules! test_both_dbs {
132    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
133        #[cfg(target_os = "macos")]
134        #[gpui::test]
135        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
136            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
137            $test_name(test_db.db()).await;
138        }
139
140        #[gpui::test]
141        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
142            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
143            $test_name(test_db.db()).await;
144        }
145    };
146}
147
148impl Drop for TestDb {
149    fn drop(&mut self) {
150        let db = self.db.take().unwrap();
151        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
152            db.test_options.as_ref().unwrap().runtime.block_on(async {
153                use util::ResultExt;
154                let query = "
155                        SELECT pg_terminate_backend(pg_stat_activity.pid)
156                        FROM pg_stat_activity
157                        WHERE
158                            pg_stat_activity.datname = current_database() AND
159                            pid <> pg_backend_pid();
160                    ";
161                db.pool
162                    .execute(sea_orm::Statement::from_string(
163                        db.pool.get_database_backend(),
164                        query,
165                    ))
166                    .await
167                    .log_err();
168                sqlx::Postgres::drop_database(db.options.get_url())
169                    .await
170                    .log_err();
171            })
172        }
173    }
174}
175
176#[track_caller]
177fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
178    let expected_channels = expected.into_iter().collect::<HashSet<_>>();
179    let actual_channels = actual.into_iter().collect::<HashSet<_>>();
180    pretty_assertions::assert_eq!(expected_channels, actual_channels);
181}
182
183fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
184    use std::collections::HashMap;
185
186    let mut result = Vec::new();
187    let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
188
189    for (id, parent_path, name) in channels {
190        let parent_key = parent_path.to_vec();
191        let order = if parent_key.is_empty() {
192            1
193        } else {
194            *order_by_parent
195                .entry(parent_key.clone())
196                .and_modify(|e| *e += 1)
197                .or_insert(1)
198        };
199
200        result.push(Channel {
201            id: *id,
202            name: name.to_string(),
203            visibility: ChannelVisibility::Members,
204            parent_path: parent_key,
205            channel_order: order,
206        });
207    }
208
209    result
210}
211
212static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
213
214async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
215    db.create_user(
216        email,
217        None,
218        false,
219        NewUserParams {
220            github_login: email[0..email.find('@').unwrap()].to_string(),
221            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
222        },
223    )
224    .await
225    .unwrap()
226    .user_id
227}
228
229static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
230fn new_test_connection(server: ServerId) -> ConnectionId {
231    ConnectionId {
232        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
233        owner_id: server.0 as u32,
234    }
235}