tests.rs

  1mod buffer_tests;
  2mod channel_tests;
  3mod contributor_tests;
  4mod db_tests;
  5// we only run postgres tests on macos right now
  6#[cfg(target_os = "macos")]
  7mod embedding_tests;
  8mod extension_tests;
  9mod feature_flag_tests;
 10mod message_tests;
 11mod user_tests;
 12
 13use crate::migrations::run_database_migrations;
 14
 15use super::*;
 16use gpui::BackgroundExecutor;
 17use parking_lot::Mutex;
 18use rand::prelude::*;
 19use sea_orm::ConnectionTrait;
 20use sqlx::migrate::MigrateDatabase;
 21use std::{
 22    sync::{
 23        Arc,
 24        atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
 25    },
 26    time::Duration,
 27};
 28
 29pub struct TestDb {
 30    pub db: Option<Arc<Database>>,
 31    pub connection: Option<sqlx::AnyConnection>,
 32}
 33
 34impl TestDb {
 35    pub fn sqlite(executor: BackgroundExecutor) -> Self {
 36        let url = "sqlite::memory:";
 37        let runtime = tokio::runtime::Builder::new_current_thread()
 38            .enable_io()
 39            .enable_time()
 40            .build()
 41            .unwrap();
 42
 43        let mut db = runtime.block_on(async {
 44            let mut options = ConnectOptions::new(url);
 45            options.max_connections(5);
 46            let mut db = Database::new(options).await.unwrap();
 47            let sql = include_str!(concat!(
 48                env!("CARGO_MANIFEST_DIR"),
 49                "/migrations.sqlite/20221109000000_test_schema.sql"
 50            ));
 51            db.pool
 52                .execute(sea_orm::Statement::from_string(
 53                    db.pool.get_database_backend(),
 54                    sql,
 55                ))
 56                .await
 57                .unwrap();
 58            db.initialize_notification_kinds().await.unwrap();
 59            db
 60        });
 61
 62        db.test_options = Some(DatabaseTestOptions {
 63            executor,
 64            runtime,
 65            query_failure_probability: parking_lot::Mutex::new(0.0),
 66        });
 67
 68        Self {
 69            db: Some(Arc::new(db)),
 70            connection: None,
 71        }
 72    }
 73
 74    pub fn postgres(executor: BackgroundExecutor) -> Self {
 75        static LOCK: Mutex<()> = Mutex::new(());
 76
 77        let _guard = LOCK.lock();
 78        let mut rng = StdRng::from_os_rng();
 79        let url = format!(
 80            "postgres://postgres@localhost/zed-test-{}",
 81            rng.random::<u128>()
 82        );
 83        let runtime = tokio::runtime::Builder::new_current_thread()
 84            .enable_io()
 85            .enable_time()
 86            .build()
 87            .unwrap();
 88
 89        let mut db = runtime.block_on(async {
 90            sqlx::Postgres::create_database(&url)
 91                .await
 92                .expect("failed to create test db");
 93            let mut options = ConnectOptions::new(url);
 94            options
 95                .max_connections(5)
 96                .idle_timeout(Duration::from_secs(0));
 97            let mut db = Database::new(options).await.unwrap();
 98            let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
 99            run_database_migrations(db.options(), migrations_path)
100                .await
101                .unwrap();
102            db.initialize_notification_kinds().await.unwrap();
103            db
104        });
105
106        db.test_options = Some(DatabaseTestOptions {
107            executor,
108            runtime,
109            query_failure_probability: parking_lot::Mutex::new(0.0),
110        });
111
112        Self {
113            db: Some(Arc::new(db)),
114            connection: None,
115        }
116    }
117
118    pub fn db(&self) -> &Arc<Database> {
119        self.db.as_ref().unwrap()
120    }
121
122    pub fn set_query_failure_probability(&self, probability: f64) {
123        let database = self.db.as_ref().unwrap();
124        let test_options = database.test_options.as_ref().unwrap();
125        *test_options.query_failure_probability.lock() = probability;
126    }
127}
128
129#[macro_export]
130macro_rules! test_both_dbs {
131    ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
132        #[cfg(target_os = "macos")]
133        #[gpui::test]
134        async fn $postgres_test_name(cx: &mut gpui::TestAppContext) {
135            let test_db = $crate::db::TestDb::postgres(cx.executor().clone());
136            $test_name(test_db.db()).await;
137        }
138
139        #[gpui::test]
140        async fn $sqlite_test_name(cx: &mut gpui::TestAppContext) {
141            let test_db = $crate::db::TestDb::sqlite(cx.executor().clone());
142            $test_name(test_db.db()).await;
143        }
144    };
145}
146
147impl Drop for TestDb {
148    fn drop(&mut self) {
149        let db = self.db.take().unwrap();
150        if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
151            db.test_options.as_ref().unwrap().runtime.block_on(async {
152                use util::ResultExt;
153                let query = "
154                        SELECT pg_terminate_backend(pg_stat_activity.pid)
155                        FROM pg_stat_activity
156                        WHERE
157                            pg_stat_activity.datname = current_database() AND
158                            pid <> pg_backend_pid();
159                    ";
160                db.pool
161                    .execute(sea_orm::Statement::from_string(
162                        db.pool.get_database_backend(),
163                        query,
164                    ))
165                    .await
166                    .log_err();
167                sqlx::Postgres::drop_database(db.options.get_url())
168                    .await
169                    .log_err();
170            })
171        }
172    }
173}
174
175#[track_caller]
176fn assert_channel_tree_matches(actual: Vec<Channel>, expected: Vec<Channel>) {
177    let expected_channels = expected.into_iter().collect::<HashSet<_>>();
178    let actual_channels = actual.into_iter().collect::<HashSet<_>>();
179    pretty_assertions::assert_eq!(expected_channels, actual_channels);
180}
181
182fn channel_tree(channels: &[(ChannelId, &[ChannelId], &'static str)]) -> Vec<Channel> {
183    use std::collections::HashMap;
184
185    let mut result = Vec::new();
186    let mut order_by_parent: HashMap<Vec<ChannelId>, i32> = HashMap::new();
187
188    for (id, parent_path, name) in channels {
189        let parent_key = parent_path.to_vec();
190        let order = if parent_key.is_empty() {
191            1
192        } else {
193            *order_by_parent
194                .entry(parent_key.clone())
195                .and_modify(|e| *e += 1)
196                .or_insert(1)
197        };
198
199        result.push(Channel {
200            id: *id,
201            name: name.to_string(),
202            visibility: ChannelVisibility::Members,
203            parent_path: parent_key,
204            channel_order: order,
205        });
206    }
207
208    result
209}
210
211static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
212
213async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
214    db.create_user(
215        email,
216        None,
217        false,
218        NewUserParams {
219            github_login: email[0..email.find('@').unwrap()].to_string(),
220            github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
221        },
222    )
223    .await
224    .unwrap()
225    .user_id
226}
227
228static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
229fn new_test_connection(server: ServerId) -> ConnectionId {
230    ConnectionId {
231        id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
232        owner_id: server.0 as u32,
233    }
234}