1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24pub struct Db {
25 db: sqlx::PgPool,
26 test_mode: bool,
27}
28
29#[derive(Debug, FromRow, Serialize)]
30pub struct User {
31 pub id: UserId,
32 pub github_login: String,
33 pub admin: bool,
34}
35
36#[derive(Debug, FromRow, Serialize)]
37pub struct Signup {
38 pub id: SignupId,
39 pub github_login: String,
40 pub email_address: String,
41 pub about: String,
42}
43
44#[derive(Debug, FromRow, Serialize)]
45pub struct Channel {
46 pub id: ChannelId,
47 pub name: String,
48}
49
50#[derive(Debug, FromRow)]
51pub struct ChannelMessage {
52 pub id: MessageId,
53 pub sender_id: UserId,
54 pub body: String,
55 pub sent_at: OffsetDateTime,
56}
57
58impl Db {
59 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
60 let db = DbOptions::new()
61 .max_connections(max_connections)
62 .connect(url)
63 .await
64 .context("failed to connect to postgres database")?;
65 Ok(Self {
66 db,
67 test_mode: false,
68 })
69 }
70
71 // signups
72
73 pub async fn create_signup(
74 &self,
75 github_login: &str,
76 email_address: &str,
77 about: &str,
78 ) -> Result<SignupId> {
79 test_support!(self, {
80 let query = "
81 INSERT INTO signups (github_login, email_address, about)
82 VALUES ($1, $2, $3)
83 RETURNING id
84 ";
85 sqlx::query_scalar(query)
86 .bind(github_login)
87 .bind(email_address)
88 .bind(about)
89 .fetch_one(&self.db)
90 .await
91 .map(SignupId)
92 })
93 }
94
95 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
96 test_support!(self, {
97 let query = "SELECT * FROM users ORDER BY github_login ASC";
98 sqlx::query_as(query).fetch_all(&self.db).await
99 })
100 }
101
102 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
103 test_support!(self, {
104 let query = "DELETE FROM signups WHERE id = $1";
105 sqlx::query(query)
106 .bind(id.0)
107 .execute(&self.db)
108 .await
109 .map(drop)
110 })
111 }
112
113 // users
114
115 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
116 test_support!(self, {
117 let query = "
118 INSERT INTO users (github_login, admin)
119 VALUES ($1, $2)
120 RETURNING id
121 ";
122 sqlx::query_scalar(query)
123 .bind(github_login)
124 .bind(admin)
125 .fetch_one(&self.db)
126 .await
127 .map(UserId)
128 })
129 }
130
131 pub async fn get_all_users(&self) -> Result<Vec<User>> {
132 test_support!(self, {
133 let query = "SELECT * FROM users ORDER BY github_login ASC";
134 sqlx::query_as(query).fetch_all(&self.db).await
135 })
136 }
137
138 pub async fn get_users_by_ids(
139 &self,
140 requester_id: UserId,
141 ids: impl Iterator<Item = UserId>,
142 ) -> Result<Vec<User>> {
143 test_support!(self, {
144 // Only return users that are in a common channel with the requesting user.
145 let query = "
146 SELECT users.*
147 FROM
148 users, channel_memberships
149 WHERE
150 users.id = ANY ($1) AND
151 channel_memberships.user_id = users.id AND
152 channel_memberships.channel_id IN (
153 SELECT channel_id
154 FROM channel_memberships
155 WHERE channel_memberships.user_id = $2
156 )
157 ";
158
159 sqlx::query_as(query)
160 .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
161 .bind(requester_id)
162 .fetch_all(&self.db)
163 .await
164 })
165 }
166
167 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
168 test_support!(self, {
169 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
170 sqlx::query_as(query)
171 .bind(github_login)
172 .fetch_optional(&self.db)
173 .await
174 })
175 }
176
177 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
178 test_support!(self, {
179 let query = "UPDATE users SET admin = $1 WHERE id = $2";
180 sqlx::query(query)
181 .bind(is_admin)
182 .bind(id.0)
183 .execute(&self.db)
184 .await
185 .map(drop)
186 })
187 }
188
189 pub async fn delete_user(&self, id: UserId) -> Result<()> {
190 test_support!(self, {
191 let query = "DELETE FROM users WHERE id = $1;";
192 sqlx::query(query)
193 .bind(id.0)
194 .execute(&self.db)
195 .await
196 .map(drop)
197 })
198 }
199
200 // access tokens
201
202 pub async fn create_access_token_hash(
203 &self,
204 user_id: UserId,
205 access_token_hash: String,
206 ) -> Result<()> {
207 test_support!(self, {
208 let query = "
209 INSERT INTO access_tokens (user_id, hash)
210 VALUES ($1, $2)
211 ";
212 sqlx::query(query)
213 .bind(user_id.0)
214 .bind(access_token_hash)
215 .execute(&self.db)
216 .await
217 .map(drop)
218 })
219 }
220
221 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
222 test_support!(self, {
223 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
224 sqlx::query_scalar(query)
225 .bind(user_id.0)
226 .fetch_all(&self.db)
227 .await
228 })
229 }
230
231 // orgs
232
233 #[cfg(test)]
234 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
235 test_support!(self, {
236 let query = "
237 INSERT INTO orgs (name, slug)
238 VALUES ($1, $2)
239 RETURNING id
240 ";
241 sqlx::query_scalar(query)
242 .bind(name)
243 .bind(slug)
244 .fetch_one(&self.db)
245 .await
246 .map(OrgId)
247 })
248 }
249
250 #[cfg(test)]
251 pub async fn add_org_member(
252 &self,
253 org_id: OrgId,
254 user_id: UserId,
255 is_admin: bool,
256 ) -> Result<()> {
257 test_support!(self, {
258 let query = "
259 INSERT INTO org_memberships (org_id, user_id, admin)
260 VALUES ($1, $2, $3)
261 ";
262 sqlx::query(query)
263 .bind(org_id.0)
264 .bind(user_id.0)
265 .bind(is_admin)
266 .execute(&self.db)
267 .await
268 .map(drop)
269 })
270 }
271
272 // channels
273
274 #[cfg(test)]
275 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
276 test_support!(self, {
277 let query = "
278 INSERT INTO channels (owner_id, owner_is_user, name)
279 VALUES ($1, false, $2)
280 RETURNING id
281 ";
282 sqlx::query_scalar(query)
283 .bind(org_id.0)
284 .bind(name)
285 .fetch_one(&self.db)
286 .await
287 .map(ChannelId)
288 })
289 }
290
291 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
292 test_support!(self, {
293 let query = "
294 SELECT
295 channels.id, channels.name
296 FROM
297 channel_memberships, channels
298 WHERE
299 channel_memberships.user_id = $1 AND
300 channel_memberships.channel_id = channels.id
301 ";
302 sqlx::query_as(query)
303 .bind(user_id.0)
304 .fetch_all(&self.db)
305 .await
306 })
307 }
308
309 pub async fn can_user_access_channel(
310 &self,
311 user_id: UserId,
312 channel_id: ChannelId,
313 ) -> Result<bool> {
314 test_support!(self, {
315 let query = "
316 SELECT id
317 FROM channel_memberships
318 WHERE user_id = $1 AND channel_id = $2
319 LIMIT 1
320 ";
321 sqlx::query_scalar::<_, i32>(query)
322 .bind(user_id.0)
323 .bind(channel_id.0)
324 .fetch_optional(&self.db)
325 .await
326 .map(|e| e.is_some())
327 })
328 }
329
330 #[cfg(test)]
331 pub async fn add_channel_member(
332 &self,
333 channel_id: ChannelId,
334 user_id: UserId,
335 is_admin: bool,
336 ) -> Result<()> {
337 test_support!(self, {
338 let query = "
339 INSERT INTO channel_memberships (channel_id, user_id, admin)
340 VALUES ($1, $2, $3)
341 ";
342 sqlx::query(query)
343 .bind(channel_id.0)
344 .bind(user_id.0)
345 .bind(is_admin)
346 .execute(&self.db)
347 .await
348 .map(drop)
349 })
350 }
351
352 // messages
353
354 pub async fn create_channel_message(
355 &self,
356 channel_id: ChannelId,
357 sender_id: UserId,
358 body: &str,
359 timestamp: OffsetDateTime,
360 ) -> Result<MessageId> {
361 test_support!(self, {
362 let query = "
363 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
364 VALUES ($1, $2, $3, $4)
365 RETURNING id
366 ";
367 sqlx::query_scalar(query)
368 .bind(channel_id.0)
369 .bind(sender_id.0)
370 .bind(body)
371 .bind(timestamp)
372 .fetch_one(&self.db)
373 .await
374 .map(MessageId)
375 })
376 }
377
378 pub async fn get_recent_channel_messages(
379 &self,
380 channel_id: ChannelId,
381 count: usize,
382 ) -> Result<Vec<ChannelMessage>> {
383 test_support!(self, {
384 let query = r#"
385 SELECT
386 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
387 FROM
388 channel_messages
389 WHERE
390 channel_id = $1
391 LIMIT $2
392 "#;
393 sqlx::query_as(query)
394 .bind(channel_id.0)
395 .bind(count as i64)
396 .fetch_all(&self.db)
397 .await
398 })
399 }
400
401 #[cfg(test)]
402 pub async fn close(&self, db_name: &str) {
403 test_support!(self, {
404 let query = "
405 SELECT pg_terminate_backend(pg_stat_activity.pid)
406 FROM pg_stat_activity
407 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
408 ";
409 sqlx::query(query)
410 .bind(db_name)
411 .execute(&self.db)
412 .await
413 .unwrap();
414 self.db.close().await;
415 })
416 }
417}
418
419macro_rules! id_type {
420 ($name:ident) => {
421 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
422 #[sqlx(transparent)]
423 #[serde(transparent)]
424 pub struct $name(pub i32);
425
426 impl $name {
427 #[allow(unused)]
428 pub fn from_proto(value: u64) -> Self {
429 Self(value as i32)
430 }
431
432 #[allow(unused)]
433 pub fn to_proto(&self) -> u64 {
434 self.0 as u64
435 }
436 }
437 };
438}
439
440id_type!(UserId);
441id_type!(OrgId);
442id_type!(ChannelId);
443id_type!(SignupId);
444id_type!(MessageId);
445
446#[cfg(test)]
447pub mod tests {
448 use super::*;
449 use rand::prelude::*;
450 use sqlx::{
451 migrate::{MigrateDatabase, Migrator},
452 Postgres,
453 };
454 use std::path::Path;
455
456 pub struct TestDb {
457 pub name: String,
458 pub url: String,
459 }
460
461 impl TestDb {
462 pub fn new() -> (Self, Db) {
463 // Enable tests to run in parallel by serializing the creation of each test database.
464 lazy_static::lazy_static! {
465 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
466 }
467
468 let mut rng = StdRng::from_entropy();
469 let name = format!("zed-test-{}", rng.gen::<u128>());
470 let url = format!("postgres://postgres@localhost/{}", name);
471 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
472 let db = block_on(async {
473 {
474 let _lock = DB_CREATION.lock();
475 Postgres::create_database(&url)
476 .await
477 .expect("failed to create test db");
478 }
479 let mut db = Db::new(&url, 5).await.unwrap();
480 db.test_mode = true;
481 let migrator = Migrator::new(migrations_path).await.unwrap();
482 migrator.run(&db.db).await.unwrap();
483 db
484 });
485
486 (Self { name, url }, db)
487 }
488 }
489
490 impl Drop for TestDb {
491 fn drop(&mut self) {
492 block_on(Postgres::drop_database(&self.url)).unwrap();
493 }
494 }
495}