1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{types::Uuid, FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 ) -> Result<SignupId> {
51 test_support!(self, {
52 let query = "
53 INSERT INTO signups (github_login, email_address, about)
54 VALUES ($1, $2, $3)
55 RETURNING id
56 ";
57 sqlx::query_scalar(query)
58 .bind(github_login)
59 .bind(email_address)
60 .bind(about)
61 .fetch_one(&self.pool)
62 .await
63 .map(SignupId)
64 })
65 }
66
67 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
68 test_support!(self, {
69 let query = "SELECT * FROM signups ORDER BY github_login ASC";
70 sqlx::query_as(query).fetch_all(&self.pool).await
71 })
72 }
73
74 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
75 test_support!(self, {
76 let query = "DELETE FROM signups WHERE id = $1";
77 sqlx::query(query)
78 .bind(id.0)
79 .execute(&self.pool)
80 .await
81 .map(drop)
82 })
83 }
84
85 // users
86
87 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
88 test_support!(self, {
89 let query = "
90 INSERT INTO users (github_login, admin)
91 VALUES ($1, $2)
92 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
93 RETURNING id
94 ";
95 sqlx::query_scalar(query)
96 .bind(github_login)
97 .bind(admin)
98 .fetch_one(&self.pool)
99 .await
100 .map(UserId)
101 })
102 }
103
104 pub async fn get_all_users(&self) -> Result<Vec<User>> {
105 test_support!(self, {
106 let query = "SELECT * FROM users ORDER BY github_login ASC";
107 sqlx::query_as(query).fetch_all(&self.pool).await
108 })
109 }
110
111 pub async fn get_users_by_ids(&self, ids: impl Iterator<Item = UserId>) -> Result<Vec<User>> {
112 let ids = ids.map(|id| id.0).collect::<Vec<_>>();
113 test_support!(self, {
114 let query = "
115 SELECT users.*
116 FROM users
117 WHERE users.id = ANY ($1)
118 ";
119
120 sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
121 })
122 }
123
124 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
125 test_support!(self, {
126 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
127 sqlx::query_as(query)
128 .bind(github_login)
129 .fetch_optional(&self.pool)
130 .await
131 })
132 }
133
134 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
135 test_support!(self, {
136 let query = "UPDATE users SET admin = $1 WHERE id = $2";
137 sqlx::query(query)
138 .bind(is_admin)
139 .bind(id.0)
140 .execute(&self.pool)
141 .await
142 .map(drop)
143 })
144 }
145
146 pub async fn delete_user(&self, id: UserId) -> Result<()> {
147 test_support!(self, {
148 let query = "DELETE FROM users WHERE id = $1;";
149 sqlx::query(query)
150 .bind(id.0)
151 .execute(&self.pool)
152 .await
153 .map(drop)
154 })
155 }
156
157 // access tokens
158
159 pub async fn create_access_token_hash(
160 &self,
161 user_id: UserId,
162 access_token_hash: String,
163 ) -> Result<()> {
164 test_support!(self, {
165 let query = "
166 INSERT INTO access_tokens (user_id, hash)
167 VALUES ($1, $2)
168 ";
169 sqlx::query(query)
170 .bind(user_id.0)
171 .bind(access_token_hash)
172 .execute(&self.pool)
173 .await
174 .map(drop)
175 })
176 }
177
178 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
179 test_support!(self, {
180 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
181 sqlx::query_scalar(query)
182 .bind(user_id.0)
183 .fetch_all(&self.pool)
184 .await
185 })
186 }
187
188 // orgs
189
190 #[allow(unused)] // Help rust-analyzer
191 #[cfg(any(test, feature = "seed-support"))]
192 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
193 test_support!(self, {
194 let query = "
195 SELECT *
196 FROM orgs
197 WHERE slug = $1
198 ";
199 sqlx::query_as(query)
200 .bind(slug)
201 .fetch_optional(&self.pool)
202 .await
203 })
204 }
205
206 #[cfg(any(test, feature = "seed-support"))]
207 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
208 test_support!(self, {
209 let query = "
210 INSERT INTO orgs (name, slug)
211 VALUES ($1, $2)
212 RETURNING id
213 ";
214 sqlx::query_scalar(query)
215 .bind(name)
216 .bind(slug)
217 .fetch_one(&self.pool)
218 .await
219 .map(OrgId)
220 })
221 }
222
223 #[cfg(any(test, feature = "seed-support"))]
224 pub async fn add_org_member(
225 &self,
226 org_id: OrgId,
227 user_id: UserId,
228 is_admin: bool,
229 ) -> Result<()> {
230 test_support!(self, {
231 let query = "
232 INSERT INTO org_memberships (org_id, user_id, admin)
233 VALUES ($1, $2, $3)
234 ON CONFLICT DO NOTHING
235 ";
236 sqlx::query(query)
237 .bind(org_id.0)
238 .bind(user_id.0)
239 .bind(is_admin)
240 .execute(&self.pool)
241 .await
242 .map(drop)
243 })
244 }
245
246 // channels
247
248 #[cfg(any(test, feature = "seed-support"))]
249 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
250 test_support!(self, {
251 let query = "
252 INSERT INTO channels (owner_id, owner_is_user, name)
253 VALUES ($1, false, $2)
254 RETURNING id
255 ";
256 sqlx::query_scalar(query)
257 .bind(org_id.0)
258 .bind(name)
259 .fetch_one(&self.pool)
260 .await
261 .map(ChannelId)
262 })
263 }
264
265 #[allow(unused)] // Help rust-analyzer
266 #[cfg(any(test, feature = "seed-support"))]
267 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
268 test_support!(self, {
269 let query = "
270 SELECT *
271 FROM channels
272 WHERE
273 channels.owner_is_user = false AND
274 channels.owner_id = $1
275 ";
276 sqlx::query_as(query)
277 .bind(org_id.0)
278 .fetch_all(&self.pool)
279 .await
280 })
281 }
282
283 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
284 test_support!(self, {
285 let query = "
286 SELECT
287 channels.id, channels.name
288 FROM
289 channel_memberships, channels
290 WHERE
291 channel_memberships.user_id = $1 AND
292 channel_memberships.channel_id = channels.id
293 ";
294 sqlx::query_as(query)
295 .bind(user_id.0)
296 .fetch_all(&self.pool)
297 .await
298 })
299 }
300
301 pub async fn can_user_access_channel(
302 &self,
303 user_id: UserId,
304 channel_id: ChannelId,
305 ) -> Result<bool> {
306 test_support!(self, {
307 let query = "
308 SELECT id
309 FROM channel_memberships
310 WHERE user_id = $1 AND channel_id = $2
311 LIMIT 1
312 ";
313 sqlx::query_scalar::<_, i32>(query)
314 .bind(user_id.0)
315 .bind(channel_id.0)
316 .fetch_optional(&self.pool)
317 .await
318 .map(|e| e.is_some())
319 })
320 }
321
322 #[cfg(any(test, feature = "seed-support"))]
323 pub async fn add_channel_member(
324 &self,
325 channel_id: ChannelId,
326 user_id: UserId,
327 is_admin: bool,
328 ) -> Result<()> {
329 test_support!(self, {
330 let query = "
331 INSERT INTO channel_memberships (channel_id, user_id, admin)
332 VALUES ($1, $2, $3)
333 ON CONFLICT DO NOTHING
334 ";
335 sqlx::query(query)
336 .bind(channel_id.0)
337 .bind(user_id.0)
338 .bind(is_admin)
339 .execute(&self.pool)
340 .await
341 .map(drop)
342 })
343 }
344
345 // messages
346
347 pub async fn create_channel_message(
348 &self,
349 channel_id: ChannelId,
350 sender_id: UserId,
351 body: &str,
352 timestamp: OffsetDateTime,
353 nonce: u128,
354 ) -> Result<MessageId> {
355 test_support!(self, {
356 let query = "
357 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
358 VALUES ($1, $2, $3, $4, $5)
359 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
360 RETURNING id
361 ";
362 sqlx::query_scalar(query)
363 .bind(channel_id.0)
364 .bind(sender_id.0)
365 .bind(body)
366 .bind(timestamp)
367 .bind(Uuid::from_u128(nonce))
368 .fetch_one(&self.pool)
369 .await
370 .map(MessageId)
371 })
372 }
373
374 pub async fn get_channel_messages(
375 &self,
376 channel_id: ChannelId,
377 count: usize,
378 before_id: Option<MessageId>,
379 ) -> Result<Vec<ChannelMessage>> {
380 test_support!(self, {
381 let query = r#"
382 SELECT * FROM (
383 SELECT
384 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
385 FROM
386 channel_messages
387 WHERE
388 channel_id = $1 AND
389 id < $2
390 ORDER BY id DESC
391 LIMIT $3
392 ) as recent_messages
393 ORDER BY id ASC
394 "#;
395 sqlx::query_as(query)
396 .bind(channel_id.0)
397 .bind(before_id.unwrap_or(MessageId::MAX))
398 .bind(count as i64)
399 .fetch_all(&self.pool)
400 .await
401 })
402 }
403}
404
405macro_rules! id_type {
406 ($name:ident) => {
407 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
408 #[sqlx(transparent)]
409 #[serde(transparent)]
410 pub struct $name(pub i32);
411
412 impl $name {
413 #[allow(unused)]
414 pub const MAX: Self = Self(i32::MAX);
415
416 #[allow(unused)]
417 pub fn from_proto(value: u64) -> Self {
418 Self(value as i32)
419 }
420
421 #[allow(unused)]
422 pub fn to_proto(&self) -> u64 {
423 self.0 as u64
424 }
425 }
426 };
427}
428
429id_type!(UserId);
430#[derive(Debug, FromRow, Serialize, PartialEq)]
431pub struct User {
432 pub id: UserId,
433 pub github_login: String,
434 pub admin: bool,
435}
436
437id_type!(OrgId);
438#[derive(FromRow)]
439pub struct Org {
440 pub id: OrgId,
441 pub name: String,
442 pub slug: String,
443}
444
445id_type!(SignupId);
446#[derive(Debug, FromRow, Serialize)]
447pub struct Signup {
448 pub id: SignupId,
449 pub github_login: String,
450 pub email_address: String,
451 pub about: String,
452}
453
454id_type!(ChannelId);
455#[derive(Debug, FromRow, Serialize)]
456pub struct Channel {
457 pub id: ChannelId,
458 pub name: String,
459}
460
461id_type!(MessageId);
462#[derive(Debug, FromRow)]
463pub struct ChannelMessage {
464 pub id: MessageId,
465 pub sender_id: UserId,
466 pub body: String,
467 pub sent_at: OffsetDateTime,
468 pub nonce: Uuid,
469}
470
471#[cfg(test)]
472pub mod tests {
473 use super::*;
474 use rand::prelude::*;
475 use sqlx::{
476 migrate::{MigrateDatabase, Migrator},
477 Postgres,
478 };
479 use std::path::Path;
480
481 pub struct TestDb {
482 pub db: Db,
483 pub name: String,
484 pub url: String,
485 }
486
487 impl TestDb {
488 pub fn new() -> Self {
489 // Enable tests to run in parallel by serializing the creation of each test database.
490 lazy_static::lazy_static! {
491 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
492 }
493
494 let mut rng = StdRng::from_entropy();
495 let name = format!("zed-test-{}", rng.gen::<u128>());
496 let url = format!("postgres://postgres@localhost/{}", name);
497 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
498 let db = block_on(async {
499 {
500 let _lock = DB_CREATION.lock();
501 Postgres::create_database(&url)
502 .await
503 .expect("failed to create test db");
504 }
505 let mut db = Db::new(&url, 5).await.unwrap();
506 db.test_mode = true;
507 let migrator = Migrator::new(migrations_path).await.unwrap();
508 migrator.run(&db.pool).await.unwrap();
509 db
510 });
511
512 Self { db, name, url }
513 }
514
515 pub fn db(&self) -> &Db {
516 &self.db
517 }
518 }
519
520 impl Drop for TestDb {
521 fn drop(&mut self) {
522 block_on(async {
523 let query = "
524 SELECT pg_terminate_backend(pg_stat_activity.pid)
525 FROM pg_stat_activity
526 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
527 ";
528 sqlx::query(query)
529 .bind(&self.name)
530 .execute(&self.db.pool)
531 .await
532 .unwrap();
533 self.db.pool.close().await;
534 Postgres::drop_database(&self.url).await.unwrap();
535 });
536 }
537 }
538
539 #[gpui::test]
540 async fn test_get_users_by_ids() {
541 let test_db = TestDb::new();
542 let db = test_db.db();
543
544 let user = db.create_user("user", false).await.unwrap();
545 let friend1 = db.create_user("friend-1", false).await.unwrap();
546 let friend2 = db.create_user("friend-2", false).await.unwrap();
547 let friend3 = db.create_user("friend-3", false).await.unwrap();
548
549 assert_eq!(
550 db.get_users_by_ids([user, friend1, friend2, friend3].iter().copied())
551 .await
552 .unwrap(),
553 vec![
554 User {
555 id: user,
556 github_login: "user".to_string(),
557 admin: false,
558 },
559 User {
560 id: friend1,
561 github_login: "friend-1".to_string(),
562 admin: false,
563 },
564 User {
565 id: friend2,
566 github_login: "friend-2".to_string(),
567 admin: false,
568 },
569 User {
570 id: friend3,
571 github_login: "friend-3".to_string(),
572 admin: false,
573 }
574 ]
575 );
576 }
577
578 #[gpui::test]
579 async fn test_recent_channel_messages() {
580 let test_db = TestDb::new();
581 let db = test_db.db();
582 let user = db.create_user("user", false).await.unwrap();
583 let org = db.create_org("org", "org").await.unwrap();
584 let channel = db.create_org_channel(org, "channel").await.unwrap();
585 for i in 0..10 {
586 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
587 .await
588 .unwrap();
589 }
590
591 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
592 assert_eq!(
593 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
594 ["5", "6", "7", "8", "9"]
595 );
596
597 let prev_messages = db
598 .get_channel_messages(channel, 4, Some(messages[0].id))
599 .await
600 .unwrap();
601 assert_eq!(
602 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
603 ["1", "2", "3", "4"]
604 );
605 }
606
607 #[gpui::test]
608 async fn test_channel_message_nonces() {
609 let test_db = TestDb::new();
610 let db = test_db.db();
611 let user = db.create_user("user", false).await.unwrap();
612 let org = db.create_org("org", "org").await.unwrap();
613 let channel = db.create_org_channel(org, "channel").await.unwrap();
614
615 let msg1_id = db
616 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
617 .await
618 .unwrap();
619 let msg2_id = db
620 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
621 .await
622 .unwrap();
623 let msg3_id = db
624 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
625 .await
626 .unwrap();
627 let msg4_id = db
628 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
629 .await
630 .unwrap();
631
632 assert_ne!(msg1_id, msg2_id);
633 assert_eq!(msg1_id, msg3_id);
634 assert_eq!(msg2_id, msg4_id);
635 }
636}