1mod buffer_tests;
2mod channel_tests;
3mod db_tests;
4mod feature_flag_tests;
5mod message_tests;
6
7use super::*;
8use gpui::executor::Background;
9use parking_lot::Mutex;
10use rpc::proto::ChannelEdge;
11use sea_orm::ConnectionTrait;
12use sqlx::migrate::MigrateDatabase;
13use std::sync::Arc;
14
15const TEST_RELEASE_CHANNEL: &'static str = "test";
16
17pub struct TestDb {
18 pub db: Option<Arc<Database>>,
19 pub connection: Option<sqlx::AnyConnection>,
20}
21
22impl TestDb {
23 pub fn sqlite(background: Arc<Background>) -> Self {
24 let url = format!("sqlite::memory:");
25 let runtime = tokio::runtime::Builder::new_current_thread()
26 .enable_io()
27 .enable_time()
28 .build()
29 .unwrap();
30
31 let mut db = runtime.block_on(async {
32 let mut options = ConnectOptions::new(url);
33 options.max_connections(5);
34 let mut db = Database::new(options, Executor::Deterministic(background))
35 .await
36 .unwrap();
37 let sql = include_str!(concat!(
38 env!("CARGO_MANIFEST_DIR"),
39 "/migrations.sqlite/20221109000000_test_schema.sql"
40 ));
41 db.pool
42 .execute(sea_orm::Statement::from_string(
43 db.pool.get_database_backend(),
44 sql,
45 ))
46 .await
47 .unwrap();
48 db.initialize_notification_enum().await.unwrap();
49 db
50 });
51
52 db.runtime = Some(runtime);
53
54 Self {
55 db: Some(Arc::new(db)),
56 connection: None,
57 }
58 }
59
60 pub fn postgres(background: Arc<Background>) -> Self {
61 static LOCK: Mutex<()> = Mutex::new(());
62
63 let _guard = LOCK.lock();
64 let mut rng = StdRng::from_entropy();
65 let url = format!(
66 "postgres://postgres@localhost/zed-test-{}",
67 rng.gen::<u128>()
68 );
69 let runtime = tokio::runtime::Builder::new_current_thread()
70 .enable_io()
71 .enable_time()
72 .build()
73 .unwrap();
74
75 let mut db = runtime.block_on(async {
76 sqlx::Postgres::create_database(&url)
77 .await
78 .expect("failed to create test db");
79 let mut options = ConnectOptions::new(url);
80 options
81 .max_connections(5)
82 .idle_timeout(Duration::from_secs(0));
83 let mut db = Database::new(options, Executor::Deterministic(background))
84 .await
85 .unwrap();
86 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
87 db.migrate(Path::new(migrations_path), false).await.unwrap();
88 db.initialize_notification_enum().await.unwrap();
89 db
90 });
91
92 db.runtime = Some(runtime);
93
94 Self {
95 db: Some(Arc::new(db)),
96 connection: None,
97 }
98 }
99
100 pub fn db(&self) -> &Arc<Database> {
101 self.db.as_ref().unwrap()
102 }
103}
104
105#[macro_export]
106macro_rules! test_both_dbs {
107 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
108 #[gpui::test]
109 async fn $postgres_test_name() {
110 let test_db = crate::db::TestDb::postgres(
111 gpui::executor::Deterministic::new(0).build_background(),
112 );
113 $test_name(test_db.db()).await;
114 }
115
116 #[gpui::test]
117 async fn $sqlite_test_name() {
118 let test_db =
119 crate::db::TestDb::sqlite(gpui::executor::Deterministic::new(0).build_background());
120 $test_name(test_db.db()).await;
121 }
122 };
123}
124
125impl Drop for TestDb {
126 fn drop(&mut self) {
127 let db = self.db.take().unwrap();
128 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
129 db.runtime.as_ref().unwrap().block_on(async {
130 use util::ResultExt;
131 let query = "
132 SELECT pg_terminate_backend(pg_stat_activity.pid)
133 FROM pg_stat_activity
134 WHERE
135 pg_stat_activity.datname = current_database() AND
136 pid <> pg_backend_pid();
137 ";
138 db.pool
139 .execute(sea_orm::Statement::from_string(
140 db.pool.get_database_backend(),
141 query,
142 ))
143 .await
144 .log_err();
145 sqlx::Postgres::drop_database(db.options.get_url())
146 .await
147 .log_err();
148 })
149 }
150 }
151}
152
153/// The second tuples are (channel_id, parent)
154fn graph(channels: &[(ChannelId, &'static str)], edges: &[(ChannelId, ChannelId)]) -> ChannelGraph {
155 let mut graph = ChannelGraph {
156 channels: vec![],
157 edges: vec![],
158 };
159
160 for (id, name) in channels {
161 graph.channels.push(Channel {
162 id: *id,
163 name: name.to_string(),
164 })
165 }
166
167 for (channel, parent) in edges {
168 graph.edges.push(ChannelEdge {
169 channel_id: channel.to_proto(),
170 parent_id: parent.to_proto(),
171 })
172 }
173
174 graph
175}