1mod buffer_tests;
2mod channel_tests;
3mod db_tests;
4mod feature_flag_tests;
5mod message_tests;
6
7use super::*;
8use gpui::executor::Background;
9use parking_lot::Mutex;
10use rpc::proto::ChannelEdge;
11use sea_orm::ConnectionTrait;
12use sqlx::migrate::MigrateDatabase;
13use std::sync::Arc;
14
15pub struct TestDb {
16 pub db: Option<Arc<Database>>,
17 pub connection: Option<sqlx::AnyConnection>,
18}
19
20impl TestDb {
21 pub fn sqlite(background: Arc<Background>) -> Self {
22 let url = format!("sqlite::memory:");
23 let runtime = tokio::runtime::Builder::new_current_thread()
24 .enable_io()
25 .enable_time()
26 .build()
27 .unwrap();
28
29 let mut db = runtime.block_on(async {
30 let mut options = ConnectOptions::new(url);
31 options.max_connections(5);
32 let db = Database::new(options, Executor::Deterministic(background))
33 .await
34 .unwrap();
35 let sql = include_str!(concat!(
36 env!("CARGO_MANIFEST_DIR"),
37 "/migrations.sqlite/20221109000000_test_schema.sql"
38 ));
39 db.pool
40 .execute(sea_orm::Statement::from_string(
41 db.pool.get_database_backend(),
42 sql,
43 ))
44 .await
45 .unwrap();
46 db
47 });
48
49 db.runtime = Some(runtime);
50
51 Self {
52 db: Some(Arc::new(db)),
53 connection: None,
54 }
55 }
56
57 pub fn postgres(background: Arc<Background>) -> Self {
58 static LOCK: Mutex<()> = Mutex::new(());
59
60 let _guard = LOCK.lock();
61 let mut rng = StdRng::from_entropy();
62 let url = format!(
63 "postgres://postgres@localhost/zed-test-{}",
64 rng.gen::<u128>()
65 );
66 let runtime = tokio::runtime::Builder::new_current_thread()
67 .enable_io()
68 .enable_time()
69 .build()
70 .unwrap();
71
72 let mut db = runtime.block_on(async {
73 sqlx::Postgres::create_database(&url)
74 .await
75 .expect("failed to create test db");
76 let mut options = ConnectOptions::new(url);
77 options
78 .max_connections(5)
79 .idle_timeout(Duration::from_secs(0));
80 let db = Database::new(options, Executor::Deterministic(background))
81 .await
82 .unwrap();
83 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
84 db.migrate(Path::new(migrations_path), false).await.unwrap();
85 db
86 });
87
88 db.runtime = Some(runtime);
89
90 Self {
91 db: Some(Arc::new(db)),
92 connection: None,
93 }
94 }
95
96 pub fn db(&self) -> &Arc<Database> {
97 self.db.as_ref().unwrap()
98 }
99}
100
101#[macro_export]
102macro_rules! test_both_dbs {
103 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
104 #[gpui::test]
105 async fn $postgres_test_name() {
106 let test_db = crate::db::TestDb::postgres(
107 gpui::executor::Deterministic::new(0).build_background(),
108 );
109 $test_name(test_db.db()).await;
110 }
111
112 #[gpui::test]
113 async fn $sqlite_test_name() {
114 let test_db =
115 crate::db::TestDb::sqlite(gpui::executor::Deterministic::new(0).build_background());
116 $test_name(test_db.db()).await;
117 }
118 };
119}
120
121impl Drop for TestDb {
122 fn drop(&mut self) {
123 let db = self.db.take().unwrap();
124 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
125 db.runtime.as_ref().unwrap().block_on(async {
126 use util::ResultExt;
127 let query = "
128 SELECT pg_terminate_backend(pg_stat_activity.pid)
129 FROM pg_stat_activity
130 WHERE
131 pg_stat_activity.datname = current_database() AND
132 pid <> pg_backend_pid();
133 ";
134 db.pool
135 .execute(sea_orm::Statement::from_string(
136 db.pool.get_database_backend(),
137 query,
138 ))
139 .await
140 .log_err();
141 sqlx::Postgres::drop_database(db.options.get_url())
142 .await
143 .log_err();
144 })
145 }
146 }
147}
148
149/// The second tuples are (channel_id, parent)
150fn graph(channels: &[(ChannelId, &'static str)], edges: &[(ChannelId, ChannelId)]) -> ChannelGraph {
151 let mut graph = ChannelGraph {
152 channels: vec![],
153 edges: vec![],
154 };
155
156 for (id, name) in channels {
157 graph.channels.push(Channel {
158 id: *id,
159 name: name.to_string(),
160 })
161 }
162
163 for (channel, parent) in edges {
164 graph.edges.push(ChannelEdge {
165 channel_id: channel.to_proto(),
166 parent_id: parent.to_proto(),
167 })
168 }
169
170 graph
171}