tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod db_tests;
  4mod feature_flag_tests;
  5mod message_tests;
  6
  7use super::*;
  8use gpui::executor::Background;
  9use parking_lot::Mutex;
 10use sea_orm::ConnectionTrait;
 11use sqlx::migrate::MigrateDatabase;
 12use std::sync::{
 13    atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 14    Arc,
 15};
 16
 17const TEST_RELEASE_CHANNEL: &'static str = "test";
 18
 19pub struct TestDb {
 20    pub db: Option<Arc<Database>>,
 21    pub connection: Option<sqlx::AnyConnection>,
 22}
 23
 24impl TestDb {
 25    pub fn sqlite(background: Arc<Background>) -> Self {
 26        let url = format!("sqlite::memory:");
 27        let runtime = tokio::runtime::Builder::new_current_thread()
 28            .enable_io()
 29            .enable_time()
 30            .build()
 31            .unwrap();
 32
 33        let mut db = runtime.block_on(async {
 34            let mut options = ConnectOptions::new(url);
 35            options.max_connections(5);
 36            let mut db = Database::new(options, Executor::Deterministic(background))
 37                .await
 38                .unwrap();
 39            let sql = include_str!(concat!(
 40                env!("CARGO_MANIFEST_DIR"),
 41                "/migrations.sqlite/20221109000000_test_schema.sql"
 42            ));
 43            db.pool
 44                .execute(sea_orm::Statement::from_string(
 45                    db.pool.get_database_backend(),
 46                    sql,
 47                ))
 48                .await
 49                .unwrap();
 50            db.initialize_notification_kinds().await.unwrap();
 51            db
 52        });
 53
 54        db.runtime = Some(runtime);
 55
 56        Self {
 57            db: Some(Arc::new(db)),
 58            connection: None,
 59        }
 60    }
 61
 62    pub fn postgres(background: Arc<Background>) -> Self {
 63        static LOCK: Mutex<()> = Mutex::new(());
 64
 65        let _guard = LOCK.lock();
 66        let mut rng = StdRng::from_entropy();
 67        let url = format!(
 68            "postgres://postgres@localhost/zed-test-{}",
 69            rng.gen::<u128>()
 70        );
 71        let runtime = tokio::runtime::Builder::new_current_thread()
 72            .enable_io()
 73            .enable_time()
 74            .build()
 75            .unwrap();
 76
 77        let mut db = runtime.block_on(async {
 78            sqlx::Postgres::create_database(&url)
 79                .await
 80                .expect("failed to create test db");
 81            let mut options = ConnectOptions::new(url);
 82            options
 83                .max_connections(5)
 84                .idle_timeout(Duration::from_secs(0));
 85            let mut db = Database::new(options, Executor::Deterministic(background))
 86                .await
 87                .unwrap();
 88            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 89            db.migrate(Path::new(migrations_path), false).await.unwrap();
 90            db.initialize_notification_kinds().await.unwrap();
 91            db
 92        });
 93
 94        db.runtime = Some(runtime);
 95
 96        Self {
 97            db: Some(Arc::new(db)),
 98            connection: None,
 99        }
100    }
101
102    pub fn db(&self) -> &Arc<Database> {
103        self.db.as_ref().unwrap()
104    }
105}
106
107#[macro_export]
108macro_rules! test_both_dbs {
109    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
110        #[gpui::test]
111        async fn $postgres_test_name() {
112            let test_db = crate::db::TestDb::postgres(
113                gpui::executor::Deterministic::new(0).build_background(),
114            );
115            $test_name(test_db.db()).await;
116        }
117
118        #[gpui::test]
119        async fn $sqlite_test_name() {
120            let test_db =
121                crate::db::TestDb::sqlite(gpui::executor::Deterministic::new(0).build_background());
122            $test_name(test_db.db()).await;
123        }
124    };
125}
126
127impl Drop for TestDb {
128    fn drop(&mut self) {
129        let db = self.db.take().unwrap();
130        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
131            db.runtime.as_ref().unwrap().block_on(async {
132                use util::ResultExt;
133                let query = "
134                        SELECT pg_terminate_backend(pg_stat_activity.pid)
135                        FROM pg_stat_activity
136                        WHERE
137                            pg_stat_activity.datname = current_database() AND
138                            pid <> pg_backend_pid();
139                    ";
140                db.pool
141                    .execute(sea_orm::Statement::from_string(
142                        db.pool.get_database_backend(),
143                        query,
144                    ))
145                    .await
146                    .log_err();
147                sqlx::Postgres::drop_database(db.options.get_url())
148                    .await
149                    .log_err();
150            })
151        }
152    }
153}
154
155fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str, ChannelRole)]) -> Vec<Channel> {
156    channels
157        .iter()
158        .map(|(id, parent_path, name, role)| Channel {
159            id: *id,
160            name: name.to_string(),
161            visibility: ChannelVisibility::Members,
162            role: *role,
163            parent_path: parent_path.to_vec(),
164        })
165        .collect()
166}
167
168static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
169
170async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
171    db.create_user(
172        email,
173        false,
174        NewUserParams {
175            github_login: email[0..email.find("@").unwrap()].to_string(),
176            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
177        },
178    )
179    .await
180    .unwrap()
181    .user_id
182}
183
184static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
185fn new_test_connection(server: ServerId) -> ConnectionId {
186    ConnectionId {
187        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
188        owner_id: server.0 as u32,
189    }
190}