tests.rs

  1mod billing_subscription_tests;
  2mod buffer_tests;
  3mod channel_tests;
  4mod contributor_tests;
  5mod db_tests;
  6// we only run postgres tests on macos right now
  7#[cfg(target_os = "macos")]
  8mod embedding_tests;
  9mod extension_tests;
 10mod feature_flag_tests;
 11mod message_tests;
 12
 13use super::*;
 14use gpui::BackgroundExecutor;
 15use parking_lot::Mutex;
 16use sea_orm::ConnectionTrait;
 17use sqlx::migrate::MigrateDatabase;
 18use std::sync::{
 19    atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 20    Arc,
 21};
 22
 23pub struct TestDb {
 24    pub db: Option<Arc<Database>>,
 25    pub connection: Option<sqlx::AnyConnection>,
 26}
 27
 28impl TestDb {
 29    pub fn sqlite(background: BackgroundExecutor) -> Self {
 30        let url = "sqlite::memory:";
 31        let runtime = tokio::runtime::Builder::new_current_thread()
 32            .enable_io()
 33            .enable_time()
 34            .build()
 35            .unwrap();
 36
 37        let mut db = runtime.block_on(async {
 38            let mut options = ConnectOptions::new(url);
 39            options.max_connections(5);
 40            let mut db = Database::new(options, Executor::Deterministic(background))
 41                .await
 42                .unwrap();
 43            let sql = include_str!(concat!(
 44                env!("CARGO_MANIFEST_DIR"),
 45                "/migrations.sqlite/20221109000000_test_schema.sql"
 46            ));
 47            db.pool
 48                .execute(sea_orm::Statement::from_string(
 49                    db.pool.get_database_backend(),
 50                    sql,
 51                ))
 52                .await
 53                .unwrap();
 54            db.initialize_notification_kinds().await.unwrap();
 55            db
 56        });
 57
 58        db.runtime = Some(runtime);
 59
 60        Self {
 61            db: Some(Arc::new(db)),
 62            connection: None,
 63        }
 64    }
 65
 66    pub fn postgres(background: BackgroundExecutor) -> Self {
 67        static LOCK: Mutex<()> = Mutex::new(());
 68
 69        let _guard = LOCK.lock();
 70        let mut rng = StdRng::from_entropy();
 71        let url = format!(
 72            "postgres://postgres@localhost/zed-test-{}",
 73            rng.gen::<u128>()
 74        );
 75        let runtime = tokio::runtime::Builder::new_current_thread()
 76            .enable_io()
 77            .enable_time()
 78            .build()
 79            .unwrap();
 80
 81        let mut db = runtime.block_on(async {
 82            sqlx::Postgres::create_database(&url)
 83                .await
 84                .expect("failed to create test db");
 85            let mut options = ConnectOptions::new(url);
 86            options
 87                .max_connections(5)
 88                .idle_timeout(Duration::from_secs(0));
 89            let mut db = Database::new(options, Executor::Deterministic(background))
 90                .await
 91                .unwrap();
 92            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 93            db.migrate(Path::new(migrations_path), false).await.unwrap();
 94            db.initialize_notification_kinds().await.unwrap();
 95            db
 96        });
 97
 98        db.runtime = Some(runtime);
 99
100        Self {
101            db: Some(Arc::new(db)),
102            connection: None,
103        }
104    }
105
106    pub fn db(&self) -> &Arc<Database> {
107        self.db.as_ref().unwrap()
108    }
109}
110
111#[macro_export]
112macro_rules! test_both_dbs {
113    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
114        #[cfg(target_os = "macos")]
115        #[gpui::test]
116        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
117            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
118            $test_name(test_db.db()).await;
119        }
120
121        #[gpui::test]
122        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
123            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
124            $test_name(test_db.db()).await;
125        }
126    };
127}
128
129impl Drop for TestDb {
130    fn drop(&mut self) {
131        let db = self.db.take().unwrap();
132        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
133            db.runtime.as_ref().unwrap().block_on(async {
134                use util::ResultExt;
135                let query = "
136                        SELECT pg_terminate_backend(pg_stat_activity.pid)
137                        FROM pg_stat_activity
138                        WHERE
139                            pg_stat_activity.datname = current_database() AND
140                            pid <> pg_backend_pid();
141                    ";
142                db.pool
143                    .execute(sea_orm::Statement::from_string(
144                        db.pool.get_database_backend(),
145                        query,
146                    ))
147                    .await
148                    .log_err();
149                sqlx::Postgres::drop_database(db.options.get_url())
150                    .await
151                    .log_err();
152            })
153        }
154    }
155}
156
157fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
158    channels
159        .iter()
160        .map(|(id, parent_path, name)| Channel {
161            id: *id,
162            name: name.to_string(),
163            visibility: ChannelVisibility::Members,
164            parent_path: parent_path.to_vec(),
165        })
166        .collect()
167}
168
169static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
170
171async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
172    db.create_user(
173        email,
174        false,
175        NewUserParams {
176            github_login: email[0..email.find('@').unwrap()].to_string(),
177            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
178        },
179    )
180    .await
181    .unwrap()
182    .user_id
183}
184
185static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
186fn new_test_connection(server: ServerId) -> ConnectionId {
187    ConnectionId {
188        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
189        owner_id: server.0 as u32,
190    }
191}