tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod db_tests;
  4mod feature_flag_tests;
  5mod message_tests;
  6
  7use super::*;
  8use gpui::executor::Background;
  9use parking_lot::Mutex;
 10use rpc::proto::ChannelEdge;
 11use sea_orm::ConnectionTrait;
 12use sqlx::migrate::MigrateDatabase;
 13use std::sync::{
 14    atomic::{AtomicI32, Ordering::SeqCst},
 15    Arc,
 16};
 17
 18const TEST_RELEASE_CHANNEL: &'static str = "test";
 19
 20pub struct TestDb {
 21    pub db: Option<Arc<Database>>,
 22    pub connection: Option<sqlx::AnyConnection>,
 23}
 24
 25impl TestDb {
 26    pub fn sqlite(background: Arc<Background>) -> Self {
 27        let url = format!("sqlite::memory:");
 28        let runtime = tokio::runtime::Builder::new_current_thread()
 29            .enable_io()
 30            .enable_time()
 31            .build()
 32            .unwrap();
 33
 34        let mut db = runtime.block_on(async {
 35            let mut options = ConnectOptions::new(url);
 36            options.max_connections(5);
 37            let mut db = Database::new(options, Executor::Deterministic(background))
 38                .await
 39                .unwrap();
 40            let sql = include_str!(concat!(
 41                env!("CARGO_MANIFEST_DIR"),
 42                "/migrations.sqlite/20221109000000_test_schema.sql"
 43            ));
 44            db.pool
 45                .execute(sea_orm::Statement::from_string(
 46                    db.pool.get_database_backend(),
 47                    sql,
 48                ))
 49                .await
 50                .unwrap();
 51            db.initialize_notification_kinds().await.unwrap();
 52            db
 53        });
 54
 55        db.runtime = Some(runtime);
 56
 57        Self {
 58            db: Some(Arc::new(db)),
 59            connection: None,
 60        }
 61    }
 62
 63    pub fn postgres(background: Arc<Background>) -> Self {
 64        static LOCK: Mutex<()> = Mutex::new(());
 65
 66        let _guard = LOCK.lock();
 67        let mut rng = StdRng::from_entropy();
 68        let url = format!(
 69            "postgres://postgres@localhost/zed-test-{}",
 70            rng.gen::<u128>()
 71        );
 72        let runtime = tokio::runtime::Builder::new_current_thread()
 73            .enable_io()
 74            .enable_time()
 75            .build()
 76            .unwrap();
 77
 78        let mut db = runtime.block_on(async {
 79            sqlx::Postgres::create_database(&url)
 80                .await
 81                .expect("failed to create test db");
 82            let mut options = ConnectOptions::new(url);
 83            options
 84                .max_connections(5)
 85                .idle_timeout(Duration::from_secs(0));
 86            let mut db = Database::new(options, Executor::Deterministic(background))
 87                .await
 88                .unwrap();
 89            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 90            db.migrate(Path::new(migrations_path), false).await.unwrap();
 91            db.initialize_notification_kinds().await.unwrap();
 92            db
 93        });
 94
 95        db.runtime = Some(runtime);
 96
 97        Self {
 98            db: Some(Arc::new(db)),
 99            connection: None,
100        }
101    }
102
103    pub fn db(&self) -> &Arc<Database> {
104        self.db.as_ref().unwrap()
105    }
106}
107
108#[macro_export]
109macro_rules! test_both_dbs {
110    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
111        #[gpui::test]
112        async fn $postgres_test_name() {
113            let test_db = crate::db::TestDb::postgres(
114                gpui::executor::Deterministic::new(0).build_background(),
115            );
116            $test_name(test_db.db()).await;
117        }
118
119        #[gpui::test]
120        async fn $sqlite_test_name() {
121            let test_db =
122                crate::db::TestDb::sqlite(gpui::executor::Deterministic::new(0).build_background());
123            $test_name(test_db.db()).await;
124        }
125    };
126}
127
128impl Drop for TestDb {
129    fn drop(&mut self) {
130        let db = self.db.take().unwrap();
131        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
132            db.runtime.as_ref().unwrap().block_on(async {
133                use util::ResultExt;
134                let query = "
135                        SELECT pg_terminate_backend(pg_stat_activity.pid)
136                        FROM pg_stat_activity
137                        WHERE
138                            pg_stat_activity.datname = current_database() AND
139                            pid <> pg_backend_pid();
140                    ";
141                db.pool
142                    .execute(sea_orm::Statement::from_string(
143                        db.pool.get_database_backend(),
144                        query,
145                    ))
146                    .await
147                    .log_err();
148                sqlx::Postgres::drop_database(db.options.get_url())
149                    .await
150                    .log_err();
151            })
152        }
153    }
154}
155
156/// The second tuples are (channel_id, parent)
157fn graph(channels: &[(ChannelId, &'static str)], edges: &[(ChannelId, ChannelId)]) -> ChannelGraph {
158    let mut graph = ChannelGraph {
159        channels: vec![],
160        edges: vec![],
161    };
162
163    for (id, name) in channels {
164        graph.channels.push(Channel {
165            id: *id,
166            name: name.to_string(),
167            visibility: ChannelVisibility::Members,
168        })
169    }
170
171    for (channel, parent) in edges {
172        graph.edges.push(ChannelEdge {
173            channel_id: channel.to_proto(),
174            parent_id: parent.to_proto(),
175        })
176    }
177
178    graph
179}
180
181static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
182
183async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
184    db.create_user(
185        email,
186        false,
187        NewUserParams {
188            github_login: email[0..email.find("@").unwrap()].to_string(),
189            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
190        },
191    )
192    .await
193    .unwrap()
194    .user_id
195}