1mod buffer_tests;
2mod channel_tests;
3mod db_tests;
4mod feature_flag_tests;
5mod message_tests;
6
7use super::*;
8use gpui::executor::Background;
9use parking_lot::Mutex;
10use rpc::proto::ChannelEdge;
11use sea_orm::ConnectionTrait;
12use sqlx::migrate::MigrateDatabase;
13use std::sync::Arc;
14
15const TEST_RELEASE_CHANNEL: &'static str = "test";
16
17pub struct TestDb {
18 pub db: Option<Arc<Database>>,
19 pub connection: Option<sqlx::AnyConnection>,
20}
21
22impl TestDb {
23 pub fn sqlite(background: Arc<Background>) -> Self {
24 let url = format!("sqlite::memory:");
25 let runtime = tokio::runtime::Builder::new_current_thread()
26 .enable_io()
27 .enable_time()
28 .build()
29 .unwrap();
30
31 let mut db = runtime.block_on(async {
32 let mut options = ConnectOptions::new(url);
33 options.max_connections(5);
34 let db = Database::new(options, Executor::Deterministic(background))
35 .await
36 .unwrap();
37 let sql = include_str!(concat!(
38 env!("CARGO_MANIFEST_DIR"),
39 "/migrations.sqlite/20221109000000_test_schema.sql"
40 ));
41 db.pool
42 .execute(sea_orm::Statement::from_string(
43 db.pool.get_database_backend(),
44 sql,
45 ))
46 .await
47 .unwrap();
48 db
49 });
50
51 db.runtime = Some(runtime);
52
53 Self {
54 db: Some(Arc::new(db)),
55 connection: None,
56 }
57 }
58
59 pub fn postgres(background: Arc<Background>) -> Self {
60 static LOCK: Mutex<()> = Mutex::new(());
61
62 let _guard = LOCK.lock();
63 let mut rng = StdRng::from_entropy();
64 let url = format!(
65 "postgres://postgres@localhost/zed-test-{}",
66 rng.gen::<u128>()
67 );
68 let runtime = tokio::runtime::Builder::new_current_thread()
69 .enable_io()
70 .enable_time()
71 .build()
72 .unwrap();
73
74 let mut db = runtime.block_on(async {
75 sqlx::Postgres::create_database(&url)
76 .await
77 .expect("failed to create test db");
78 let mut options = ConnectOptions::new(url);
79 options
80 .max_connections(5)
81 .idle_timeout(Duration::from_secs(0));
82 let db = Database::new(options, Executor::Deterministic(background))
83 .await
84 .unwrap();
85 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
86 db.migrate(Path::new(migrations_path), false).await.unwrap();
87 db
88 });
89
90 db.runtime = Some(runtime);
91
92 Self {
93 db: Some(Arc::new(db)),
94 connection: None,
95 }
96 }
97
98 pub fn db(&self) -> &Arc<Database> {
99 self.db.as_ref().unwrap()
100 }
101}
102
103#[macro_export]
104macro_rules! test_both_dbs {
105 ($test_name:ident, $postgres_test_name:ident, $sqlite_test_name:ident) => {
106 #[gpui::test]
107 async fn $postgres_test_name() {
108 let test_db = crate::db::TestDb::postgres(
109 gpui::executor::Deterministic::new(0).build_background(),
110 );
111 $test_name(test_db.db()).await;
112 }
113
114 #[gpui::test]
115 async fn $sqlite_test_name() {
116 let test_db =
117 crate::db::TestDb::sqlite(gpui::executor::Deterministic::new(0).build_background());
118 $test_name(test_db.db()).await;
119 }
120 };
121}
122
123impl Drop for TestDb {
124 fn drop(&mut self) {
125 let db = self.db.take().unwrap();
126 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
127 db.runtime.as_ref().unwrap().block_on(async {
128 use util::ResultExt;
129 let query = "
130 SELECT pg_terminate_backend(pg_stat_activity.pid)
131 FROM pg_stat_activity
132 WHERE
133 pg_stat_activity.datname = current_database() AND
134 pid <> pg_backend_pid();
135 ";
136 db.pool
137 .execute(sea_orm::Statement::from_string(
138 db.pool.get_database_backend(),
139 query,
140 ))
141 .await
142 .log_err();
143 sqlx::Postgres::drop_database(db.options.get_url())
144 .await
145 .log_err();
146 })
147 }
148 }
149}
150
151/// The second tuples are (channel_id, parent)
152fn graph(channels: &[(ChannelId, &'static str)], edges: &[(ChannelId, ChannelId)]) -> ChannelGraph {
153 let mut graph = ChannelGraph {
154 channels: vec![],
155 edges: vec![],
156 };
157
158 for (id, name) in channels {
159 graph.channels.push(Channel {
160 id: *id,
161 name: name.to_string(),
162 })
163 }
164
165 for (channel, parent) in edges {
166 graph.edges.push(ChannelEdge {
167 channel_id: channel.to_proto(),
168 parent_id: parent.to_proto(),
169 })
170 }
171
172 graph
173}