1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{types::Uuid, FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 wants_releases: bool,
51 wants_updates: bool,
52 wants_community: bool,
53 ) -> Result<SignupId> {
54 test_support!(self, {
55 let query = "
56 INSERT INTO signups (
57 github_login,
58 email_address,
59 about,
60 wants_releases,
61 wants_updates,
62 wants_community
63 )
64 VALUES ($1, $2, $3, $4, $5, $6)
65 RETURNING id
66 ";
67 sqlx::query_scalar(query)
68 .bind(github_login)
69 .bind(email_address)
70 .bind(about)
71 .bind(wants_releases)
72 .bind(wants_updates)
73 .bind(wants_community)
74 .fetch_one(&self.pool)
75 .await
76 .map(SignupId)
77 })
78 }
79
80 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
81 test_support!(self, {
82 let query = "SELECT * FROM signups ORDER BY github_login ASC";
83 sqlx::query_as(query).fetch_all(&self.pool).await
84 })
85 }
86
87 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
88 test_support!(self, {
89 let query = "DELETE FROM signups WHERE id = $1";
90 sqlx::query(query)
91 .bind(id.0)
92 .execute(&self.pool)
93 .await
94 .map(drop)
95 })
96 }
97
98 // users
99
100 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
101 test_support!(self, {
102 let query = "
103 INSERT INTO users (github_login, admin)
104 VALUES ($1, $2)
105 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
106 RETURNING id
107 ";
108 sqlx::query_scalar(query)
109 .bind(github_login)
110 .bind(admin)
111 .fetch_one(&self.pool)
112 .await
113 .map(UserId)
114 })
115 }
116
117 pub async fn get_all_users(&self) -> Result<Vec<User>> {
118 test_support!(self, {
119 let query = "SELECT * FROM users ORDER BY github_login ASC";
120 sqlx::query_as(query).fetch_all(&self.pool).await
121 })
122 }
123
124 pub async fn get_users_by_ids(
125 &self,
126 ids: impl IntoIterator<Item = UserId>,
127 ) -> Result<Vec<User>> {
128 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
129 test_support!(self, {
130 let query = "
131 SELECT users.*
132 FROM users
133 WHERE users.id = ANY ($1)
134 ";
135
136 sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
137 })
138 }
139
140 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
141 test_support!(self, {
142 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
143 sqlx::query_as(query)
144 .bind(github_login)
145 .fetch_optional(&self.pool)
146 .await
147 })
148 }
149
150 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
151 test_support!(self, {
152 let query = "UPDATE users SET admin = $1 WHERE id = $2";
153 sqlx::query(query)
154 .bind(is_admin)
155 .bind(id.0)
156 .execute(&self.pool)
157 .await
158 .map(drop)
159 })
160 }
161
162 pub async fn delete_user(&self, id: UserId) -> Result<()> {
163 test_support!(self, {
164 let query = "DELETE FROM users WHERE id = $1;";
165 sqlx::query(query)
166 .bind(id.0)
167 .execute(&self.pool)
168 .await
169 .map(drop)
170 })
171 }
172
173 // access tokens
174
175 pub async fn create_access_token_hash(
176 &self,
177 user_id: UserId,
178 access_token_hash: String,
179 ) -> Result<()> {
180 test_support!(self, {
181 let query = "
182 INSERT INTO access_tokens (user_id, hash)
183 VALUES ($1, $2)
184 ";
185 sqlx::query(query)
186 .bind(user_id.0)
187 .bind(access_token_hash)
188 .execute(&self.pool)
189 .await
190 .map(drop)
191 })
192 }
193
194 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
195 test_support!(self, {
196 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
197 sqlx::query_scalar(query)
198 .bind(user_id.0)
199 .fetch_all(&self.pool)
200 .await
201 })
202 }
203
204 // orgs
205
206 #[allow(unused)] // Help rust-analyzer
207 #[cfg(any(test, feature = "seed-support"))]
208 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
209 test_support!(self, {
210 let query = "
211 SELECT *
212 FROM orgs
213 WHERE slug = $1
214 ";
215 sqlx::query_as(query)
216 .bind(slug)
217 .fetch_optional(&self.pool)
218 .await
219 })
220 }
221
222 #[cfg(any(test, feature = "seed-support"))]
223 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
224 test_support!(self, {
225 let query = "
226 INSERT INTO orgs (name, slug)
227 VALUES ($1, $2)
228 RETURNING id
229 ";
230 sqlx::query_scalar(query)
231 .bind(name)
232 .bind(slug)
233 .fetch_one(&self.pool)
234 .await
235 .map(OrgId)
236 })
237 }
238
239 #[cfg(any(test, feature = "seed-support"))]
240 pub async fn add_org_member(
241 &self,
242 org_id: OrgId,
243 user_id: UserId,
244 is_admin: bool,
245 ) -> Result<()> {
246 test_support!(self, {
247 let query = "
248 INSERT INTO org_memberships (org_id, user_id, admin)
249 VALUES ($1, $2, $3)
250 ON CONFLICT DO NOTHING
251 ";
252 sqlx::query(query)
253 .bind(org_id.0)
254 .bind(user_id.0)
255 .bind(is_admin)
256 .execute(&self.pool)
257 .await
258 .map(drop)
259 })
260 }
261
262 // channels
263
264 #[cfg(any(test, feature = "seed-support"))]
265 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
266 test_support!(self, {
267 let query = "
268 INSERT INTO channels (owner_id, owner_is_user, name)
269 VALUES ($1, false, $2)
270 RETURNING id
271 ";
272 sqlx::query_scalar(query)
273 .bind(org_id.0)
274 .bind(name)
275 .fetch_one(&self.pool)
276 .await
277 .map(ChannelId)
278 })
279 }
280
281 #[allow(unused)] // Help rust-analyzer
282 #[cfg(any(test, feature = "seed-support"))]
283 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
284 test_support!(self, {
285 let query = "
286 SELECT *
287 FROM channels
288 WHERE
289 channels.owner_is_user = false AND
290 channels.owner_id = $1
291 ";
292 sqlx::query_as(query)
293 .bind(org_id.0)
294 .fetch_all(&self.pool)
295 .await
296 })
297 }
298
299 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
300 test_support!(self, {
301 let query = "
302 SELECT
303 channels.id, channels.name
304 FROM
305 channel_memberships, channels
306 WHERE
307 channel_memberships.user_id = $1 AND
308 channel_memberships.channel_id = channels.id
309 ";
310 sqlx::query_as(query)
311 .bind(user_id.0)
312 .fetch_all(&self.pool)
313 .await
314 })
315 }
316
317 pub async fn can_user_access_channel(
318 &self,
319 user_id: UserId,
320 channel_id: ChannelId,
321 ) -> Result<bool> {
322 test_support!(self, {
323 let query = "
324 SELECT id
325 FROM channel_memberships
326 WHERE user_id = $1 AND channel_id = $2
327 LIMIT 1
328 ";
329 sqlx::query_scalar::<_, i32>(query)
330 .bind(user_id.0)
331 .bind(channel_id.0)
332 .fetch_optional(&self.pool)
333 .await
334 .map(|e| e.is_some())
335 })
336 }
337
338 #[cfg(any(test, feature = "seed-support"))]
339 pub async fn add_channel_member(
340 &self,
341 channel_id: ChannelId,
342 user_id: UserId,
343 is_admin: bool,
344 ) -> Result<()> {
345 test_support!(self, {
346 let query = "
347 INSERT INTO channel_memberships (channel_id, user_id, admin)
348 VALUES ($1, $2, $3)
349 ON CONFLICT DO NOTHING
350 ";
351 sqlx::query(query)
352 .bind(channel_id.0)
353 .bind(user_id.0)
354 .bind(is_admin)
355 .execute(&self.pool)
356 .await
357 .map(drop)
358 })
359 }
360
361 // messages
362
363 pub async fn create_channel_message(
364 &self,
365 channel_id: ChannelId,
366 sender_id: UserId,
367 body: &str,
368 timestamp: OffsetDateTime,
369 nonce: u128,
370 ) -> Result<MessageId> {
371 test_support!(self, {
372 let query = "
373 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
374 VALUES ($1, $2, $3, $4, $5)
375 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
376 RETURNING id
377 ";
378 sqlx::query_scalar(query)
379 .bind(channel_id.0)
380 .bind(sender_id.0)
381 .bind(body)
382 .bind(timestamp)
383 .bind(Uuid::from_u128(nonce))
384 .fetch_one(&self.pool)
385 .await
386 .map(MessageId)
387 })
388 }
389
390 pub async fn get_channel_messages(
391 &self,
392 channel_id: ChannelId,
393 count: usize,
394 before_id: Option<MessageId>,
395 ) -> Result<Vec<ChannelMessage>> {
396 test_support!(self, {
397 let query = r#"
398 SELECT * FROM (
399 SELECT
400 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
401 FROM
402 channel_messages
403 WHERE
404 channel_id = $1 AND
405 id < $2
406 ORDER BY id DESC
407 LIMIT $3
408 ) as recent_messages
409 ORDER BY id ASC
410 "#;
411 sqlx::query_as(query)
412 .bind(channel_id.0)
413 .bind(before_id.unwrap_or(MessageId::MAX))
414 .bind(count as i64)
415 .fetch_all(&self.pool)
416 .await
417 })
418 }
419}
420
421macro_rules! id_type {
422 ($name:ident) => {
423 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
424 #[sqlx(transparent)]
425 #[serde(transparent)]
426 pub struct $name(pub i32);
427
428 impl $name {
429 #[allow(unused)]
430 pub const MAX: Self = Self(i32::MAX);
431
432 #[allow(unused)]
433 pub fn from_proto(value: u64) -> Self {
434 Self(value as i32)
435 }
436
437 #[allow(unused)]
438 pub fn to_proto(&self) -> u64 {
439 self.0 as u64
440 }
441 }
442 };
443}
444
445id_type!(UserId);
446#[derive(Debug, FromRow, Serialize, PartialEq)]
447pub struct User {
448 pub id: UserId,
449 pub github_login: String,
450 pub admin: bool,
451}
452
453id_type!(OrgId);
454#[derive(FromRow)]
455pub struct Org {
456 pub id: OrgId,
457 pub name: String,
458 pub slug: String,
459}
460
461id_type!(SignupId);
462#[derive(Debug, FromRow, Serialize)]
463pub struct Signup {
464 pub id: SignupId,
465 pub github_login: String,
466 pub email_address: String,
467 pub about: String,
468 pub wants_releases: Option<bool>,
469 pub wants_updates: Option<bool>,
470 pub wants_community: Option<bool>,
471}
472
473id_type!(ChannelId);
474#[derive(Debug, FromRow, Serialize)]
475pub struct Channel {
476 pub id: ChannelId,
477 pub name: String,
478}
479
480id_type!(MessageId);
481#[derive(Debug, FromRow)]
482pub struct ChannelMessage {
483 pub id: MessageId,
484 pub sender_id: UserId,
485 pub body: String,
486 pub sent_at: OffsetDateTime,
487 pub nonce: Uuid,
488}
489
490#[cfg(test)]
491pub mod tests {
492 use super::*;
493 use rand::prelude::*;
494 use sqlx::{
495 migrate::{MigrateDatabase, Migrator},
496 Postgres,
497 };
498 use std::path::Path;
499
500 pub struct TestDb {
501 pub db: Db,
502 pub name: String,
503 pub url: String,
504 }
505
506 impl TestDb {
507 pub fn new() -> Self {
508 // Enable tests to run in parallel by serializing the creation of each test database.
509 lazy_static::lazy_static! {
510 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
511 }
512
513 let mut rng = StdRng::from_entropy();
514 let name = format!("zed-test-{}", rng.gen::<u128>());
515 let url = format!("postgres://postgres@localhost/{}", name);
516 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
517 let db = block_on(async {
518 {
519 let _lock = DB_CREATION.lock();
520 Postgres::create_database(&url)
521 .await
522 .expect("failed to create test db");
523 }
524 let mut db = Db::new(&url, 5).await.unwrap();
525 db.test_mode = true;
526 let migrator = Migrator::new(migrations_path).await.unwrap();
527 migrator.run(&db.pool).await.unwrap();
528 db
529 });
530
531 Self { db, name, url }
532 }
533
534 pub fn db(&self) -> &Db {
535 &self.db
536 }
537 }
538
539 impl Drop for TestDb {
540 fn drop(&mut self) {
541 block_on(async {
542 let query = "
543 SELECT pg_terminate_backend(pg_stat_activity.pid)
544 FROM pg_stat_activity
545 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
546 ";
547 sqlx::query(query)
548 .bind(&self.name)
549 .execute(&self.db.pool)
550 .await
551 .unwrap();
552 self.db.pool.close().await;
553 Postgres::drop_database(&self.url).await.unwrap();
554 });
555 }
556 }
557
558 #[gpui::test]
559 async fn test_get_users_by_ids() {
560 let test_db = TestDb::new();
561 let db = test_db.db();
562
563 let user = db.create_user("user", false).await.unwrap();
564 let friend1 = db.create_user("friend-1", false).await.unwrap();
565 let friend2 = db.create_user("friend-2", false).await.unwrap();
566 let friend3 = db.create_user("friend-3", false).await.unwrap();
567
568 assert_eq!(
569 db.get_users_by_ids([user, friend1, friend2, friend3])
570 .await
571 .unwrap(),
572 vec![
573 User {
574 id: user,
575 github_login: "user".to_string(),
576 admin: false,
577 },
578 User {
579 id: friend1,
580 github_login: "friend-1".to_string(),
581 admin: false,
582 },
583 User {
584 id: friend2,
585 github_login: "friend-2".to_string(),
586 admin: false,
587 },
588 User {
589 id: friend3,
590 github_login: "friend-3".to_string(),
591 admin: false,
592 }
593 ]
594 );
595 }
596
597 #[gpui::test]
598 async fn test_recent_channel_messages() {
599 let test_db = TestDb::new();
600 let db = test_db.db();
601 let user = db.create_user("user", false).await.unwrap();
602 let org = db.create_org("org", "org").await.unwrap();
603 let channel = db.create_org_channel(org, "channel").await.unwrap();
604 for i in 0..10 {
605 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
606 .await
607 .unwrap();
608 }
609
610 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
611 assert_eq!(
612 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
613 ["5", "6", "7", "8", "9"]
614 );
615
616 let prev_messages = db
617 .get_channel_messages(channel, 4, Some(messages[0].id))
618 .await
619 .unwrap();
620 assert_eq!(
621 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
622 ["1", "2", "3", "4"]
623 );
624 }
625
626 #[gpui::test]
627 async fn test_channel_message_nonces() {
628 let test_db = TestDb::new();
629 let db = test_db.db();
630 let user = db.create_user("user", false).await.unwrap();
631 let org = db.create_org("org", "org").await.unwrap();
632 let channel = db.create_org_channel(org, "channel").await.unwrap();
633
634 let msg1_id = db
635 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
636 .await
637 .unwrap();
638 let msg2_id = db
639 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
640 .await
641 .unwrap();
642 let msg3_id = db
643 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
644 .await
645 .unwrap();
646 let msg4_id = db
647 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
648 .await
649 .unwrap();
650
651 assert_ne!(msg1_id, msg2_id);
652 assert_eq!(msg1_id, msg3_id);
653 assert_eq!(msg2_id, msg4_id);
654 }
655}