1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{types::Uuid, FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 ) -> Result<SignupId> {
51 test_support!(self, {
52 let query = "
53 INSERT INTO signups (github_login, email_address, about)
54 VALUES ($1, $2, $3)
55 RETURNING id
56 ";
57 sqlx::query_scalar(query)
58 .bind(github_login)
59 .bind(email_address)
60 .bind(about)
61 .fetch_one(&self.pool)
62 .await
63 .map(SignupId)
64 })
65 }
66
67 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
68 test_support!(self, {
69 let query = "SELECT * FROM signups ORDER BY github_login ASC";
70 sqlx::query_as(query).fetch_all(&self.pool).await
71 })
72 }
73
74 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
75 test_support!(self, {
76 let query = "DELETE FROM signups WHERE id = $1";
77 sqlx::query(query)
78 .bind(id.0)
79 .execute(&self.pool)
80 .await
81 .map(drop)
82 })
83 }
84
85 // users
86
87 #[allow(unused)] // Help rust-analyzer
88 #[cfg(any(test, feature = "seed-support"))]
89 pub async fn get_user(&self, github_login: &str) -> Result<Option<UserId>> {
90 test_support!(self, {
91 let query = "
92 SELECT id
93 FROM users
94 WHERE github_login = $1
95 ";
96 sqlx::query_scalar(query)
97 .bind(github_login)
98 .fetch_optional(&self.pool)
99 .await
100 })
101 }
102
103 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
104 test_support!(self, {
105 let query = "
106 INSERT INTO users (github_login, admin)
107 VALUES ($1, $2)
108 RETURNING id
109 ";
110 sqlx::query_scalar(query)
111 .bind(github_login)
112 .bind(admin)
113 .fetch_one(&self.pool)
114 .await
115 .map(UserId)
116 })
117 }
118
119 pub async fn get_all_users(&self) -> Result<Vec<User>> {
120 test_support!(self, {
121 let query = "SELECT * FROM users ORDER BY github_login ASC";
122 sqlx::query_as(query).fetch_all(&self.pool).await
123 })
124 }
125
126 pub async fn get_users_by_ids(
127 &self,
128 requester_id: UserId,
129 ids: impl Iterator<Item = UserId>,
130 ) -> Result<Vec<User>> {
131 let mut include_requester = false;
132 let ids = ids
133 .map(|id| {
134 if id == requester_id {
135 include_requester = true;
136 }
137 id.0
138 })
139 .collect::<Vec<_>>();
140
141 test_support!(self, {
142 // Only return users that are in a common channel with the requesting user.
143 // Also allow the requesting user to return their own data, even if they aren't
144 // in any channels.
145 let query = "
146 SELECT
147 users.*
148 FROM
149 users, channel_memberships
150 WHERE
151 users.id = ANY ($1) AND
152 channel_memberships.user_id = users.id AND
153 channel_memberships.channel_id IN (
154 SELECT channel_id
155 FROM channel_memberships
156 WHERE channel_memberships.user_id = $2
157 )
158 UNION
159 SELECT
160 users.*
161 FROM
162 users
163 WHERE
164 $3 AND users.id = $2
165 ";
166
167 sqlx::query_as(query)
168 .bind(&ids)
169 .bind(requester_id)
170 .bind(include_requester)
171 .fetch_all(&self.pool)
172 .await
173 })
174 }
175
176 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
177 test_support!(self, {
178 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
179 sqlx::query_as(query)
180 .bind(github_login)
181 .fetch_optional(&self.pool)
182 .await
183 })
184 }
185
186 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
187 test_support!(self, {
188 let query = "UPDATE users SET admin = $1 WHERE id = $2";
189 sqlx::query(query)
190 .bind(is_admin)
191 .bind(id.0)
192 .execute(&self.pool)
193 .await
194 .map(drop)
195 })
196 }
197
198 pub async fn delete_user(&self, id: UserId) -> Result<()> {
199 test_support!(self, {
200 let query = "DELETE FROM users WHERE id = $1;";
201 sqlx::query(query)
202 .bind(id.0)
203 .execute(&self.pool)
204 .await
205 .map(drop)
206 })
207 }
208
209 // access tokens
210
211 pub async fn create_access_token_hash(
212 &self,
213 user_id: UserId,
214 access_token_hash: String,
215 ) -> Result<()> {
216 test_support!(self, {
217 let query = "
218 INSERT INTO access_tokens (user_id, hash)
219 VALUES ($1, $2)
220 ";
221 sqlx::query(query)
222 .bind(user_id.0)
223 .bind(access_token_hash)
224 .execute(&self.pool)
225 .await
226 .map(drop)
227 })
228 }
229
230 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
231 test_support!(self, {
232 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
233 sqlx::query_scalar(query)
234 .bind(user_id.0)
235 .fetch_all(&self.pool)
236 .await
237 })
238 }
239
240 // orgs
241
242 #[allow(unused)] // Help rust-analyzer
243 #[cfg(any(test, feature = "seed-support"))]
244 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
245 test_support!(self, {
246 let query = "
247 SELECT *
248 FROM orgs
249 WHERE slug = $1
250 ";
251 sqlx::query_as(query)
252 .bind(slug)
253 .fetch_optional(&self.pool)
254 .await
255 })
256 }
257
258 #[cfg(any(test, feature = "seed-support"))]
259 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
260 test_support!(self, {
261 let query = "
262 INSERT INTO orgs (name, slug)
263 VALUES ($1, $2)
264 RETURNING id
265 ";
266 sqlx::query_scalar(query)
267 .bind(name)
268 .bind(slug)
269 .fetch_one(&self.pool)
270 .await
271 .map(OrgId)
272 })
273 }
274
275 #[cfg(any(test, feature = "seed-support"))]
276 pub async fn add_org_member(
277 &self,
278 org_id: OrgId,
279 user_id: UserId,
280 is_admin: bool,
281 ) -> Result<()> {
282 test_support!(self, {
283 let query = "
284 INSERT INTO org_memberships (org_id, user_id, admin)
285 VALUES ($1, $2, $3)
286 ON CONFLICT DO NOTHING
287 ";
288 sqlx::query(query)
289 .bind(org_id.0)
290 .bind(user_id.0)
291 .bind(is_admin)
292 .execute(&self.pool)
293 .await
294 .map(drop)
295 })
296 }
297
298 // channels
299
300 #[cfg(any(test, feature = "seed-support"))]
301 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
302 test_support!(self, {
303 let query = "
304 INSERT INTO channels (owner_id, owner_is_user, name)
305 VALUES ($1, false, $2)
306 RETURNING id
307 ";
308 sqlx::query_scalar(query)
309 .bind(org_id.0)
310 .bind(name)
311 .fetch_one(&self.pool)
312 .await
313 .map(ChannelId)
314 })
315 }
316
317 #[allow(unused)] // Help rust-analyzer
318 #[cfg(any(test, feature = "seed-support"))]
319 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
320 test_support!(self, {
321 let query = "
322 SELECT *
323 FROM channels
324 WHERE
325 channels.owner_is_user = false AND
326 channels.owner_id = $1
327 ";
328 sqlx::query_as(query)
329 .bind(org_id.0)
330 .fetch_all(&self.pool)
331 .await
332 })
333 }
334
335 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
336 test_support!(self, {
337 let query = "
338 SELECT
339 channels.id, channels.name
340 FROM
341 channel_memberships, channels
342 WHERE
343 channel_memberships.user_id = $1 AND
344 channel_memberships.channel_id = channels.id
345 ";
346 sqlx::query_as(query)
347 .bind(user_id.0)
348 .fetch_all(&self.pool)
349 .await
350 })
351 }
352
353 pub async fn can_user_access_channel(
354 &self,
355 user_id: UserId,
356 channel_id: ChannelId,
357 ) -> Result<bool> {
358 test_support!(self, {
359 let query = "
360 SELECT id
361 FROM channel_memberships
362 WHERE user_id = $1 AND channel_id = $2
363 LIMIT 1
364 ";
365 sqlx::query_scalar::<_, i32>(query)
366 .bind(user_id.0)
367 .bind(channel_id.0)
368 .fetch_optional(&self.pool)
369 .await
370 .map(|e| e.is_some())
371 })
372 }
373
374 #[cfg(any(test, feature = "seed-support"))]
375 pub async fn add_channel_member(
376 &self,
377 channel_id: ChannelId,
378 user_id: UserId,
379 is_admin: bool,
380 ) -> Result<()> {
381 test_support!(self, {
382 let query = "
383 INSERT INTO channel_memberships (channel_id, user_id, admin)
384 VALUES ($1, $2, $3)
385 ON CONFLICT DO NOTHING
386 ";
387 sqlx::query(query)
388 .bind(channel_id.0)
389 .bind(user_id.0)
390 .bind(is_admin)
391 .execute(&self.pool)
392 .await
393 .map(drop)
394 })
395 }
396
397 // messages
398
399 pub async fn create_channel_message(
400 &self,
401 channel_id: ChannelId,
402 sender_id: UserId,
403 body: &str,
404 timestamp: OffsetDateTime,
405 nonce: u128,
406 ) -> Result<MessageId> {
407 test_support!(self, {
408 let query = "
409 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
410 VALUES ($1, $2, $3, $4, $5)
411 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
412 RETURNING id
413 ";
414 sqlx::query_scalar(query)
415 .bind(channel_id.0)
416 .bind(sender_id.0)
417 .bind(body)
418 .bind(timestamp)
419 .bind(Uuid::from_u128(nonce))
420 .fetch_one(&self.pool)
421 .await
422 .map(MessageId)
423 })
424 }
425
426 pub async fn get_channel_messages(
427 &self,
428 channel_id: ChannelId,
429 count: usize,
430 before_id: Option<MessageId>,
431 ) -> Result<Vec<ChannelMessage>> {
432 test_support!(self, {
433 let query = r#"
434 SELECT * FROM (
435 SELECT
436 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
437 FROM
438 channel_messages
439 WHERE
440 channel_id = $1 AND
441 id < $2
442 ORDER BY id DESC
443 LIMIT $3
444 ) as recent_messages
445 ORDER BY id ASC
446 "#;
447 sqlx::query_as(query)
448 .bind(channel_id.0)
449 .bind(before_id.unwrap_or(MessageId::MAX))
450 .bind(count as i64)
451 .fetch_all(&self.pool)
452 .await
453 })
454 }
455}
456
457macro_rules! id_type {
458 ($name:ident) => {
459 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
460 #[sqlx(transparent)]
461 #[serde(transparent)]
462 pub struct $name(pub i32);
463
464 impl $name {
465 #[allow(unused)]
466 pub const MAX: Self = Self(i32::MAX);
467
468 #[allow(unused)]
469 pub fn from_proto(value: u64) -> Self {
470 Self(value as i32)
471 }
472
473 #[allow(unused)]
474 pub fn to_proto(&self) -> u64 {
475 self.0 as u64
476 }
477 }
478 };
479}
480
481id_type!(UserId);
482#[derive(Debug, FromRow, Serialize, PartialEq)]
483pub struct User {
484 pub id: UserId,
485 pub github_login: String,
486 pub admin: bool,
487}
488
489id_type!(OrgId);
490#[derive(FromRow)]
491pub struct Org {
492 pub id: OrgId,
493 pub name: String,
494 pub slug: String,
495}
496
497id_type!(SignupId);
498#[derive(Debug, FromRow, Serialize)]
499pub struct Signup {
500 pub id: SignupId,
501 pub github_login: String,
502 pub email_address: String,
503 pub about: String,
504}
505
506id_type!(ChannelId);
507#[derive(Debug, FromRow, Serialize)]
508pub struct Channel {
509 pub id: ChannelId,
510 pub name: String,
511}
512
513id_type!(MessageId);
514#[derive(Debug, FromRow)]
515pub struct ChannelMessage {
516 pub id: MessageId,
517 pub sender_id: UserId,
518 pub body: String,
519 pub sent_at: OffsetDateTime,
520 pub nonce: Uuid,
521}
522
523#[cfg(test)]
524pub mod tests {
525 use super::*;
526 use rand::prelude::*;
527 use sqlx::{
528 migrate::{MigrateDatabase, Migrator},
529 Postgres,
530 };
531 use std::path::Path;
532
533 pub struct TestDb {
534 pub db: Db,
535 pub name: String,
536 pub url: String,
537 }
538
539 impl TestDb {
540 pub fn new() -> Self {
541 // Enable tests to run in parallel by serializing the creation of each test database.
542 lazy_static::lazy_static! {
543 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
544 }
545
546 let mut rng = StdRng::from_entropy();
547 let name = format!("zed-test-{}", rng.gen::<u128>());
548 let url = format!("postgres://postgres@localhost/{}", name);
549 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
550 let db = block_on(async {
551 {
552 let _lock = DB_CREATION.lock();
553 Postgres::create_database(&url)
554 .await
555 .expect("failed to create test db");
556 }
557 let mut db = Db::new(&url, 5).await.unwrap();
558 db.test_mode = true;
559 let migrator = Migrator::new(migrations_path).await.unwrap();
560 migrator.run(&db.pool).await.unwrap();
561 db
562 });
563
564 Self { db, name, url }
565 }
566
567 pub fn db(&self) -> &Db {
568 &self.db
569 }
570 }
571
572 impl Drop for TestDb {
573 fn drop(&mut self) {
574 block_on(async {
575 let query = "
576 SELECT pg_terminate_backend(pg_stat_activity.pid)
577 FROM pg_stat_activity
578 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
579 ";
580 sqlx::query(query)
581 .bind(&self.name)
582 .execute(&self.db.pool)
583 .await
584 .unwrap();
585 self.db.pool.close().await;
586 Postgres::drop_database(&self.url).await.unwrap();
587 });
588 }
589 }
590
591 #[gpui::test]
592 async fn test_get_users_by_ids() {
593 let test_db = TestDb::new();
594 let db = test_db.db();
595
596 let user = db.create_user("user", false).await.unwrap();
597 let friend1 = db.create_user("friend-1", false).await.unwrap();
598 let friend2 = db.create_user("friend-2", false).await.unwrap();
599 let friend3 = db.create_user("friend-3", false).await.unwrap();
600 let stranger = db.create_user("stranger", false).await.unwrap();
601
602 // A user can read their own info, even if they aren't in any channels.
603 assert_eq!(
604 db.get_users_by_ids(
605 user,
606 [user, friend1, friend2, friend3, stranger].iter().copied()
607 )
608 .await
609 .unwrap(),
610 vec![User {
611 id: user,
612 github_login: "user".to_string(),
613 admin: false,
614 },],
615 );
616
617 // A user can read the info of any other user who is in a shared channel
618 // with them.
619 let org = db.create_org("test org", "test-org").await.unwrap();
620 let chan1 = db.create_org_channel(org, "channel-1").await.unwrap();
621 let chan2 = db.create_org_channel(org, "channel-2").await.unwrap();
622 let chan3 = db.create_org_channel(org, "channel-3").await.unwrap();
623
624 db.add_channel_member(chan1, user, false).await.unwrap();
625 db.add_channel_member(chan2, user, false).await.unwrap();
626 db.add_channel_member(chan1, friend1, false).await.unwrap();
627 db.add_channel_member(chan1, friend2, false).await.unwrap();
628 db.add_channel_member(chan2, friend2, false).await.unwrap();
629 db.add_channel_member(chan2, friend3, false).await.unwrap();
630 db.add_channel_member(chan3, stranger, false).await.unwrap();
631
632 assert_eq!(
633 db.get_users_by_ids(
634 user,
635 [user, friend1, friend2, friend3, stranger].iter().copied()
636 )
637 .await
638 .unwrap(),
639 vec![
640 User {
641 id: user,
642 github_login: "user".to_string(),
643 admin: false,
644 },
645 User {
646 id: friend1,
647 github_login: "friend-1".to_string(),
648 admin: false,
649 },
650 User {
651 id: friend2,
652 github_login: "friend-2".to_string(),
653 admin: false,
654 },
655 User {
656 id: friend3,
657 github_login: "friend-3".to_string(),
658 admin: false,
659 }
660 ]
661 );
662
663 // The user's own info is only returned if they request it.
664 assert_eq!(
665 db.get_users_by_ids(user, [friend1].iter().copied())
666 .await
667 .unwrap(),
668 vec![User {
669 id: friend1,
670 github_login: "friend-1".to_string(),
671 admin: false,
672 },]
673 )
674 }
675
676 #[gpui::test]
677 async fn test_recent_channel_messages() {
678 let test_db = TestDb::new();
679 let db = test_db.db();
680 let user = db.create_user("user", false).await.unwrap();
681 let org = db.create_org("org", "org").await.unwrap();
682 let channel = db.create_org_channel(org, "channel").await.unwrap();
683 for i in 0..10 {
684 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
685 .await
686 .unwrap();
687 }
688
689 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
690 assert_eq!(
691 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
692 ["5", "6", "7", "8", "9"]
693 );
694
695 let prev_messages = db
696 .get_channel_messages(channel, 4, Some(messages[0].id))
697 .await
698 .unwrap();
699 assert_eq!(
700 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
701 ["1", "2", "3", "4"]
702 );
703 }
704
705 #[gpui::test]
706 async fn test_channel_message_nonces() {
707 let test_db = TestDb::new();
708 let db = test_db.db();
709 let user = db.create_user("user", false).await.unwrap();
710 let org = db.create_org("org", "org").await.unwrap();
711 let channel = db.create_org_channel(org, "channel").await.unwrap();
712
713 let msg1_id = db
714 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
715 .await
716 .unwrap();
717 let msg2_id = db
718 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
719 .await
720 .unwrap();
721 let msg3_id = db
722 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
723 .await
724 .unwrap();
725 let msg4_id = db
726 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
727 .await
728 .unwrap();
729
730 assert_ne!(msg1_id, msg2_id);
731 assert_eq!(msg1_id, msg3_id);
732 assert_eq!(msg2_id, msg4_id);
733 }
734}