1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{types::Uuid, FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 wants_releases: bool,
51 wants_updates: bool,
52 wants_community: bool,
53 ) -> Result<SignupId> {
54 test_support!(self, {
55 let query = "
56 INSERT INTO signups (
57 github_login,
58 email_address,
59 about,
60 wants_releases,
61 wants_updates,
62 wants_community
63 )
64 VALUES ($1, $2, $3, $4, $5, $6)
65 RETURNING id
66 ";
67 sqlx::query_scalar(query)
68 .bind(github_login)
69 .bind(email_address)
70 .bind(about)
71 .bind(wants_releases)
72 .bind(wants_updates)
73 .bind(wants_community)
74 .fetch_one(&self.pool)
75 .await
76 .map(SignupId)
77 })
78 }
79
80 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
81 test_support!(self, {
82 let query = "SELECT * FROM signups ORDER BY github_login ASC";
83 sqlx::query_as(query).fetch_all(&self.pool).await
84 })
85 }
86
87 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
88 test_support!(self, {
89 let query = "DELETE FROM signups WHERE id = $1";
90 sqlx::query(query)
91 .bind(id.0)
92 .execute(&self.pool)
93 .await
94 .map(drop)
95 })
96 }
97
98 // users
99
100 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
101 test_support!(self, {
102 let query = "
103 INSERT INTO users (github_login, admin)
104 VALUES ($1, $2)
105 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
106 RETURNING id
107 ";
108 sqlx::query_scalar(query)
109 .bind(github_login)
110 .bind(admin)
111 .fetch_one(&self.pool)
112 .await
113 .map(UserId)
114 })
115 }
116
117 pub async fn get_all_users(&self) -> Result<Vec<User>> {
118 test_support!(self, {
119 let query = "SELECT * FROM users ORDER BY github_login ASC";
120 sqlx::query_as(query).fetch_all(&self.pool).await
121 })
122 }
123
124 pub async fn get_users_by_ids(
125 &self,
126 ids: impl IntoIterator<Item = UserId>,
127 ) -> Result<Vec<User>> {
128 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
129 test_support!(self, {
130 let query = "
131 SELECT users.*
132 FROM users
133 WHERE users.id = ANY ($1)
134 ";
135
136 sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
137 })
138 }
139
140 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
141 test_support!(self, {
142 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
143 sqlx::query_as(query)
144 .bind(github_login)
145 .fetch_optional(&self.pool)
146 .await
147 })
148 }
149
150 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
151 test_support!(self, {
152 let query = "UPDATE users SET admin = $1 WHERE id = $2";
153 sqlx::query(query)
154 .bind(is_admin)
155 .bind(id.0)
156 .execute(&self.pool)
157 .await
158 .map(drop)
159 })
160 }
161
162 pub async fn delete_user(&self, id: UserId) -> Result<()> {
163 test_support!(self, {
164 let query = "DELETE FROM users WHERE id = $1;";
165 sqlx::query(query)
166 .bind(id.0)
167 .execute(&self.pool)
168 .await
169 .map(drop)
170 })
171 }
172
173 // access tokens
174
175 pub async fn create_access_token_hash(
176 &self,
177 user_id: UserId,
178 access_token_hash: &str,
179 max_access_token_count: usize,
180 ) -> Result<()> {
181 test_support!(self, {
182 let insert_query = "
183 INSERT INTO access_tokens (user_id, hash)
184 VALUES ($1, $2);
185 ";
186 let cleanup_query = "
187 DELETE FROM access_tokens
188 WHERE id IN (
189 SELECT id from access_tokens
190 WHERE user_id = $1
191 ORDER BY id DESC
192 OFFSET $3
193 )
194 ";
195
196 let mut tx = self.pool.begin().await?;
197 sqlx::query(insert_query)
198 .bind(user_id.0)
199 .bind(access_token_hash)
200 .execute(&mut tx)
201 .await?;
202 sqlx::query(cleanup_query)
203 .bind(user_id.0)
204 .bind(access_token_hash)
205 .bind(max_access_token_count as u32)
206 .execute(&mut tx)
207 .await?;
208 tx.commit().await
209 })
210 }
211
212 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
213 test_support!(self, {
214 let query = "
215 SELECT hash
216 FROM access_tokens
217 WHERE user_id = $1
218 ORDER BY id DESC
219 ";
220 sqlx::query_scalar(query)
221 .bind(user_id.0)
222 .fetch_all(&self.pool)
223 .await
224 })
225 }
226
227 // orgs
228
229 #[allow(unused)] // Help rust-analyzer
230 #[cfg(any(test, feature = "seed-support"))]
231 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
232 test_support!(self, {
233 let query = "
234 SELECT *
235 FROM orgs
236 WHERE slug = $1
237 ";
238 sqlx::query_as(query)
239 .bind(slug)
240 .fetch_optional(&self.pool)
241 .await
242 })
243 }
244
245 #[cfg(any(test, feature = "seed-support"))]
246 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
247 test_support!(self, {
248 let query = "
249 INSERT INTO orgs (name, slug)
250 VALUES ($1, $2)
251 RETURNING id
252 ";
253 sqlx::query_scalar(query)
254 .bind(name)
255 .bind(slug)
256 .fetch_one(&self.pool)
257 .await
258 .map(OrgId)
259 })
260 }
261
262 #[cfg(any(test, feature = "seed-support"))]
263 pub async fn add_org_member(
264 &self,
265 org_id: OrgId,
266 user_id: UserId,
267 is_admin: bool,
268 ) -> Result<()> {
269 test_support!(self, {
270 let query = "
271 INSERT INTO org_memberships (org_id, user_id, admin)
272 VALUES ($1, $2, $3)
273 ON CONFLICT DO NOTHING
274 ";
275 sqlx::query(query)
276 .bind(org_id.0)
277 .bind(user_id.0)
278 .bind(is_admin)
279 .execute(&self.pool)
280 .await
281 .map(drop)
282 })
283 }
284
285 // channels
286
287 #[cfg(any(test, feature = "seed-support"))]
288 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
289 test_support!(self, {
290 let query = "
291 INSERT INTO channels (owner_id, owner_is_user, name)
292 VALUES ($1, false, $2)
293 RETURNING id
294 ";
295 sqlx::query_scalar(query)
296 .bind(org_id.0)
297 .bind(name)
298 .fetch_one(&self.pool)
299 .await
300 .map(ChannelId)
301 })
302 }
303
304 #[allow(unused)] // Help rust-analyzer
305 #[cfg(any(test, feature = "seed-support"))]
306 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
307 test_support!(self, {
308 let query = "
309 SELECT *
310 FROM channels
311 WHERE
312 channels.owner_is_user = false AND
313 channels.owner_id = $1
314 ";
315 sqlx::query_as(query)
316 .bind(org_id.0)
317 .fetch_all(&self.pool)
318 .await
319 })
320 }
321
322 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
323 test_support!(self, {
324 let query = "
325 SELECT
326 channels.id, channels.name
327 FROM
328 channel_memberships, channels
329 WHERE
330 channel_memberships.user_id = $1 AND
331 channel_memberships.channel_id = channels.id
332 ";
333 sqlx::query_as(query)
334 .bind(user_id.0)
335 .fetch_all(&self.pool)
336 .await
337 })
338 }
339
340 pub async fn can_user_access_channel(
341 &self,
342 user_id: UserId,
343 channel_id: ChannelId,
344 ) -> Result<bool> {
345 test_support!(self, {
346 let query = "
347 SELECT id
348 FROM channel_memberships
349 WHERE user_id = $1 AND channel_id = $2
350 LIMIT 1
351 ";
352 sqlx::query_scalar::<_, i32>(query)
353 .bind(user_id.0)
354 .bind(channel_id.0)
355 .fetch_optional(&self.pool)
356 .await
357 .map(|e| e.is_some())
358 })
359 }
360
361 #[cfg(any(test, feature = "seed-support"))]
362 pub async fn add_channel_member(
363 &self,
364 channel_id: ChannelId,
365 user_id: UserId,
366 is_admin: bool,
367 ) -> Result<()> {
368 test_support!(self, {
369 let query = "
370 INSERT INTO channel_memberships (channel_id, user_id, admin)
371 VALUES ($1, $2, $3)
372 ON CONFLICT DO NOTHING
373 ";
374 sqlx::query(query)
375 .bind(channel_id.0)
376 .bind(user_id.0)
377 .bind(is_admin)
378 .execute(&self.pool)
379 .await
380 .map(drop)
381 })
382 }
383
384 // messages
385
386 pub async fn create_channel_message(
387 &self,
388 channel_id: ChannelId,
389 sender_id: UserId,
390 body: &str,
391 timestamp: OffsetDateTime,
392 nonce: u128,
393 ) -> Result<MessageId> {
394 test_support!(self, {
395 let query = "
396 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
397 VALUES ($1, $2, $3, $4, $5)
398 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
399 RETURNING id
400 ";
401 sqlx::query_scalar(query)
402 .bind(channel_id.0)
403 .bind(sender_id.0)
404 .bind(body)
405 .bind(timestamp)
406 .bind(Uuid::from_u128(nonce))
407 .fetch_one(&self.pool)
408 .await
409 .map(MessageId)
410 })
411 }
412
413 pub async fn get_channel_messages(
414 &self,
415 channel_id: ChannelId,
416 count: usize,
417 before_id: Option<MessageId>,
418 ) -> Result<Vec<ChannelMessage>> {
419 test_support!(self, {
420 let query = r#"
421 SELECT * FROM (
422 SELECT
423 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
424 FROM
425 channel_messages
426 WHERE
427 channel_id = $1 AND
428 id < $2
429 ORDER BY id DESC
430 LIMIT $3
431 ) as recent_messages
432 ORDER BY id ASC
433 "#;
434 sqlx::query_as(query)
435 .bind(channel_id.0)
436 .bind(before_id.unwrap_or(MessageId::MAX))
437 .bind(count as i64)
438 .fetch_all(&self.pool)
439 .await
440 })
441 }
442}
443
444macro_rules! id_type {
445 ($name:ident) => {
446 #[derive(
447 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
448 )]
449 #[sqlx(transparent)]
450 #[serde(transparent)]
451 pub struct $name(pub i32);
452
453 impl $name {
454 #[allow(unused)]
455 pub const MAX: Self = Self(i32::MAX);
456
457 #[allow(unused)]
458 pub fn from_proto(value: u64) -> Self {
459 Self(value as i32)
460 }
461
462 #[allow(unused)]
463 pub fn to_proto(&self) -> u64 {
464 self.0 as u64
465 }
466 }
467 };
468}
469
470id_type!(UserId);
471#[derive(Debug, FromRow, Serialize, PartialEq)]
472pub struct User {
473 pub id: UserId,
474 pub github_login: String,
475 pub admin: bool,
476}
477
478id_type!(OrgId);
479#[derive(FromRow)]
480pub struct Org {
481 pub id: OrgId,
482 pub name: String,
483 pub slug: String,
484}
485
486id_type!(SignupId);
487#[derive(Debug, FromRow, Serialize)]
488pub struct Signup {
489 pub id: SignupId,
490 pub github_login: String,
491 pub email_address: String,
492 pub about: String,
493 pub wants_releases: Option<bool>,
494 pub wants_updates: Option<bool>,
495 pub wants_community: Option<bool>,
496}
497
498id_type!(ChannelId);
499#[derive(Debug, FromRow, Serialize)]
500pub struct Channel {
501 pub id: ChannelId,
502 pub name: String,
503}
504
505id_type!(MessageId);
506#[derive(Debug, FromRow)]
507pub struct ChannelMessage {
508 pub id: MessageId,
509 pub sender_id: UserId,
510 pub body: String,
511 pub sent_at: OffsetDateTime,
512 pub nonce: Uuid,
513}
514
515#[cfg(test)]
516pub mod tests {
517 use super::*;
518 use rand::prelude::*;
519 use sqlx::{
520 migrate::{MigrateDatabase, Migrator},
521 Postgres,
522 };
523 use std::path::Path;
524
525 pub struct TestDb {
526 pub db: Db,
527 pub name: String,
528 pub url: String,
529 }
530
531 impl TestDb {
532 pub fn new() -> Self {
533 // Enable tests to run in parallel by serializing the creation of each test database.
534 lazy_static::lazy_static! {
535 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
536 }
537
538 let mut rng = StdRng::from_entropy();
539 let name = format!("zed-test-{}", rng.gen::<u128>());
540 let url = format!("postgres://postgres@localhost/{}", name);
541 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
542 let db = block_on(async {
543 {
544 let _lock = DB_CREATION.lock();
545 Postgres::create_database(&url)
546 .await
547 .expect("failed to create test db");
548 }
549 let mut db = Db::new(&url, 5).await.unwrap();
550 db.test_mode = true;
551 let migrator = Migrator::new(migrations_path).await.unwrap();
552 migrator.run(&db.pool).await.unwrap();
553 db
554 });
555
556 Self { db, name, url }
557 }
558
559 pub fn db(&self) -> &Db {
560 &self.db
561 }
562 }
563
564 impl Drop for TestDb {
565 fn drop(&mut self) {
566 block_on(async {
567 let query = "
568 SELECT pg_terminate_backend(pg_stat_activity.pid)
569 FROM pg_stat_activity
570 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
571 ";
572 sqlx::query(query)
573 .bind(&self.name)
574 .execute(&self.db.pool)
575 .await
576 .unwrap();
577 self.db.pool.close().await;
578 Postgres::drop_database(&self.url).await.unwrap();
579 });
580 }
581 }
582
583 #[gpui::test]
584 async fn test_get_users_by_ids() {
585 let test_db = TestDb::new();
586 let db = test_db.db();
587
588 let user = db.create_user("user", false).await.unwrap();
589 let friend1 = db.create_user("friend-1", false).await.unwrap();
590 let friend2 = db.create_user("friend-2", false).await.unwrap();
591 let friend3 = db.create_user("friend-3", false).await.unwrap();
592
593 assert_eq!(
594 db.get_users_by_ids([user, friend1, friend2, friend3])
595 .await
596 .unwrap(),
597 vec![
598 User {
599 id: user,
600 github_login: "user".to_string(),
601 admin: false,
602 },
603 User {
604 id: friend1,
605 github_login: "friend-1".to_string(),
606 admin: false,
607 },
608 User {
609 id: friend2,
610 github_login: "friend-2".to_string(),
611 admin: false,
612 },
613 User {
614 id: friend3,
615 github_login: "friend-3".to_string(),
616 admin: false,
617 }
618 ]
619 );
620 }
621
622 #[gpui::test]
623 async fn test_recent_channel_messages() {
624 let test_db = TestDb::new();
625 let db = test_db.db();
626 let user = db.create_user("user", false).await.unwrap();
627 let org = db.create_org("org", "org").await.unwrap();
628 let channel = db.create_org_channel(org, "channel").await.unwrap();
629 for i in 0..10 {
630 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
631 .await
632 .unwrap();
633 }
634
635 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
636 assert_eq!(
637 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
638 ["5", "6", "7", "8", "9"]
639 );
640
641 let prev_messages = db
642 .get_channel_messages(channel, 4, Some(messages[0].id))
643 .await
644 .unwrap();
645 assert_eq!(
646 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
647 ["1", "2", "3", "4"]
648 );
649 }
650
651 #[gpui::test]
652 async fn test_channel_message_nonces() {
653 let test_db = TestDb::new();
654 let db = test_db.db();
655 let user = db.create_user("user", false).await.unwrap();
656 let org = db.create_org("org", "org").await.unwrap();
657 let channel = db.create_org_channel(org, "channel").await.unwrap();
658
659 let msg1_id = db
660 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
661 .await
662 .unwrap();
663 let msg2_id = db
664 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
665 .await
666 .unwrap();
667 let msg3_id = db
668 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
669 .await
670 .unwrap();
671 let msg4_id = db
672 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
673 .await
674 .unwrap();
675
676 assert_ne!(msg1_id, msg2_id);
677 assert_eq!(msg1_id, msg3_id);
678 assert_eq!(msg2_id, msg4_id);
679 }
680
681 #[gpui::test]
682 async fn test_create_access_tokens() {
683 let test_db = TestDb::new();
684 let db = test_db.db();
685 let user = db.create_user("the-user", false).await.unwrap();
686
687 db.create_access_token_hash(user, "h1", 3).await.unwrap();
688 db.create_access_token_hash(user, "h2", 3).await.unwrap();
689 assert_eq!(
690 db.get_access_token_hashes(user).await.unwrap(),
691 &["h2".to_string(), "h1".to_string()]
692 );
693
694 db.create_access_token_hash(user, "h3", 3).await.unwrap();
695 assert_eq!(
696 db.get_access_token_hashes(user).await.unwrap(),
697 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
698 );
699
700 db.create_access_token_hash(user, "h4", 3).await.unwrap();
701 assert_eq!(
702 db.get_access_token_hashes(user).await.unwrap(),
703 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
704 );
705
706 db.create_access_token_hash(user, "h5", 3).await.unwrap();
707 assert_eq!(
708 db.get_access_token_hashes(user).await.unwrap(),
709 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
710 );
711 }
712}