1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{types::Uuid, FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 ) -> Result<SignupId> {
51 test_support!(self, {
52 let query = "
53 INSERT INTO signups (github_login, email_address, about)
54 VALUES ($1, $2, $3)
55 RETURNING id
56 ";
57 sqlx::query_scalar(query)
58 .bind(github_login)
59 .bind(email_address)
60 .bind(about)
61 .fetch_one(&self.pool)
62 .await
63 .map(SignupId)
64 })
65 }
66
67 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
68 test_support!(self, {
69 let query = "SELECT * FROM signups ORDER BY github_login ASC";
70 sqlx::query_as(query).fetch_all(&self.pool).await
71 })
72 }
73
74 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
75 test_support!(self, {
76 let query = "DELETE FROM signups WHERE id = $1";
77 sqlx::query(query)
78 .bind(id.0)
79 .execute(&self.pool)
80 .await
81 .map(drop)
82 })
83 }
84
85 // users
86
87 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
88 test_support!(self, {
89 let query = "
90 INSERT INTO users (github_login, admin)
91 VALUES ($1, $2)
92 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
93 RETURNING id
94 ";
95 sqlx::query_scalar(query)
96 .bind(github_login)
97 .bind(admin)
98 .fetch_one(&self.pool)
99 .await
100 .map(UserId)
101 })
102 }
103
104 pub async fn get_all_users(&self) -> Result<Vec<User>> {
105 test_support!(self, {
106 let query = "SELECT * FROM users ORDER BY github_login ASC";
107 sqlx::query_as(query).fetch_all(&self.pool).await
108 })
109 }
110
111 pub async fn get_users_by_ids(
112 &self,
113 ids: impl IntoIterator<Item = UserId>,
114 ) -> Result<Vec<User>> {
115 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
116 test_support!(self, {
117 let query = "
118 SELECT users.*
119 FROM users
120 WHERE users.id = ANY ($1)
121 ";
122
123 sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
124 })
125 }
126
127 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
128 test_support!(self, {
129 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
130 sqlx::query_as(query)
131 .bind(github_login)
132 .fetch_optional(&self.pool)
133 .await
134 })
135 }
136
137 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
138 test_support!(self, {
139 let query = "UPDATE users SET admin = $1 WHERE id = $2";
140 sqlx::query(query)
141 .bind(is_admin)
142 .bind(id.0)
143 .execute(&self.pool)
144 .await
145 .map(drop)
146 })
147 }
148
149 pub async fn delete_user(&self, id: UserId) -> Result<()> {
150 test_support!(self, {
151 let query = "DELETE FROM users WHERE id = $1;";
152 sqlx::query(query)
153 .bind(id.0)
154 .execute(&self.pool)
155 .await
156 .map(drop)
157 })
158 }
159
160 // access tokens
161
162 pub async fn create_access_token_hash(
163 &self,
164 user_id: UserId,
165 access_token_hash: &str,
166 max_access_token_count: usize,
167 ) -> Result<()> {
168 test_support!(self, {
169 let insert_query = "
170 INSERT INTO access_tokens (user_id, hash)
171 VALUES ($1, $2);
172 ";
173 let cleanup_query = "
174 DELETE FROM access_tokens
175 WHERE id IN (
176 SELECT id from access_tokens
177 WHERE user_id = $1
178 ORDER BY id DESC
179 OFFSET $3
180 )
181 ";
182
183 let mut tx = self.pool.begin().await?;
184 sqlx::query(insert_query)
185 .bind(user_id.0)
186 .bind(access_token_hash)
187 .execute(&mut tx)
188 .await?;
189 sqlx::query(cleanup_query)
190 .bind(user_id.0)
191 .bind(access_token_hash)
192 .bind(max_access_token_count as u32)
193 .execute(&mut tx)
194 .await?;
195 tx.commit().await
196 })
197 }
198
199 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
200 test_support!(self, {
201 let query = "
202 SELECT hash
203 FROM access_tokens
204 WHERE user_id = $1
205 ORDER BY id DESC
206 ";
207 sqlx::query_scalar(query)
208 .bind(user_id.0)
209 .fetch_all(&self.pool)
210 .await
211 })
212 }
213
214 // orgs
215
216 #[allow(unused)] // Help rust-analyzer
217 #[cfg(any(test, feature = "seed-support"))]
218 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
219 test_support!(self, {
220 let query = "
221 SELECT *
222 FROM orgs
223 WHERE slug = $1
224 ";
225 sqlx::query_as(query)
226 .bind(slug)
227 .fetch_optional(&self.pool)
228 .await
229 })
230 }
231
232 #[cfg(any(test, feature = "seed-support"))]
233 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
234 test_support!(self, {
235 let query = "
236 INSERT INTO orgs (name, slug)
237 VALUES ($1, $2)
238 RETURNING id
239 ";
240 sqlx::query_scalar(query)
241 .bind(name)
242 .bind(slug)
243 .fetch_one(&self.pool)
244 .await
245 .map(OrgId)
246 })
247 }
248
249 #[cfg(any(test, feature = "seed-support"))]
250 pub async fn add_org_member(
251 &self,
252 org_id: OrgId,
253 user_id: UserId,
254 is_admin: bool,
255 ) -> Result<()> {
256 test_support!(self, {
257 let query = "
258 INSERT INTO org_memberships (org_id, user_id, admin)
259 VALUES ($1, $2, $3)
260 ON CONFLICT DO NOTHING
261 ";
262 sqlx::query(query)
263 .bind(org_id.0)
264 .bind(user_id.0)
265 .bind(is_admin)
266 .execute(&self.pool)
267 .await
268 .map(drop)
269 })
270 }
271
272 // channels
273
274 #[cfg(any(test, feature = "seed-support"))]
275 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
276 test_support!(self, {
277 let query = "
278 INSERT INTO channels (owner_id, owner_is_user, name)
279 VALUES ($1, false, $2)
280 RETURNING id
281 ";
282 sqlx::query_scalar(query)
283 .bind(org_id.0)
284 .bind(name)
285 .fetch_one(&self.pool)
286 .await
287 .map(ChannelId)
288 })
289 }
290
291 #[allow(unused)] // Help rust-analyzer
292 #[cfg(any(test, feature = "seed-support"))]
293 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
294 test_support!(self, {
295 let query = "
296 SELECT *
297 FROM channels
298 WHERE
299 channels.owner_is_user = false AND
300 channels.owner_id = $1
301 ";
302 sqlx::query_as(query)
303 .bind(org_id.0)
304 .fetch_all(&self.pool)
305 .await
306 })
307 }
308
309 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
310 test_support!(self, {
311 let query = "
312 SELECT
313 channels.id, channels.name
314 FROM
315 channel_memberships, channels
316 WHERE
317 channel_memberships.user_id = $1 AND
318 channel_memberships.channel_id = channels.id
319 ";
320 sqlx::query_as(query)
321 .bind(user_id.0)
322 .fetch_all(&self.pool)
323 .await
324 })
325 }
326
327 pub async fn can_user_access_channel(
328 &self,
329 user_id: UserId,
330 channel_id: ChannelId,
331 ) -> Result<bool> {
332 test_support!(self, {
333 let query = "
334 SELECT id
335 FROM channel_memberships
336 WHERE user_id = $1 AND channel_id = $2
337 LIMIT 1
338 ";
339 sqlx::query_scalar::<_, i32>(query)
340 .bind(user_id.0)
341 .bind(channel_id.0)
342 .fetch_optional(&self.pool)
343 .await
344 .map(|e| e.is_some())
345 })
346 }
347
348 #[cfg(any(test, feature = "seed-support"))]
349 pub async fn add_channel_member(
350 &self,
351 channel_id: ChannelId,
352 user_id: UserId,
353 is_admin: bool,
354 ) -> Result<()> {
355 test_support!(self, {
356 let query = "
357 INSERT INTO channel_memberships (channel_id, user_id, admin)
358 VALUES ($1, $2, $3)
359 ON CONFLICT DO NOTHING
360 ";
361 sqlx::query(query)
362 .bind(channel_id.0)
363 .bind(user_id.0)
364 .bind(is_admin)
365 .execute(&self.pool)
366 .await
367 .map(drop)
368 })
369 }
370
371 // messages
372
373 pub async fn create_channel_message(
374 &self,
375 channel_id: ChannelId,
376 sender_id: UserId,
377 body: &str,
378 timestamp: OffsetDateTime,
379 nonce: u128,
380 ) -> Result<MessageId> {
381 test_support!(self, {
382 let query = "
383 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
384 VALUES ($1, $2, $3, $4, $5)
385 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
386 RETURNING id
387 ";
388 sqlx::query_scalar(query)
389 .bind(channel_id.0)
390 .bind(sender_id.0)
391 .bind(body)
392 .bind(timestamp)
393 .bind(Uuid::from_u128(nonce))
394 .fetch_one(&self.pool)
395 .await
396 .map(MessageId)
397 })
398 }
399
400 pub async fn get_channel_messages(
401 &self,
402 channel_id: ChannelId,
403 count: usize,
404 before_id: Option<MessageId>,
405 ) -> Result<Vec<ChannelMessage>> {
406 test_support!(self, {
407 let query = r#"
408 SELECT * FROM (
409 SELECT
410 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
411 FROM
412 channel_messages
413 WHERE
414 channel_id = $1 AND
415 id < $2
416 ORDER BY id DESC
417 LIMIT $3
418 ) as recent_messages
419 ORDER BY id ASC
420 "#;
421 sqlx::query_as(query)
422 .bind(channel_id.0)
423 .bind(before_id.unwrap_or(MessageId::MAX))
424 .bind(count as i64)
425 .fetch_all(&self.pool)
426 .await
427 })
428 }
429}
430
431macro_rules! id_type {
432 ($name:ident) => {
433 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
434 #[sqlx(transparent)]
435 #[serde(transparent)]
436 pub struct $name(pub i32);
437
438 impl $name {
439 #[allow(unused)]
440 pub const MAX: Self = Self(i32::MAX);
441
442 #[allow(unused)]
443 pub fn from_proto(value: u64) -> Self {
444 Self(value as i32)
445 }
446
447 #[allow(unused)]
448 pub fn to_proto(&self) -> u64 {
449 self.0 as u64
450 }
451 }
452 };
453}
454
455id_type!(UserId);
456#[derive(Debug, FromRow, Serialize, PartialEq)]
457pub struct User {
458 pub id: UserId,
459 pub github_login: String,
460 pub admin: bool,
461}
462
463id_type!(OrgId);
464#[derive(FromRow)]
465pub struct Org {
466 pub id: OrgId,
467 pub name: String,
468 pub slug: String,
469}
470
471id_type!(SignupId);
472#[derive(Debug, FromRow, Serialize)]
473pub struct Signup {
474 pub id: SignupId,
475 pub github_login: String,
476 pub email_address: String,
477 pub about: String,
478}
479
480id_type!(ChannelId);
481#[derive(Debug, FromRow, Serialize)]
482pub struct Channel {
483 pub id: ChannelId,
484 pub name: String,
485}
486
487id_type!(MessageId);
488#[derive(Debug, FromRow)]
489pub struct ChannelMessage {
490 pub id: MessageId,
491 pub sender_id: UserId,
492 pub body: String,
493 pub sent_at: OffsetDateTime,
494 pub nonce: Uuid,
495}
496
497#[cfg(test)]
498pub mod tests {
499 use super::*;
500 use rand::prelude::*;
501 use sqlx::{
502 migrate::{MigrateDatabase, Migrator},
503 Postgres,
504 };
505 use std::path::Path;
506
507 pub struct TestDb {
508 pub db: Db,
509 pub name: String,
510 pub url: String,
511 }
512
513 impl TestDb {
514 pub fn new() -> Self {
515 // Enable tests to run in parallel by serializing the creation of each test database.
516 lazy_static::lazy_static! {
517 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
518 }
519
520 let mut rng = StdRng::from_entropy();
521 let name = format!("zed-test-{}", rng.gen::<u128>());
522 let url = format!("postgres://postgres@localhost/{}", name);
523 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
524 let db = block_on(async {
525 {
526 let _lock = DB_CREATION.lock();
527 Postgres::create_database(&url)
528 .await
529 .expect("failed to create test db");
530 }
531 let mut db = Db::new(&url, 5).await.unwrap();
532 db.test_mode = true;
533 let migrator = Migrator::new(migrations_path).await.unwrap();
534 migrator.run(&db.pool).await.unwrap();
535 db
536 });
537
538 Self { db, name, url }
539 }
540
541 pub fn db(&self) -> &Db {
542 &self.db
543 }
544 }
545
546 impl Drop for TestDb {
547 fn drop(&mut self) {
548 block_on(async {
549 let query = "
550 SELECT pg_terminate_backend(pg_stat_activity.pid)
551 FROM pg_stat_activity
552 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
553 ";
554 sqlx::query(query)
555 .bind(&self.name)
556 .execute(&self.db.pool)
557 .await
558 .unwrap();
559 self.db.pool.close().await;
560 Postgres::drop_database(&self.url).await.unwrap();
561 });
562 }
563 }
564
565 #[gpui::test]
566 async fn test_get_users_by_ids() {
567 let test_db = TestDb::new();
568 let db = test_db.db();
569
570 let user = db.create_user("user", false).await.unwrap();
571 let friend1 = db.create_user("friend-1", false).await.unwrap();
572 let friend2 = db.create_user("friend-2", false).await.unwrap();
573 let friend3 = db.create_user("friend-3", false).await.unwrap();
574
575 assert_eq!(
576 db.get_users_by_ids([user, friend1, friend2, friend3])
577 .await
578 .unwrap(),
579 vec![
580 User {
581 id: user,
582 github_login: "user".to_string(),
583 admin: false,
584 },
585 User {
586 id: friend1,
587 github_login: "friend-1".to_string(),
588 admin: false,
589 },
590 User {
591 id: friend2,
592 github_login: "friend-2".to_string(),
593 admin: false,
594 },
595 User {
596 id: friend3,
597 github_login: "friend-3".to_string(),
598 admin: false,
599 }
600 ]
601 );
602 }
603
604 #[gpui::test]
605 async fn test_recent_channel_messages() {
606 let test_db = TestDb::new();
607 let db = test_db.db();
608 let user = db.create_user("user", false).await.unwrap();
609 let org = db.create_org("org", "org").await.unwrap();
610 let channel = db.create_org_channel(org, "channel").await.unwrap();
611 for i in 0..10 {
612 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
613 .await
614 .unwrap();
615 }
616
617 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
618 assert_eq!(
619 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
620 ["5", "6", "7", "8", "9"]
621 );
622
623 let prev_messages = db
624 .get_channel_messages(channel, 4, Some(messages[0].id))
625 .await
626 .unwrap();
627 assert_eq!(
628 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
629 ["1", "2", "3", "4"]
630 );
631 }
632
633 #[gpui::test]
634 async fn test_channel_message_nonces() {
635 let test_db = TestDb::new();
636 let db = test_db.db();
637 let user = db.create_user("user", false).await.unwrap();
638 let org = db.create_org("org", "org").await.unwrap();
639 let channel = db.create_org_channel(org, "channel").await.unwrap();
640
641 let msg1_id = db
642 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
643 .await
644 .unwrap();
645 let msg2_id = db
646 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
647 .await
648 .unwrap();
649 let msg3_id = db
650 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
651 .await
652 .unwrap();
653 let msg4_id = db
654 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
655 .await
656 .unwrap();
657
658 assert_ne!(msg1_id, msg2_id);
659 assert_eq!(msg1_id, msg3_id);
660 assert_eq!(msg2_id, msg4_id);
661 }
662
663 #[gpui::test]
664 async fn test_create_access_tokens() {
665 let test_db = TestDb::new();
666 let db = test_db.db();
667 let user = db.create_user("the-user", false).await.unwrap();
668
669 db.create_access_token_hash(user, "h1", 3).await.unwrap();
670 db.create_access_token_hash(user, "h2", 3).await.unwrap();
671 assert_eq!(
672 db.get_access_token_hashes(user).await.unwrap(),
673 &["h2".to_string(), "h1".to_string()]
674 );
675
676 db.create_access_token_hash(user, "h3", 3).await.unwrap();
677 assert_eq!(
678 db.get_access_token_hashes(user).await.unwrap(),
679 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
680 );
681
682 db.create_access_token_hash(user, "h4", 3).await.unwrap();
683 assert_eq!(
684 db.get_access_token_hashes(user).await.unwrap(),
685 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
686 );
687
688 db.create_access_token_hash(user, "h5", 3).await.unwrap();
689 assert_eq!(
690 db.get_access_token_hashes(user).await.unwrap(),
691 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
692 );
693 }
694}