1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{types::Uuid, FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 ) -> Result<SignupId> {
51 test_support!(self, {
52 let query = "
53 INSERT INTO signups (github_login, email_address, about)
54 VALUES ($1, $2, $3)
55 RETURNING id
56 ";
57 sqlx::query_scalar(query)
58 .bind(github_login)
59 .bind(email_address)
60 .bind(about)
61 .fetch_one(&self.pool)
62 .await
63 .map(SignupId)
64 })
65 }
66
67 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
68 test_support!(self, {
69 let query = "SELECT * FROM signups ORDER BY github_login ASC";
70 sqlx::query_as(query).fetch_all(&self.pool).await
71 })
72 }
73
74 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
75 test_support!(self, {
76 let query = "DELETE FROM signups WHERE id = $1";
77 sqlx::query(query)
78 .bind(id.0)
79 .execute(&self.pool)
80 .await
81 .map(drop)
82 })
83 }
84
85 // users
86
87 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
88 test_support!(self, {
89 let query = "
90 INSERT INTO users (github_login, admin)
91 VALUES ($1, $2)
92 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
93 RETURNING id
94 ";
95 sqlx::query_scalar(query)
96 .bind(github_login)
97 .bind(admin)
98 .fetch_one(&self.pool)
99 .await
100 .map(UserId)
101 })
102 }
103
104 pub async fn get_all_users(&self) -> Result<Vec<User>> {
105 test_support!(self, {
106 let query = "SELECT * FROM users ORDER BY github_login ASC";
107 sqlx::query_as(query).fetch_all(&self.pool).await
108 })
109 }
110
111 pub async fn get_users_by_ids(
112 &self,
113 ids: impl IntoIterator<Item = UserId>,
114 ) -> Result<Vec<User>> {
115 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
116 test_support!(self, {
117 let query = "
118 SELECT users.*
119 FROM users
120 WHERE users.id = ANY ($1)
121 ";
122
123 sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
124 })
125 }
126
127 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
128 test_support!(self, {
129 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
130 sqlx::query_as(query)
131 .bind(github_login)
132 .fetch_optional(&self.pool)
133 .await
134 })
135 }
136
137 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
138 test_support!(self, {
139 let query = "UPDATE users SET admin = $1 WHERE id = $2";
140 sqlx::query(query)
141 .bind(is_admin)
142 .bind(id.0)
143 .execute(&self.pool)
144 .await
145 .map(drop)
146 })
147 }
148
149 pub async fn delete_user(&self, id: UserId) -> Result<()> {
150 test_support!(self, {
151 let query = "DELETE FROM users WHERE id = $1;";
152 sqlx::query(query)
153 .bind(id.0)
154 .execute(&self.pool)
155 .await
156 .map(drop)
157 })
158 }
159
160 // access tokens
161
162 pub async fn create_access_token_hash(
163 &self,
164 user_id: UserId,
165 access_token_hash: String,
166 ) -> Result<()> {
167 test_support!(self, {
168 let query = "
169 INSERT INTO access_tokens (user_id, hash)
170 VALUES ($1, $2)
171 ";
172 sqlx::query(query)
173 .bind(user_id.0)
174 .bind(access_token_hash)
175 .execute(&self.pool)
176 .await
177 .map(drop)
178 })
179 }
180
181 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
182 test_support!(self, {
183 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
184 sqlx::query_scalar(query)
185 .bind(user_id.0)
186 .fetch_all(&self.pool)
187 .await
188 })
189 }
190
191 // orgs
192
193 #[allow(unused)] // Help rust-analyzer
194 #[cfg(any(test, feature = "seed-support"))]
195 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
196 test_support!(self, {
197 let query = "
198 SELECT *
199 FROM orgs
200 WHERE slug = $1
201 ";
202 sqlx::query_as(query)
203 .bind(slug)
204 .fetch_optional(&self.pool)
205 .await
206 })
207 }
208
209 #[cfg(any(test, feature = "seed-support"))]
210 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
211 test_support!(self, {
212 let query = "
213 INSERT INTO orgs (name, slug)
214 VALUES ($1, $2)
215 RETURNING id
216 ";
217 sqlx::query_scalar(query)
218 .bind(name)
219 .bind(slug)
220 .fetch_one(&self.pool)
221 .await
222 .map(OrgId)
223 })
224 }
225
226 #[cfg(any(test, feature = "seed-support"))]
227 pub async fn add_org_member(
228 &self,
229 org_id: OrgId,
230 user_id: UserId,
231 is_admin: bool,
232 ) -> Result<()> {
233 test_support!(self, {
234 let query = "
235 INSERT INTO org_memberships (org_id, user_id, admin)
236 VALUES ($1, $2, $3)
237 ON CONFLICT DO NOTHING
238 ";
239 sqlx::query(query)
240 .bind(org_id.0)
241 .bind(user_id.0)
242 .bind(is_admin)
243 .execute(&self.pool)
244 .await
245 .map(drop)
246 })
247 }
248
249 // channels
250
251 #[cfg(any(test, feature = "seed-support"))]
252 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
253 test_support!(self, {
254 let query = "
255 INSERT INTO channels (owner_id, owner_is_user, name)
256 VALUES ($1, false, $2)
257 RETURNING id
258 ";
259 sqlx::query_scalar(query)
260 .bind(org_id.0)
261 .bind(name)
262 .fetch_one(&self.pool)
263 .await
264 .map(ChannelId)
265 })
266 }
267
268 #[allow(unused)] // Help rust-analyzer
269 #[cfg(any(test, feature = "seed-support"))]
270 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
271 test_support!(self, {
272 let query = "
273 SELECT *
274 FROM channels
275 WHERE
276 channels.owner_is_user = false AND
277 channels.owner_id = $1
278 ";
279 sqlx::query_as(query)
280 .bind(org_id.0)
281 .fetch_all(&self.pool)
282 .await
283 })
284 }
285
286 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
287 test_support!(self, {
288 let query = "
289 SELECT
290 channels.id, channels.name
291 FROM
292 channel_memberships, channels
293 WHERE
294 channel_memberships.user_id = $1 AND
295 channel_memberships.channel_id = channels.id
296 ";
297 sqlx::query_as(query)
298 .bind(user_id.0)
299 .fetch_all(&self.pool)
300 .await
301 })
302 }
303
304 pub async fn can_user_access_channel(
305 &self,
306 user_id: UserId,
307 channel_id: ChannelId,
308 ) -> Result<bool> {
309 test_support!(self, {
310 let query = "
311 SELECT id
312 FROM channel_memberships
313 WHERE user_id = $1 AND channel_id = $2
314 LIMIT 1
315 ";
316 sqlx::query_scalar::<_, i32>(query)
317 .bind(user_id.0)
318 .bind(channel_id.0)
319 .fetch_optional(&self.pool)
320 .await
321 .map(|e| e.is_some())
322 })
323 }
324
325 #[cfg(any(test, feature = "seed-support"))]
326 pub async fn add_channel_member(
327 &self,
328 channel_id: ChannelId,
329 user_id: UserId,
330 is_admin: bool,
331 ) -> Result<()> {
332 test_support!(self, {
333 let query = "
334 INSERT INTO channel_memberships (channel_id, user_id, admin)
335 VALUES ($1, $2, $3)
336 ON CONFLICT DO NOTHING
337 ";
338 sqlx::query(query)
339 .bind(channel_id.0)
340 .bind(user_id.0)
341 .bind(is_admin)
342 .execute(&self.pool)
343 .await
344 .map(drop)
345 })
346 }
347
348 // messages
349
350 pub async fn create_channel_message(
351 &self,
352 channel_id: ChannelId,
353 sender_id: UserId,
354 body: &str,
355 timestamp: OffsetDateTime,
356 nonce: u128,
357 ) -> Result<MessageId> {
358 test_support!(self, {
359 let query = "
360 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
361 VALUES ($1, $2, $3, $4, $5)
362 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
363 RETURNING id
364 ";
365 sqlx::query_scalar(query)
366 .bind(channel_id.0)
367 .bind(sender_id.0)
368 .bind(body)
369 .bind(timestamp)
370 .bind(Uuid::from_u128(nonce))
371 .fetch_one(&self.pool)
372 .await
373 .map(MessageId)
374 })
375 }
376
377 pub async fn get_channel_messages(
378 &self,
379 channel_id: ChannelId,
380 count: usize,
381 before_id: Option<MessageId>,
382 ) -> Result<Vec<ChannelMessage>> {
383 test_support!(self, {
384 let query = r#"
385 SELECT * FROM (
386 SELECT
387 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
388 FROM
389 channel_messages
390 WHERE
391 channel_id = $1 AND
392 id < $2
393 ORDER BY id DESC
394 LIMIT $3
395 ) as recent_messages
396 ORDER BY id ASC
397 "#;
398 sqlx::query_as(query)
399 .bind(channel_id.0)
400 .bind(before_id.unwrap_or(MessageId::MAX))
401 .bind(count as i64)
402 .fetch_all(&self.pool)
403 .await
404 })
405 }
406}
407
408macro_rules! id_type {
409 ($name:ident) => {
410 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
411 #[sqlx(transparent)]
412 #[serde(transparent)]
413 pub struct $name(pub i32);
414
415 impl $name {
416 #[allow(unused)]
417 pub const MAX: Self = Self(i32::MAX);
418
419 #[allow(unused)]
420 pub fn from_proto(value: u64) -> Self {
421 Self(value as i32)
422 }
423
424 #[allow(unused)]
425 pub fn to_proto(&self) -> u64 {
426 self.0 as u64
427 }
428 }
429 };
430}
431
432id_type!(UserId);
433#[derive(Debug, FromRow, Serialize, PartialEq)]
434pub struct User {
435 pub id: UserId,
436 pub github_login: String,
437 pub admin: bool,
438}
439
440id_type!(OrgId);
441#[derive(FromRow)]
442pub struct Org {
443 pub id: OrgId,
444 pub name: String,
445 pub slug: String,
446}
447
448id_type!(SignupId);
449#[derive(Debug, FromRow, Serialize)]
450pub struct Signup {
451 pub id: SignupId,
452 pub github_login: String,
453 pub email_address: String,
454 pub about: String,
455}
456
457id_type!(ChannelId);
458#[derive(Debug, FromRow, Serialize)]
459pub struct Channel {
460 pub id: ChannelId,
461 pub name: String,
462}
463
464id_type!(MessageId);
465#[derive(Debug, FromRow)]
466pub struct ChannelMessage {
467 pub id: MessageId,
468 pub sender_id: UserId,
469 pub body: String,
470 pub sent_at: OffsetDateTime,
471 pub nonce: Uuid,
472}
473
474#[cfg(test)]
475pub mod tests {
476 use super::*;
477 use rand::prelude::*;
478 use sqlx::{
479 migrate::{MigrateDatabase, Migrator},
480 Postgres,
481 };
482 use std::path::Path;
483
484 pub struct TestDb {
485 pub db: Db,
486 pub name: String,
487 pub url: String,
488 }
489
490 impl TestDb {
491 pub fn new() -> Self {
492 // Enable tests to run in parallel by serializing the creation of each test database.
493 lazy_static::lazy_static! {
494 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
495 }
496
497 let mut rng = StdRng::from_entropy();
498 let name = format!("zed-test-{}", rng.gen::<u128>());
499 let url = format!("postgres://postgres@localhost/{}", name);
500 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
501 let db = block_on(async {
502 {
503 let _lock = DB_CREATION.lock();
504 Postgres::create_database(&url)
505 .await
506 .expect("failed to create test db");
507 }
508 let mut db = Db::new(&url, 5).await.unwrap();
509 db.test_mode = true;
510 let migrator = Migrator::new(migrations_path).await.unwrap();
511 migrator.run(&db.pool).await.unwrap();
512 db
513 });
514
515 Self { db, name, url }
516 }
517
518 pub fn db(&self) -> &Db {
519 &self.db
520 }
521 }
522
523 impl Drop for TestDb {
524 fn drop(&mut self) {
525 block_on(async {
526 let query = "
527 SELECT pg_terminate_backend(pg_stat_activity.pid)
528 FROM pg_stat_activity
529 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
530 ";
531 sqlx::query(query)
532 .bind(&self.name)
533 .execute(&self.db.pool)
534 .await
535 .unwrap();
536 self.db.pool.close().await;
537 Postgres::drop_database(&self.url).await.unwrap();
538 });
539 }
540 }
541
542 #[gpui::test]
543 async fn test_get_users_by_ids() {
544 let test_db = TestDb::new();
545 let db = test_db.db();
546
547 let user = db.create_user("user", false).await.unwrap();
548 let friend1 = db.create_user("friend-1", false).await.unwrap();
549 let friend2 = db.create_user("friend-2", false).await.unwrap();
550 let friend3 = db.create_user("friend-3", false).await.unwrap();
551
552 assert_eq!(
553 db.get_users_by_ids([user, friend1, friend2, friend3])
554 .await
555 .unwrap(),
556 vec![
557 User {
558 id: user,
559 github_login: "user".to_string(),
560 admin: false,
561 },
562 User {
563 id: friend1,
564 github_login: "friend-1".to_string(),
565 admin: false,
566 },
567 User {
568 id: friend2,
569 github_login: "friend-2".to_string(),
570 admin: false,
571 },
572 User {
573 id: friend3,
574 github_login: "friend-3".to_string(),
575 admin: false,
576 }
577 ]
578 );
579 }
580
581 #[gpui::test]
582 async fn test_recent_channel_messages() {
583 let test_db = TestDb::new();
584 let db = test_db.db();
585 let user = db.create_user("user", false).await.unwrap();
586 let org = db.create_org("org", "org").await.unwrap();
587 let channel = db.create_org_channel(org, "channel").await.unwrap();
588 for i in 0..10 {
589 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
590 .await
591 .unwrap();
592 }
593
594 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
595 assert_eq!(
596 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
597 ["5", "6", "7", "8", "9"]
598 );
599
600 let prev_messages = db
601 .get_channel_messages(channel, 4, Some(messages[0].id))
602 .await
603 .unwrap();
604 assert_eq!(
605 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
606 ["1", "2", "3", "4"]
607 );
608 }
609
610 #[gpui::test]
611 async fn test_channel_message_nonces() {
612 let test_db = TestDb::new();
613 let db = test_db.db();
614 let user = db.create_user("user", false).await.unwrap();
615 let org = db.create_org("org", "org").await.unwrap();
616 let channel = db.create_org_channel(org, "channel").await.unwrap();
617
618 let msg1_id = db
619 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
620 .await
621 .unwrap();
622 let msg2_id = db
623 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
624 .await
625 .unwrap();
626 let msg3_id = db
627 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
628 .await
629 .unwrap();
630 let msg4_id = db
631 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
632 .await
633 .unwrap();
634
635 assert_ne!(msg1_id, msg2_id);
636 assert_eq!(msg1_id, msg3_id);
637 assert_eq!(msg2_id, msg4_id);
638 }
639}