tests.rs

  1mod billing_subscription_tests;
  2mod buffer_tests;
  3mod channel_tests;
  4mod contributor_tests;
  5mod db_tests;
  6// we only run postgres tests on macos right now
  7#[cfg(target_os = "macos")]
  8mod embedding_tests;
  9mod extension_tests;
 10mod feature_flag_tests;
 11mod message_tests;
 12mod processed_stripe_event_tests;
 13mod user_tests;
 14
 15use crate::migrations::run_database_migrations;
 16
 17use super::*;
 18use gpui::BackgroundExecutor;
 19use parking_lot::Mutex;
 20use rand::prelude::*;
 21use sea_orm::ConnectionTrait;
 22use sqlx::migrate::MigrateDatabase;
 23use std::{
 24    sync::{
 25        Arc,
 26        atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 27    },
 28    time::Duration,
 29};
 30
 31pub struct TestDb {
 32    pub db: Option<Arc<Database>>,
 33    pub connection: Option<sqlx::AnyConnection>,
 34}
 35
 36impl TestDb {
 37    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 38        let url = "sqlite::memory:";
 39        let runtime = tokio::runtime::Builder::new_current_thread()
 40            .enable_io()
 41            .enable_time()
 42            .build()
 43            .unwrap();
 44
 45        let mut db = runtime.block_on(async {
 46            let mut options = ConnectOptions::new(url);
 47            options.max_connections(5);
 48            let mut db = Database::new(options).await.unwrap();
 49            let sql = include_str!(concat!(
 50                env!("CARGO_MANIFEST_DIR"),
 51                "/migrations.sqlite/20221109000000_test_schema.sql"
 52            ));
 53            db.pool
 54                .execute(sea_orm::Statement::from_string(
 55                    db.pool.get_database_backend(),
 56                    sql,
 57                ))
 58                .await
 59                .unwrap();
 60            db.initialize_notification_kinds().await.unwrap();
 61            db
 62        });
 63
 64        db.test_options = Some(DatabaseTestOptions {
 65            executor,
 66            runtime,
 67            query_failure_probability: parking_lot::Mutex::new(0.0),
 68        });
 69
 70        Self {
 71            db: Some(Arc::new(db)),
 72            connection: None,
 73        }
 74    }
 75
 76    pub fn postgres(executor: BackgroundExecutor) -> Self {
 77        static LOCK: Mutex<()> = Mutex::new(());
 78
 79        let _guard = LOCK.lock();
 80        let mut rng = StdRng::from_entropy();
 81        let url = format!(
 82            "postgres://postgres@localhost/zed-test-{}",
 83            rng.r#gen::<u128>()
 84        );
 85        let runtime = tokio::runtime::Builder::new_current_thread()
 86            .enable_io()
 87            .enable_time()
 88            .build()
 89            .unwrap();
 90
 91        let mut db = runtime.block_on(async {
 92            sqlx::Postgres::create_database(&url)
 93                .await
 94                .expect("failed to create test db");
 95            let mut options = ConnectOptions::new(url);
 96            options
 97                .max_connections(5)
 98                .idle_timeout(Duration::from_secs(0));
 99            let mut db = Database::new(options).await.unwrap();
100            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
101            run_database_migrations(db.options(), migrations_path)
102                .await
103                .unwrap();
104            db.initialize_notification_kinds().await.unwrap();
105            db
106        });
107
108        db.test_options = Some(DatabaseTestOptions {
109            executor,
110            runtime,
111            query_failure_probability: parking_lot::Mutex::new(0.0),
112        });
113
114        Self {
115            db: Some(Arc::new(db)),
116            connection: None,
117        }
118    }
119
120    pub fn db(&self) -> &Arc<Database> {
121        self.db.as_ref().unwrap()
122    }
123
124    pub fn set_query_failure_probability(&self, probability: f64) {
125        let database = self.db.as_ref().unwrap();
126        let test_options = database.test_options.as_ref().unwrap();
127        *test_options.query_failure_probability.lock() = probability;
128    }
129}
130
131#[macro_export]
132macro_rules! test_both_dbs {
133    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
134        #[cfg(target_os = "macos")]
135        #[gpui::test]
136        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
137            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
138            $test_name(test_db.db()).await;
139        }
140
141        #[gpui::test]
142        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
143            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
144            $test_name(test_db.db()).await;
145        }
146    };
147}
148
149impl Drop for TestDb {
150    fn drop(&mut self) {
151        let db = self.db.take().unwrap();
152        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
153            db.test_options.as_ref().unwrap().runtime.block_on(async {
154                use util::ResultExt;
155                let query = "
156                        SELECT pg_terminate_backend(pg_stat_activity.pid)
157                        FROM pg_stat_activity
158                        WHERE
159                            pg_stat_activity.datname = current_database() AND
160                            pid <> pg_backend_pid();
161                    ";
162                db.pool
163                    .execute(sea_orm::Statement::from_string(
164                        db.pool.get_database_backend(),
165                        query,
166                    ))
167                    .await
168                    .log_err();
169                sqlx::Postgres::drop_database(db.options.get_url())
170                    .await
171                    .log_err();
172            })
173        }
174    }
175}
176
177#[track_caller]
178fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
179    let expected_channels = expected.into_iter().collect::<HashSet<_>>();
180    let actual_channels = actual.into_iter().collect::<HashSet<_>>();
181    pretty_assertions::assert_eq!(expected_channels, actual_channels);
182}
183
184fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
185    use std::collections::HashMap;
186
187    let mut result = Vec::new();
188    let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
189
190    for (id, parent_path, name) in channels {
191        let parent_key = parent_path.to_vec();
192        let order = if parent_key.is_empty() {
193            1
194        } else {
195            *order_by_parent
196                .entry(parent_key.clone())
197                .and_modify(|e| *e += 1)
198                .or_insert(1)
199        };
200
201        result.push(Channel {
202            id: *id,
203            name: name.to_string(),
204            visibility: ChannelVisibility::Members,
205            parent_path: parent_key,
206            channel_order: order,
207        });
208    }
209
210    result
211}
212
213static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
214
215async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
216    db.create_user(
217        email,
218        None,
219        false,
220        NewUserParams {
221            github_login: email[0..email.find('@').unwrap()].to_string(),
222            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
223        },
224    )
225    .await
226    .unwrap()
227    .user_id
228}
229
230static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
231fn new_test_connection(server: ServerId) -> ConnectionId {
232    ConnectionId {
233        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
234        owner_id: server.0 as u32,
235    }
236}