1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{types::Uuid, FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 wants_releases: bool,
51 wants_updates: bool,
52 wants_community: bool,
53 ) -> Result<SignupId> {
54 test_support!(self, {
55 let query = "
56 INSERT INTO signups (
57 github_login,
58 email_address,
59 about,
60 wants_releases,
61 wants_updates,
62 wants_community
63 )
64 VALUES ($1, $2, $3, $4, $5, $6)
65 RETURNING id
66 ";
67 sqlx::query_scalar(query)
68 .bind(github_login)
69 .bind(email_address)
70 .bind(about)
71 .bind(wants_releases)
72 .bind(wants_updates)
73 .bind(wants_community)
74 .fetch_one(&self.pool)
75 .await
76 .map(SignupId)
77 })
78 }
79
80 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
81 test_support!(self, {
82 let query = "SELECT * FROM signups ORDER BY github_login ASC";
83 sqlx::query_as(query).fetch_all(&self.pool).await
84 })
85 }
86
87 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
88 test_support!(self, {
89 let query = "DELETE FROM signups WHERE id = $1";
90 sqlx::query(query)
91 .bind(id.0)
92 .execute(&self.pool)
93 .await
94 .map(drop)
95 })
96 }
97
98 // users
99
100 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
101 test_support!(self, {
102 let query = "
103 INSERT INTO users (github_login, admin)
104 VALUES ($1, $2)
105 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
106 RETURNING id
107 ";
108 sqlx::query_scalar(query)
109 .bind(github_login)
110 .bind(admin)
111 .fetch_one(&self.pool)
112 .await
113 .map(UserId)
114 })
115 }
116
117 pub async fn get_all_users(&self) -> Result<Vec<User>> {
118 test_support!(self, {
119 let query = "SELECT * FROM users ORDER BY github_login ASC";
120 sqlx::query_as(query).fetch_all(&self.pool).await
121 })
122 }
123
124 pub async fn get_users_by_ids(
125 &self,
126 ids: impl IntoIterator<Item = UserId>,
127 ) -> Result<Vec<User>> {
128 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
129 test_support!(self, {
130 let query = "
131 SELECT users.*
132 FROM users
133 WHERE users.id = ANY ($1)
134 ";
135
136 sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
137 })
138 }
139
140 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
141 test_support!(self, {
142 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
143 sqlx::query_as(query)
144 .bind(github_login)
145 .fetch_optional(&self.pool)
146 .await
147 })
148 }
149
150 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
151 test_support!(self, {
152 let query = "UPDATE users SET admin = $1 WHERE id = $2";
153 sqlx::query(query)
154 .bind(is_admin)
155 .bind(id.0)
156 .execute(&self.pool)
157 .await
158 .map(drop)
159 })
160 }
161
162 pub async fn delete_user(&self, id: UserId) -> Result<()> {
163 test_support!(self, {
164 let query = "DELETE FROM users WHERE id = $1;";
165 sqlx::query(query)
166 .bind(id.0)
167 .execute(&self.pool)
168 .await
169 .map(drop)
170 })
171 }
172
173 // access tokens
174
175 pub async fn create_access_token_hash(
176 &self,
177 user_id: UserId,
178 access_token_hash: &str,
179 max_access_token_count: usize,
180 ) -> Result<()> {
181 test_support!(self, {
182 let insert_query = "
183 INSERT INTO access_tokens (user_id, hash)
184 VALUES ($1, $2);
185 ";
186 let cleanup_query = "
187 DELETE FROM access_tokens
188 WHERE id IN (
189 SELECT id from access_tokens
190 WHERE user_id = $1
191 ORDER BY id DESC
192 OFFSET $3
193 )
194 ";
195
196 let mut tx = self.pool.begin().await?;
197 sqlx::query(insert_query)
198 .bind(user_id.0)
199 .bind(access_token_hash)
200 .execute(&mut tx)
201 .await?;
202 sqlx::query(cleanup_query)
203 .bind(user_id.0)
204 .bind(access_token_hash)
205 .bind(max_access_token_count as u32)
206 .execute(&mut tx)
207 .await?;
208 tx.commit().await
209 })
210 }
211
212 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
213 test_support!(self, {
214 let query = "
215 SELECT hash
216 FROM access_tokens
217 WHERE user_id = $1
218 ORDER BY id DESC
219 ";
220 sqlx::query_scalar(query)
221 .bind(user_id.0)
222 .fetch_all(&self.pool)
223 .await
224 })
225 }
226
227 // orgs
228
229 #[allow(unused)] // Help rust-analyzer
230 #[cfg(any(test, feature = "seed-support"))]
231 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
232 test_support!(self, {
233 let query = "
234 SELECT *
235 FROM orgs
236 WHERE slug = $1
237 ";
238 sqlx::query_as(query)
239 .bind(slug)
240 .fetch_optional(&self.pool)
241 .await
242 })
243 }
244
245 #[cfg(any(test, feature = "seed-support"))]
246 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
247 test_support!(self, {
248 let query = "
249 INSERT INTO orgs (name, slug)
250 VALUES ($1, $2)
251 RETURNING id
252 ";
253 sqlx::query_scalar(query)
254 .bind(name)
255 .bind(slug)
256 .fetch_one(&self.pool)
257 .await
258 .map(OrgId)
259 })
260 }
261
262 #[cfg(any(test, feature = "seed-support"))]
263 pub async fn add_org_member(
264 &self,
265 org_id: OrgId,
266 user_id: UserId,
267 is_admin: bool,
268 ) -> Result<()> {
269 test_support!(self, {
270 let query = "
271 INSERT INTO org_memberships (org_id, user_id, admin)
272 VALUES ($1, $2, $3)
273 ON CONFLICT DO NOTHING
274 ";
275 sqlx::query(query)
276 .bind(org_id.0)
277 .bind(user_id.0)
278 .bind(is_admin)
279 .execute(&self.pool)
280 .await
281 .map(drop)
282 })
283 }
284
285 // channels
286
287 #[cfg(any(test, feature = "seed-support"))]
288 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
289 test_support!(self, {
290 let query = "
291 INSERT INTO channels (owner_id, owner_is_user, name)
292 VALUES ($1, false, $2)
293 RETURNING id
294 ";
295 sqlx::query_scalar(query)
296 .bind(org_id.0)
297 .bind(name)
298 .fetch_one(&self.pool)
299 .await
300 .map(ChannelId)
301 })
302 }
303
304 #[allow(unused)] // Help rust-analyzer
305 #[cfg(any(test, feature = "seed-support"))]
306 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
307 test_support!(self, {
308 let query = "
309 SELECT *
310 FROM channels
311 WHERE
312 channels.owner_is_user = false AND
313 channels.owner_id = $1
314 ";
315 sqlx::query_as(query)
316 .bind(org_id.0)
317 .fetch_all(&self.pool)
318 .await
319 })
320 }
321
322 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
323 test_support!(self, {
324 let query = "
325 SELECT
326 channels.id, channels.name
327 FROM
328 channel_memberships, channels
329 WHERE
330 channel_memberships.user_id = $1 AND
331 channel_memberships.channel_id = channels.id
332 ";
333 sqlx::query_as(query)
334 .bind(user_id.0)
335 .fetch_all(&self.pool)
336 .await
337 })
338 }
339
340 pub async fn can_user_access_channel(
341 &self,
342 user_id: UserId,
343 channel_id: ChannelId,
344 ) -> Result<bool> {
345 test_support!(self, {
346 let query = "
347 SELECT id
348 FROM channel_memberships
349 WHERE user_id = $1 AND channel_id = $2
350 LIMIT 1
351 ";
352 sqlx::query_scalar::<_, i32>(query)
353 .bind(user_id.0)
354 .bind(channel_id.0)
355 .fetch_optional(&self.pool)
356 .await
357 .map(|e| e.is_some())
358 })
359 }
360
361 #[cfg(any(test, feature = "seed-support"))]
362 pub async fn add_channel_member(
363 &self,
364 channel_id: ChannelId,
365 user_id: UserId,
366 is_admin: bool,
367 ) -> Result<()> {
368 test_support!(self, {
369 let query = "
370 INSERT INTO channel_memberships (channel_id, user_id, admin)
371 VALUES ($1, $2, $3)
372 ON CONFLICT DO NOTHING
373 ";
374 sqlx::query(query)
375 .bind(channel_id.0)
376 .bind(user_id.0)
377 .bind(is_admin)
378 .execute(&self.pool)
379 .await
380 .map(drop)
381 })
382 }
383
384 // messages
385
386 pub async fn create_channel_message(
387 &self,
388 channel_id: ChannelId,
389 sender_id: UserId,
390 body: &str,
391 timestamp: OffsetDateTime,
392 nonce: u128,
393 ) -> Result<MessageId> {
394 test_support!(self, {
395 let query = "
396 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
397 VALUES ($1, $2, $3, $4, $5)
398 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
399 RETURNING id
400 ";
401 sqlx::query_scalar(query)
402 .bind(channel_id.0)
403 .bind(sender_id.0)
404 .bind(body)
405 .bind(timestamp)
406 .bind(Uuid::from_u128(nonce))
407 .fetch_one(&self.pool)
408 .await
409 .map(MessageId)
410 })
411 }
412
413 pub async fn get_channel_messages(
414 &self,
415 channel_id: ChannelId,
416 count: usize,
417 before_id: Option<MessageId>,
418 ) -> Result<Vec<ChannelMessage>> {
419 test_support!(self, {
420 let query = r#"
421 SELECT * FROM (
422 SELECT
423 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
424 FROM
425 channel_messages
426 WHERE
427 channel_id = $1 AND
428 id < $2
429 ORDER BY id DESC
430 LIMIT $3
431 ) as recent_messages
432 ORDER BY id ASC
433 "#;
434 sqlx::query_as(query)
435 .bind(channel_id.0)
436 .bind(before_id.unwrap_or(MessageId::MAX))
437 .bind(count as i64)
438 .fetch_all(&self.pool)
439 .await
440 })
441 }
442}
443
444macro_rules! id_type {
445 ($name:ident) => {
446 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
447 #[sqlx(transparent)]
448 #[serde(transparent)]
449 pub struct $name(pub i32);
450
451 impl $name {
452 #[allow(unused)]
453 pub const MAX: Self = Self(i32::MAX);
454
455 #[allow(unused)]
456 pub fn from_proto(value: u64) -> Self {
457 Self(value as i32)
458 }
459
460 #[allow(unused)]
461 pub fn to_proto(&self) -> u64 {
462 self.0 as u64
463 }
464 }
465 };
466}
467
468id_type!(UserId);
469#[derive(Debug, FromRow, Serialize, PartialEq)]
470pub struct User {
471 pub id: UserId,
472 pub github_login: String,
473 pub admin: bool,
474}
475
476id_type!(OrgId);
477#[derive(FromRow)]
478pub struct Org {
479 pub id: OrgId,
480 pub name: String,
481 pub slug: String,
482}
483
484id_type!(SignupId);
485#[derive(Debug, FromRow, Serialize)]
486pub struct Signup {
487 pub id: SignupId,
488 pub github_login: String,
489 pub email_address: String,
490 pub about: String,
491 pub wants_releases: Option<bool>,
492 pub wants_updates: Option<bool>,
493 pub wants_community: Option<bool>,
494}
495
496id_type!(ChannelId);
497#[derive(Debug, FromRow, Serialize)]
498pub struct Channel {
499 pub id: ChannelId,
500 pub name: String,
501}
502
503id_type!(MessageId);
504#[derive(Debug, FromRow)]
505pub struct ChannelMessage {
506 pub id: MessageId,
507 pub sender_id: UserId,
508 pub body: String,
509 pub sent_at: OffsetDateTime,
510 pub nonce: Uuid,
511}
512
513#[cfg(test)]
514pub mod tests {
515 use super::*;
516 use rand::prelude::*;
517 use sqlx::{
518 migrate::{MigrateDatabase, Migrator},
519 Postgres,
520 };
521 use std::path::Path;
522
523 pub struct TestDb {
524 pub db: Db,
525 pub name: String,
526 pub url: String,
527 }
528
529 impl TestDb {
530 pub fn new() -> Self {
531 // Enable tests to run in parallel by serializing the creation of each test database.
532 lazy_static::lazy_static! {
533 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
534 }
535
536 let mut rng = StdRng::from_entropy();
537 let name = format!("zed-test-{}", rng.gen::<u128>());
538 let url = format!("postgres://postgres@localhost/{}", name);
539 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
540 let db = block_on(async {
541 {
542 let _lock = DB_CREATION.lock();
543 Postgres::create_database(&url)
544 .await
545 .expect("failed to create test db");
546 }
547 let mut db = Db::new(&url, 5).await.unwrap();
548 db.test_mode = true;
549 let migrator = Migrator::new(migrations_path).await.unwrap();
550 migrator.run(&db.pool).await.unwrap();
551 db
552 });
553
554 Self { db, name, url }
555 }
556
557 pub fn db(&self) -> &Db {
558 &self.db
559 }
560 }
561
562 impl Drop for TestDb {
563 fn drop(&mut self) {
564 block_on(async {
565 let query = "
566 SELECT pg_terminate_backend(pg_stat_activity.pid)
567 FROM pg_stat_activity
568 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
569 ";
570 sqlx::query(query)
571 .bind(&self.name)
572 .execute(&self.db.pool)
573 .await
574 .unwrap();
575 self.db.pool.close().await;
576 Postgres::drop_database(&self.url).await.unwrap();
577 });
578 }
579 }
580
581 #[gpui::test]
582 async fn test_get_users_by_ids() {
583 let test_db = TestDb::new();
584 let db = test_db.db();
585
586 let user = db.create_user("user", false).await.unwrap();
587 let friend1 = db.create_user("friend-1", false).await.unwrap();
588 let friend2 = db.create_user("friend-2", false).await.unwrap();
589 let friend3 = db.create_user("friend-3", false).await.unwrap();
590
591 assert_eq!(
592 db.get_users_by_ids([user, friend1, friend2, friend3])
593 .await
594 .unwrap(),
595 vec![
596 User {
597 id: user,
598 github_login: "user".to_string(),
599 admin: false,
600 },
601 User {
602 id: friend1,
603 github_login: "friend-1".to_string(),
604 admin: false,
605 },
606 User {
607 id: friend2,
608 github_login: "friend-2".to_string(),
609 admin: false,
610 },
611 User {
612 id: friend3,
613 github_login: "friend-3".to_string(),
614 admin: false,
615 }
616 ]
617 );
618 }
619
620 #[gpui::test]
621 async fn test_recent_channel_messages() {
622 let test_db = TestDb::new();
623 let db = test_db.db();
624 let user = db.create_user("user", false).await.unwrap();
625 let org = db.create_org("org", "org").await.unwrap();
626 let channel = db.create_org_channel(org, "channel").await.unwrap();
627 for i in 0..10 {
628 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
629 .await
630 .unwrap();
631 }
632
633 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
634 assert_eq!(
635 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
636 ["5", "6", "7", "8", "9"]
637 );
638
639 let prev_messages = db
640 .get_channel_messages(channel, 4, Some(messages[0].id))
641 .await
642 .unwrap();
643 assert_eq!(
644 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
645 ["1", "2", "3", "4"]
646 );
647 }
648
649 #[gpui::test]
650 async fn test_channel_message_nonces() {
651 let test_db = TestDb::new();
652 let db = test_db.db();
653 let user = db.create_user("user", false).await.unwrap();
654 let org = db.create_org("org", "org").await.unwrap();
655 let channel = db.create_org_channel(org, "channel").await.unwrap();
656
657 let msg1_id = db
658 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
659 .await
660 .unwrap();
661 let msg2_id = db
662 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
663 .await
664 .unwrap();
665 let msg3_id = db
666 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
667 .await
668 .unwrap();
669 let msg4_id = db
670 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
671 .await
672 .unwrap();
673
674 assert_ne!(msg1_id, msg2_id);
675 assert_eq!(msg1_id, msg3_id);
676 assert_eq!(msg2_id, msg4_id);
677 }
678
679 #[gpui::test]
680 async fn test_create_access_tokens() {
681 let test_db = TestDb::new();
682 let db = test_db.db();
683 let user = db.create_user("the-user", false).await.unwrap();
684
685 db.create_access_token_hash(user, "h1", 3).await.unwrap();
686 db.create_access_token_hash(user, "h2", 3).await.unwrap();
687 assert_eq!(
688 db.get_access_token_hashes(user).await.unwrap(),
689 &["h2".to_string(), "h1".to_string()]
690 );
691
692 db.create_access_token_hash(user, "h3", 3).await.unwrap();
693 assert_eq!(
694 db.get_access_token_hashes(user).await.unwrap(),
695 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
696 );
697
698 db.create_access_token_hash(user, "h4", 3).await.unwrap();
699 assert_eq!(
700 db.get_access_token_hashes(user).await.unwrap(),
701 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
702 );
703
704 db.create_access_token_hash(user, "h5", 3).await.unwrap();
705 assert_eq!(
706 db.get_access_token_hashes(user).await.unwrap(),
707 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
708 );
709 }
710}