1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{types::Uuid, FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 ) -> Result<SignupId> {
51 test_support!(self, {
52 let query = "
53 INSERT INTO signups (github_login, email_address, about)
54 VALUES ($1, $2, $3)
55 RETURNING id
56 ";
57 sqlx::query_scalar(query)
58 .bind(github_login)
59 .bind(email_address)
60 .bind(about)
61 .fetch_one(&self.pool)
62 .await
63 .map(SignupId)
64 })
65 }
66
67 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
68 test_support!(self, {
69 let query = "SELECT * FROM signups ORDER BY github_login ASC";
70 sqlx::query_as(query).fetch_all(&self.pool).await
71 })
72 }
73
74 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
75 test_support!(self, {
76 let query = "DELETE FROM signups WHERE id = $1";
77 sqlx::query(query)
78 .bind(id.0)
79 .execute(&self.pool)
80 .await
81 .map(drop)
82 })
83 }
84
85 // users
86
87 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
88 test_support!(self, {
89 let query = "
90 INSERT INTO users (github_login, admin)
91 VALUES ($1, $2)
92 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
93 RETURNING id
94 ";
95 sqlx::query_scalar(query)
96 .bind(github_login)
97 .bind(admin)
98 .fetch_one(&self.pool)
99 .await
100 .map(UserId)
101 })
102 }
103
104 pub async fn get_all_users(&self) -> Result<Vec<User>> {
105 test_support!(self, {
106 let query = "SELECT * FROM users ORDER BY github_login ASC";
107 sqlx::query_as(query).fetch_all(&self.pool).await
108 })
109 }
110
111 pub async fn get_users_by_ids(
112 &self,
113 requester_id: UserId,
114 ids: impl Iterator<Item = UserId>,
115 ) -> Result<Vec<User>> {
116 let mut include_requester = false;
117 let ids = ids
118 .map(|id| {
119 if id == requester_id {
120 include_requester = true;
121 }
122 id.0
123 })
124 .collect::<Vec<_>>();
125
126 test_support!(self, {
127 // Only return users that are in a common channel with the requesting user.
128 // Also allow the requesting user to return their own data, even if they aren't
129 // in any channels.
130 let query = "
131 SELECT
132 users.*
133 FROM
134 users, channel_memberships
135 WHERE
136 users.id = ANY ($1) AND
137 channel_memberships.user_id = users.id AND
138 channel_memberships.channel_id IN (
139 SELECT channel_id
140 FROM channel_memberships
141 WHERE channel_memberships.user_id = $2
142 )
143 UNION
144 SELECT
145 users.*
146 FROM
147 users
148 WHERE
149 $3 AND users.id = $2
150 ";
151
152 sqlx::query_as(query)
153 .bind(&ids)
154 .bind(requester_id)
155 .bind(include_requester)
156 .fetch_all(&self.pool)
157 .await
158 })
159 }
160
161 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
162 test_support!(self, {
163 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
164 sqlx::query_as(query)
165 .bind(github_login)
166 .fetch_optional(&self.pool)
167 .await
168 })
169 }
170
171 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
172 test_support!(self, {
173 let query = "UPDATE users SET admin = $1 WHERE id = $2";
174 sqlx::query(query)
175 .bind(is_admin)
176 .bind(id.0)
177 .execute(&self.pool)
178 .await
179 .map(drop)
180 })
181 }
182
183 pub async fn delete_user(&self, id: UserId) -> Result<()> {
184 test_support!(self, {
185 let query = "DELETE FROM users WHERE id = $1;";
186 sqlx::query(query)
187 .bind(id.0)
188 .execute(&self.pool)
189 .await
190 .map(drop)
191 })
192 }
193
194 // access tokens
195
196 pub async fn create_access_token_hash(
197 &self,
198 user_id: UserId,
199 access_token_hash: String,
200 ) -> Result<()> {
201 test_support!(self, {
202 let query = "
203 INSERT INTO access_tokens (user_id, hash)
204 VALUES ($1, $2)
205 ";
206 sqlx::query(query)
207 .bind(user_id.0)
208 .bind(access_token_hash)
209 .execute(&self.pool)
210 .await
211 .map(drop)
212 })
213 }
214
215 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
216 test_support!(self, {
217 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
218 sqlx::query_scalar(query)
219 .bind(user_id.0)
220 .fetch_all(&self.pool)
221 .await
222 })
223 }
224
225 // orgs
226
227 #[allow(unused)] // Help rust-analyzer
228 #[cfg(any(test, feature = "seed-support"))]
229 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
230 test_support!(self, {
231 let query = "
232 SELECT *
233 FROM orgs
234 WHERE slug = $1
235 ";
236 sqlx::query_as(query)
237 .bind(slug)
238 .fetch_optional(&self.pool)
239 .await
240 })
241 }
242
243 #[cfg(any(test, feature = "seed-support"))]
244 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
245 test_support!(self, {
246 let query = "
247 INSERT INTO orgs (name, slug)
248 VALUES ($1, $2)
249 RETURNING id
250 ";
251 sqlx::query_scalar(query)
252 .bind(name)
253 .bind(slug)
254 .fetch_one(&self.pool)
255 .await
256 .map(OrgId)
257 })
258 }
259
260 #[cfg(any(test, feature = "seed-support"))]
261 pub async fn add_org_member(
262 &self,
263 org_id: OrgId,
264 user_id: UserId,
265 is_admin: bool,
266 ) -> Result<()> {
267 test_support!(self, {
268 let query = "
269 INSERT INTO org_memberships (org_id, user_id, admin)
270 VALUES ($1, $2, $3)
271 ON CONFLICT DO NOTHING
272 ";
273 sqlx::query(query)
274 .bind(org_id.0)
275 .bind(user_id.0)
276 .bind(is_admin)
277 .execute(&self.pool)
278 .await
279 .map(drop)
280 })
281 }
282
283 // channels
284
285 #[cfg(any(test, feature = "seed-support"))]
286 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
287 test_support!(self, {
288 let query = "
289 INSERT INTO channels (owner_id, owner_is_user, name)
290 VALUES ($1, false, $2)
291 RETURNING id
292 ";
293 sqlx::query_scalar(query)
294 .bind(org_id.0)
295 .bind(name)
296 .fetch_one(&self.pool)
297 .await
298 .map(ChannelId)
299 })
300 }
301
302 #[allow(unused)] // Help rust-analyzer
303 #[cfg(any(test, feature = "seed-support"))]
304 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
305 test_support!(self, {
306 let query = "
307 SELECT *
308 FROM channels
309 WHERE
310 channels.owner_is_user = false AND
311 channels.owner_id = $1
312 ";
313 sqlx::query_as(query)
314 .bind(org_id.0)
315 .fetch_all(&self.pool)
316 .await
317 })
318 }
319
320 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
321 test_support!(self, {
322 let query = "
323 SELECT
324 channels.id, channels.name
325 FROM
326 channel_memberships, channels
327 WHERE
328 channel_memberships.user_id = $1 AND
329 channel_memberships.channel_id = channels.id
330 ";
331 sqlx::query_as(query)
332 .bind(user_id.0)
333 .fetch_all(&self.pool)
334 .await
335 })
336 }
337
338 pub async fn can_user_access_channel(
339 &self,
340 user_id: UserId,
341 channel_id: ChannelId,
342 ) -> Result<bool> {
343 test_support!(self, {
344 let query = "
345 SELECT id
346 FROM channel_memberships
347 WHERE user_id = $1 AND channel_id = $2
348 LIMIT 1
349 ";
350 sqlx::query_scalar::<_, i32>(query)
351 .bind(user_id.0)
352 .bind(channel_id.0)
353 .fetch_optional(&self.pool)
354 .await
355 .map(|e| e.is_some())
356 })
357 }
358
359 #[cfg(any(test, feature = "seed-support"))]
360 pub async fn add_channel_member(
361 &self,
362 channel_id: ChannelId,
363 user_id: UserId,
364 is_admin: bool,
365 ) -> Result<()> {
366 test_support!(self, {
367 let query = "
368 INSERT INTO channel_memberships (channel_id, user_id, admin)
369 VALUES ($1, $2, $3)
370 ON CONFLICT DO NOTHING
371 ";
372 sqlx::query(query)
373 .bind(channel_id.0)
374 .bind(user_id.0)
375 .bind(is_admin)
376 .execute(&self.pool)
377 .await
378 .map(drop)
379 })
380 }
381
382 // messages
383
384 pub async fn create_channel_message(
385 &self,
386 channel_id: ChannelId,
387 sender_id: UserId,
388 body: &str,
389 timestamp: OffsetDateTime,
390 nonce: u128,
391 ) -> Result<MessageId> {
392 test_support!(self, {
393 let query = "
394 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
395 VALUES ($1, $2, $3, $4, $5)
396 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
397 RETURNING id
398 ";
399 sqlx::query_scalar(query)
400 .bind(channel_id.0)
401 .bind(sender_id.0)
402 .bind(body)
403 .bind(timestamp)
404 .bind(Uuid::from_u128(nonce))
405 .fetch_one(&self.pool)
406 .await
407 .map(MessageId)
408 })
409 }
410
411 pub async fn get_channel_messages(
412 &self,
413 channel_id: ChannelId,
414 count: usize,
415 before_id: Option<MessageId>,
416 ) -> Result<Vec<ChannelMessage>> {
417 test_support!(self, {
418 let query = r#"
419 SELECT * FROM (
420 SELECT
421 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
422 FROM
423 channel_messages
424 WHERE
425 channel_id = $1 AND
426 id < $2
427 ORDER BY id DESC
428 LIMIT $3
429 ) as recent_messages
430 ORDER BY id ASC
431 "#;
432 sqlx::query_as(query)
433 .bind(channel_id.0)
434 .bind(before_id.unwrap_or(MessageId::MAX))
435 .bind(count as i64)
436 .fetch_all(&self.pool)
437 .await
438 })
439 }
440}
441
442macro_rules! id_type {
443 ($name:ident) => {
444 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
445 #[sqlx(transparent)]
446 #[serde(transparent)]
447 pub struct $name(pub i32);
448
449 impl $name {
450 #[allow(unused)]
451 pub const MAX: Self = Self(i32::MAX);
452
453 #[allow(unused)]
454 pub fn from_proto(value: u64) -> Self {
455 Self(value as i32)
456 }
457
458 #[allow(unused)]
459 pub fn to_proto(&self) -> u64 {
460 self.0 as u64
461 }
462 }
463 };
464}
465
466id_type!(UserId);
467#[derive(Debug, FromRow, Serialize, PartialEq)]
468pub struct User {
469 pub id: UserId,
470 pub github_login: String,
471 pub admin: bool,
472}
473
474id_type!(OrgId);
475#[derive(FromRow)]
476pub struct Org {
477 pub id: OrgId,
478 pub name: String,
479 pub slug: String,
480}
481
482id_type!(SignupId);
483#[derive(Debug, FromRow, Serialize)]
484pub struct Signup {
485 pub id: SignupId,
486 pub github_login: String,
487 pub email_address: String,
488 pub about: String,
489}
490
491id_type!(ChannelId);
492#[derive(Debug, FromRow, Serialize)]
493pub struct Channel {
494 pub id: ChannelId,
495 pub name: String,
496}
497
498id_type!(MessageId);
499#[derive(Debug, FromRow)]
500pub struct ChannelMessage {
501 pub id: MessageId,
502 pub sender_id: UserId,
503 pub body: String,
504 pub sent_at: OffsetDateTime,
505 pub nonce: Uuid,
506}
507
508#[cfg(test)]
509pub mod tests {
510 use super::*;
511 use rand::prelude::*;
512 use sqlx::{
513 migrate::{MigrateDatabase, Migrator},
514 Postgres,
515 };
516 use std::path::Path;
517
518 pub struct TestDb {
519 pub db: Db,
520 pub name: String,
521 pub url: String,
522 }
523
524 impl TestDb {
525 pub fn new() -> Self {
526 // Enable tests to run in parallel by serializing the creation of each test database.
527 lazy_static::lazy_static! {
528 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
529 }
530
531 let mut rng = StdRng::from_entropy();
532 let name = format!("zed-test-{}", rng.gen::<u128>());
533 let url = format!("postgres://postgres@localhost/{}", name);
534 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
535 let db = block_on(async {
536 {
537 let _lock = DB_CREATION.lock();
538 Postgres::create_database(&url)
539 .await
540 .expect("failed to create test db");
541 }
542 let mut db = Db::new(&url, 5).await.unwrap();
543 db.test_mode = true;
544 let migrator = Migrator::new(migrations_path).await.unwrap();
545 migrator.run(&db.pool).await.unwrap();
546 db
547 });
548
549 Self { db, name, url }
550 }
551
552 pub fn db(&self) -> &Db {
553 &self.db
554 }
555 }
556
557 impl Drop for TestDb {
558 fn drop(&mut self) {
559 block_on(async {
560 let query = "
561 SELECT pg_terminate_backend(pg_stat_activity.pid)
562 FROM pg_stat_activity
563 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
564 ";
565 sqlx::query(query)
566 .bind(&self.name)
567 .execute(&self.db.pool)
568 .await
569 .unwrap();
570 self.db.pool.close().await;
571 Postgres::drop_database(&self.url).await.unwrap();
572 });
573 }
574 }
575
576 #[gpui::test]
577 async fn test_get_users_by_ids() {
578 let test_db = TestDb::new();
579 let db = test_db.db();
580
581 let user = db.create_user("user", false).await.unwrap();
582 let friend1 = db.create_user("friend-1", false).await.unwrap();
583 let friend2 = db.create_user("friend-2", false).await.unwrap();
584 let friend3 = db.create_user("friend-3", false).await.unwrap();
585 let stranger = db.create_user("stranger", false).await.unwrap();
586
587 // A user can read their own info, even if they aren't in any channels.
588 assert_eq!(
589 db.get_users_by_ids(
590 user,
591 [user, friend1, friend2, friend3, stranger].iter().copied()
592 )
593 .await
594 .unwrap(),
595 vec![User {
596 id: user,
597 github_login: "user".to_string(),
598 admin: false,
599 },],
600 );
601
602 // A user can read the info of any other user who is in a shared channel
603 // with them.
604 let org = db.create_org("test org", "test-org").await.unwrap();
605 let chan1 = db.create_org_channel(org, "channel-1").await.unwrap();
606 let chan2 = db.create_org_channel(org, "channel-2").await.unwrap();
607 let chan3 = db.create_org_channel(org, "channel-3").await.unwrap();
608
609 db.add_channel_member(chan1, user, false).await.unwrap();
610 db.add_channel_member(chan2, user, false).await.unwrap();
611 db.add_channel_member(chan1, friend1, false).await.unwrap();
612 db.add_channel_member(chan1, friend2, false).await.unwrap();
613 db.add_channel_member(chan2, friend2, false).await.unwrap();
614 db.add_channel_member(chan2, friend3, false).await.unwrap();
615 db.add_channel_member(chan3, stranger, false).await.unwrap();
616
617 assert_eq!(
618 db.get_users_by_ids(
619 user,
620 [user, friend1, friend2, friend3, stranger].iter().copied()
621 )
622 .await
623 .unwrap(),
624 vec![
625 User {
626 id: user,
627 github_login: "user".to_string(),
628 admin: false,
629 },
630 User {
631 id: friend1,
632 github_login: "friend-1".to_string(),
633 admin: false,
634 },
635 User {
636 id: friend2,
637 github_login: "friend-2".to_string(),
638 admin: false,
639 },
640 User {
641 id: friend3,
642 github_login: "friend-3".to_string(),
643 admin: false,
644 }
645 ]
646 );
647
648 // The user's own info is only returned if they request it.
649 assert_eq!(
650 db.get_users_by_ids(user, [friend1].iter().copied())
651 .await
652 .unwrap(),
653 vec![User {
654 id: friend1,
655 github_login: "friend-1".to_string(),
656 admin: false,
657 },]
658 )
659 }
660
661 #[gpui::test]
662 async fn test_recent_channel_messages() {
663 let test_db = TestDb::new();
664 let db = test_db.db();
665 let user = db.create_user("user", false).await.unwrap();
666 let org = db.create_org("org", "org").await.unwrap();
667 let channel = db.create_org_channel(org, "channel").await.unwrap();
668 for i in 0..10 {
669 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
670 .await
671 .unwrap();
672 }
673
674 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
675 assert_eq!(
676 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
677 ["5", "6", "7", "8", "9"]
678 );
679
680 let prev_messages = db
681 .get_channel_messages(channel, 4, Some(messages[0].id))
682 .await
683 .unwrap();
684 assert_eq!(
685 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
686 ["1", "2", "3", "4"]
687 );
688 }
689
690 #[gpui::test]
691 async fn test_channel_message_nonces() {
692 let test_db = TestDb::new();
693 let db = test_db.db();
694 let user = db.create_user("user", false).await.unwrap();
695 let org = db.create_org("org", "org").await.unwrap();
696 let channel = db.create_org_channel(org, "channel").await.unwrap();
697
698 let msg1_id = db
699 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
700 .await
701 .unwrap();
702 let msg2_id = db
703 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
704 .await
705 .unwrap();
706 let msg3_id = db
707 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
708 .await
709 .unwrap();
710 let msg4_id = db
711 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
712 .await
713 .unwrap();
714
715 assert_ne!(msg1_id, msg2_id);
716 assert_eq!(msg1_id, msg3_id);
717 assert_eq!(msg2_id, msg4_id);
718 }
719}