1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 ) -> Result<SignupId> {
51 test_support!(self, {
52 let query = "
53 INSERT INTO signups (github_login, email_address, about)
54 VALUES ($1, $2, $3)
55 RETURNING id
56 ";
57 sqlx::query_scalar(query)
58 .bind(github_login)
59 .bind(email_address)
60 .bind(about)
61 .fetch_one(&self.pool)
62 .await
63 .map(SignupId)
64 })
65 }
66
67 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
68 test_support!(self, {
69 let query = "SELECT * FROM signups ORDER BY github_login ASC";
70 sqlx::query_as(query).fetch_all(&self.pool).await
71 })
72 }
73
74 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
75 test_support!(self, {
76 let query = "DELETE FROM signups WHERE id = $1";
77 sqlx::query(query)
78 .bind(id.0)
79 .execute(&self.pool)
80 .await
81 .map(drop)
82 })
83 }
84
85 // users
86
87 #[allow(unused)] // Help rust-analyzer
88 #[cfg(any(test, feature = "seed-support"))]
89 pub async fn get_user(&self, github_login: &str) -> Result<Option<UserId>> {
90 test_support!(self, {
91 let query = "
92 SELECT id
93 FROM users
94 WHERE github_login = $1
95 ";
96 sqlx::query_scalar(query)
97 .bind(github_login)
98 .fetch_optional(&self.pool)
99 .await
100 })
101 }
102
103 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
104 test_support!(self, {
105 let query = "
106 INSERT INTO users (github_login, admin)
107 VALUES ($1, $2)
108 RETURNING id
109 ";
110 sqlx::query_scalar(query)
111 .bind(github_login)
112 .bind(admin)
113 .fetch_one(&self.pool)
114 .await
115 .map(UserId)
116 })
117 }
118
119 pub async fn get_all_users(&self) -> Result<Vec<User>> {
120 test_support!(self, {
121 let query = "SELECT * FROM users ORDER BY github_login ASC";
122 sqlx::query_as(query).fetch_all(&self.pool).await
123 })
124 }
125
126 pub async fn get_users_by_ids(
127 &self,
128 requester_id: UserId,
129 ids: impl Iterator<Item = UserId>,
130 ) -> Result<Vec<User>> {
131 let mut include_requester = false;
132 let ids = ids
133 .map(|id| {
134 if id == requester_id {
135 include_requester = true;
136 }
137 id.0
138 })
139 .collect::<Vec<_>>();
140
141 test_support!(self, {
142 // Only return users that are in a common channel with the requesting user.
143 // Also allow the requesting user to return their own data, even if they aren't
144 // in any channels.
145 let query = "
146 SELECT
147 users.*
148 FROM
149 users, channel_memberships
150 WHERE
151 users.id = ANY ($1) AND
152 channel_memberships.user_id = users.id AND
153 channel_memberships.channel_id IN (
154 SELECT channel_id
155 FROM channel_memberships
156 WHERE channel_memberships.user_id = $2
157 )
158 UNION
159 SELECT
160 users.*
161 FROM
162 users
163 WHERE
164 $3 AND users.id = $2
165 ";
166
167 sqlx::query_as(query)
168 .bind(&ids)
169 .bind(requester_id)
170 .bind(include_requester)
171 .fetch_all(&self.pool)
172 .await
173 })
174 }
175
176 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
177 test_support!(self, {
178 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
179 sqlx::query_as(query)
180 .bind(github_login)
181 .fetch_optional(&self.pool)
182 .await
183 })
184 }
185
186 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
187 test_support!(self, {
188 let query = "UPDATE users SET admin = $1 WHERE id = $2";
189 sqlx::query(query)
190 .bind(is_admin)
191 .bind(id.0)
192 .execute(&self.pool)
193 .await
194 .map(drop)
195 })
196 }
197
198 pub async fn delete_user(&self, id: UserId) -> Result<()> {
199 test_support!(self, {
200 let query = "DELETE FROM users WHERE id = $1;";
201 sqlx::query(query)
202 .bind(id.0)
203 .execute(&self.pool)
204 .await
205 .map(drop)
206 })
207 }
208
209 // access tokens
210
211 pub async fn create_access_token_hash(
212 &self,
213 user_id: UserId,
214 access_token_hash: String,
215 ) -> Result<()> {
216 test_support!(self, {
217 let query = "
218 INSERT INTO access_tokens (user_id, hash)
219 VALUES ($1, $2)
220 ";
221 sqlx::query(query)
222 .bind(user_id.0)
223 .bind(access_token_hash)
224 .execute(&self.pool)
225 .await
226 .map(drop)
227 })
228 }
229
230 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
231 test_support!(self, {
232 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
233 sqlx::query_scalar(query)
234 .bind(user_id.0)
235 .fetch_all(&self.pool)
236 .await
237 })
238 }
239
240 // orgs
241
242 #[allow(unused)] // Help rust-analyzer
243 #[cfg(any(test, feature = "seed-support"))]
244 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
245 test_support!(self, {
246 let query = "
247 SELECT *
248 FROM orgs
249 WHERE slug = $1
250 ";
251 sqlx::query_as(query)
252 .bind(slug)
253 .fetch_optional(&self.pool)
254 .await
255 })
256 }
257
258 #[cfg(any(test, feature = "seed-support"))]
259 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
260 test_support!(self, {
261 let query = "
262 INSERT INTO orgs (name, slug)
263 VALUES ($1, $2)
264 RETURNING id
265 ";
266 sqlx::query_scalar(query)
267 .bind(name)
268 .bind(slug)
269 .fetch_one(&self.pool)
270 .await
271 .map(OrgId)
272 })
273 }
274
275 #[cfg(any(test, feature = "seed-support"))]
276 pub async fn add_org_member(
277 &self,
278 org_id: OrgId,
279 user_id: UserId,
280 is_admin: bool,
281 ) -> Result<()> {
282 test_support!(self, {
283 let query = "
284 INSERT INTO org_memberships (org_id, user_id, admin)
285 VALUES ($1, $2, $3)
286 ON CONFLICT DO NOTHING
287 ";
288 sqlx::query(query)
289 .bind(org_id.0)
290 .bind(user_id.0)
291 .bind(is_admin)
292 .execute(&self.pool)
293 .await
294 .map(drop)
295 })
296 }
297
298 // channels
299
300 #[cfg(any(test, feature = "seed-support"))]
301 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
302 test_support!(self, {
303 let query = "
304 INSERT INTO channels (owner_id, owner_is_user, name)
305 VALUES ($1, false, $2)
306 RETURNING id
307 ";
308 sqlx::query_scalar(query)
309 .bind(org_id.0)
310 .bind(name)
311 .fetch_one(&self.pool)
312 .await
313 .map(ChannelId)
314 })
315 }
316
317 #[allow(unused)] // Help rust-analyzer
318 #[cfg(any(test, feature = "seed-support"))]
319 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
320 test_support!(self, {
321 let query = "
322 SELECT *
323 FROM channels
324 WHERE
325 channels.owner_is_user = false AND
326 channels.owner_id = $1
327 ";
328 sqlx::query_as(query)
329 .bind(org_id.0)
330 .fetch_all(&self.pool)
331 .await
332 })
333 }
334
335 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
336 test_support!(self, {
337 let query = "
338 SELECT
339 channels.id, channels.name
340 FROM
341 channel_memberships, channels
342 WHERE
343 channel_memberships.user_id = $1 AND
344 channel_memberships.channel_id = channels.id
345 ";
346 sqlx::query_as(query)
347 .bind(user_id.0)
348 .fetch_all(&self.pool)
349 .await
350 })
351 }
352
353 pub async fn can_user_access_channel(
354 &self,
355 user_id: UserId,
356 channel_id: ChannelId,
357 ) -> Result<bool> {
358 test_support!(self, {
359 let query = "
360 SELECT id
361 FROM channel_memberships
362 WHERE user_id = $1 AND channel_id = $2
363 LIMIT 1
364 ";
365 sqlx::query_scalar::<_, i32>(query)
366 .bind(user_id.0)
367 .bind(channel_id.0)
368 .fetch_optional(&self.pool)
369 .await
370 .map(|e| e.is_some())
371 })
372 }
373
374 #[cfg(any(test, feature = "seed-support"))]
375 pub async fn add_channel_member(
376 &self,
377 channel_id: ChannelId,
378 user_id: UserId,
379 is_admin: bool,
380 ) -> Result<()> {
381 test_support!(self, {
382 let query = "
383 INSERT INTO channel_memberships (channel_id, user_id, admin)
384 VALUES ($1, $2, $3)
385 ON CONFLICT DO NOTHING
386 ";
387 sqlx::query(query)
388 .bind(channel_id.0)
389 .bind(user_id.0)
390 .bind(is_admin)
391 .execute(&self.pool)
392 .await
393 .map(drop)
394 })
395 }
396
397 // messages
398
399 pub async fn create_channel_message(
400 &self,
401 channel_id: ChannelId,
402 sender_id: UserId,
403 body: &str,
404 timestamp: OffsetDateTime,
405 ) -> Result<MessageId> {
406 test_support!(self, {
407 let query = "
408 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
409 VALUES ($1, $2, $3, $4)
410 RETURNING id
411 ";
412 sqlx::query_scalar(query)
413 .bind(channel_id.0)
414 .bind(sender_id.0)
415 .bind(body)
416 .bind(timestamp)
417 .fetch_one(&self.pool)
418 .await
419 .map(MessageId)
420 })
421 }
422
423 pub async fn get_channel_messages(
424 &self,
425 channel_id: ChannelId,
426 count: usize,
427 before_id: Option<MessageId>,
428 ) -> Result<Vec<ChannelMessage>> {
429 test_support!(self, {
430 let query = r#"
431 SELECT * FROM (
432 SELECT
433 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
434 FROM
435 channel_messages
436 WHERE
437 channel_id = $1 AND
438 id < $2
439 ORDER BY id DESC
440 LIMIT $3
441 ) as recent_messages
442 ORDER BY id ASC
443 "#;
444 sqlx::query_as(query)
445 .bind(channel_id.0)
446 .bind(before_id.unwrap_or(MessageId::MAX))
447 .bind(count as i64)
448 .fetch_all(&self.pool)
449 .await
450 })
451 }
452}
453
454macro_rules! id_type {
455 ($name:ident) => {
456 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
457 #[sqlx(transparent)]
458 #[serde(transparent)]
459 pub struct $name(pub i32);
460
461 impl $name {
462 #[allow(unused)]
463 pub const MAX: Self = Self(i32::MAX);
464
465 #[allow(unused)]
466 pub fn from_proto(value: u64) -> Self {
467 Self(value as i32)
468 }
469
470 #[allow(unused)]
471 pub fn to_proto(&self) -> u64 {
472 self.0 as u64
473 }
474 }
475 };
476}
477
478id_type!(UserId);
479#[derive(Debug, FromRow, Serialize, PartialEq)]
480pub struct User {
481 pub id: UserId,
482 pub github_login: String,
483 pub admin: bool,
484}
485
486id_type!(OrgId);
487#[derive(FromRow)]
488pub struct Org {
489 pub id: OrgId,
490 pub name: String,
491 pub slug: String,
492}
493
494id_type!(SignupId);
495#[derive(Debug, FromRow, Serialize)]
496pub struct Signup {
497 pub id: SignupId,
498 pub github_login: String,
499 pub email_address: String,
500 pub about: String,
501}
502
503id_type!(ChannelId);
504#[derive(Debug, FromRow, Serialize)]
505pub struct Channel {
506 pub id: ChannelId,
507 pub name: String,
508}
509
510id_type!(MessageId);
511#[derive(Debug, FromRow)]
512pub struct ChannelMessage {
513 pub id: MessageId,
514 pub sender_id: UserId,
515 pub body: String,
516 pub sent_at: OffsetDateTime,
517}
518
519#[cfg(test)]
520pub mod tests {
521 use super::*;
522 use rand::prelude::*;
523 use sqlx::{
524 migrate::{MigrateDatabase, Migrator},
525 Postgres,
526 };
527 use std::path::Path;
528
529 pub struct TestDb {
530 pub db: Db,
531 pub name: String,
532 pub url: String,
533 }
534
535 impl TestDb {
536 pub fn new() -> Self {
537 // Enable tests to run in parallel by serializing the creation of each test database.
538 lazy_static::lazy_static! {
539 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
540 }
541
542 let mut rng = StdRng::from_entropy();
543 let name = format!("zed-test-{}", rng.gen::<u128>());
544 let url = format!("postgres://postgres@localhost/{}", name);
545 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
546 let db = block_on(async {
547 {
548 let _lock = DB_CREATION.lock();
549 Postgres::create_database(&url)
550 .await
551 .expect("failed to create test db");
552 }
553 let mut db = Db::new(&url, 5).await.unwrap();
554 db.test_mode = true;
555 let migrator = Migrator::new(migrations_path).await.unwrap();
556 migrator.run(&db.pool).await.unwrap();
557 db
558 });
559
560 Self { db, name, url }
561 }
562
563 pub fn db(&self) -> &Db {
564 &self.db
565 }
566 }
567
568 impl Drop for TestDb {
569 fn drop(&mut self) {
570 block_on(async {
571 let query = "
572 SELECT pg_terminate_backend(pg_stat_activity.pid)
573 FROM pg_stat_activity
574 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
575 ";
576 sqlx::query(query)
577 .bind(&self.name)
578 .execute(&self.db.pool)
579 .await
580 .unwrap();
581 self.db.pool.close().await;
582 Postgres::drop_database(&self.url).await.unwrap();
583 });
584 }
585 }
586
587 #[gpui::test]
588 async fn test_get_users_by_ids() {
589 let test_db = TestDb::new();
590 let db = test_db.db();
591
592 let user = db.create_user("user", false).await.unwrap();
593 let friend1 = db.create_user("friend-1", false).await.unwrap();
594 let friend2 = db.create_user("friend-2", false).await.unwrap();
595 let friend3 = db.create_user("friend-3", false).await.unwrap();
596 let stranger = db.create_user("stranger", false).await.unwrap();
597
598 // A user can read their own info, even if they aren't in any channels.
599 assert_eq!(
600 db.get_users_by_ids(
601 user,
602 [user, friend1, friend2, friend3, stranger].iter().copied()
603 )
604 .await
605 .unwrap(),
606 vec![User {
607 id: user,
608 github_login: "user".to_string(),
609 admin: false,
610 },],
611 );
612
613 // A user can read the info of any other user who is in a shared channel
614 // with them.
615 let org = db.create_org("test org", "test-org").await.unwrap();
616 let chan1 = db.create_org_channel(org, "channel-1").await.unwrap();
617 let chan2 = db.create_org_channel(org, "channel-2").await.unwrap();
618 let chan3 = db.create_org_channel(org, "channel-3").await.unwrap();
619
620 db.add_channel_member(chan1, user, false).await.unwrap();
621 db.add_channel_member(chan2, user, false).await.unwrap();
622 db.add_channel_member(chan1, friend1, false).await.unwrap();
623 db.add_channel_member(chan1, friend2, false).await.unwrap();
624 db.add_channel_member(chan2, friend2, false).await.unwrap();
625 db.add_channel_member(chan2, friend3, false).await.unwrap();
626 db.add_channel_member(chan3, stranger, false).await.unwrap();
627
628 assert_eq!(
629 db.get_users_by_ids(
630 user,
631 [user, friend1, friend2, friend3, stranger].iter().copied()
632 )
633 .await
634 .unwrap(),
635 vec![
636 User {
637 id: user,
638 github_login: "user".to_string(),
639 admin: false,
640 },
641 User {
642 id: friend1,
643 github_login: "friend-1".to_string(),
644 admin: false,
645 },
646 User {
647 id: friend2,
648 github_login: "friend-2".to_string(),
649 admin: false,
650 },
651 User {
652 id: friend3,
653 github_login: "friend-3".to_string(),
654 admin: false,
655 }
656 ]
657 );
658
659 // The user's own info is only returned if they request it.
660 assert_eq!(
661 db.get_users_by_ids(user, [friend1].iter().copied())
662 .await
663 .unwrap(),
664 vec![User {
665 id: friend1,
666 github_login: "friend-1".to_string(),
667 admin: false,
668 },]
669 )
670 }
671
672 #[gpui::test]
673 async fn test_recent_channel_messages() {
674 let test_db = TestDb::new();
675 let db = test_db.db();
676 let user = db.create_user("user", false).await.unwrap();
677 let org = db.create_org("org", "org").await.unwrap();
678 let channel = db.create_org_channel(org, "channel").await.unwrap();
679 for i in 0..10 {
680 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc())
681 .await
682 .unwrap();
683 }
684
685 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
686 assert_eq!(
687 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
688 ["5", "6", "7", "8", "9"]
689 );
690
691 let prev_messages = db
692 .get_channel_messages(channel, 4, Some(messages[0].id))
693 .await
694 .unwrap();
695 assert_eq!(
696 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
697 ["1", "2", "3", "4"]
698 );
699 }
700}