db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{types::Uuid, FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24#[derive(Clone)]
 25pub struct Db {
 26    pool: sqlx::PgPool,
 27    test_mode: bool,
 28}
 29
 30impl Db {
 31    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 32        let pool = DbOptions::new()
 33            .max_connections(max_connections)
 34            .connect(url)
 35            .await
 36            .context("failed to connect to postgres database")?;
 37        Ok(Self {
 38            pool,
 39            test_mode: false,
 40        })
 41    }
 42
 43    // signups
 44
 45    pub async fn create_signup(
 46        &self,
 47        github_login: &str,
 48        email_address: &str,
 49        about: &str,
 50    ) -> Result<SignupId> {
 51        test_support!(self, {
 52            let query = "
 53                INSERT INTO signups (github_login, email_address, about)
 54                VALUES ($1, $2, $3)
 55                RETURNING id
 56            ";
 57            sqlx::query_scalar(query)
 58                .bind(github_login)
 59                .bind(email_address)
 60                .bind(about)
 61                .fetch_one(&self.pool)
 62                .await
 63                .map(SignupId)
 64        })
 65    }
 66
 67    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 68        test_support!(self, {
 69            let query = "SELECT * FROM signups ORDER BY github_login ASC";
 70            sqlx::query_as(query).fetch_all(&self.pool).await
 71        })
 72    }
 73
 74    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
 75        test_support!(self, {
 76            let query = "DELETE FROM signups WHERE id = $1";
 77            sqlx::query(query)
 78                .bind(id.0)
 79                .execute(&self.pool)
 80                .await
 81                .map(drop)
 82        })
 83    }
 84
 85    // users
 86
 87    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 88        test_support!(self, {
 89            let query = "
 90                INSERT INTO users (github_login, admin)
 91                VALUES ($1, $2)
 92                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 93                RETURNING id
 94            ";
 95            sqlx::query_scalar(query)
 96                .bind(github_login)
 97                .bind(admin)
 98                .fetch_one(&self.pool)
 99                .await
100                .map(UserId)
101        })
102    }
103
104    pub async fn get_all_users(&self) -> Result<Vec<User>> {
105        test_support!(self, {
106            let query = "SELECT * FROM users ORDER BY github_login ASC";
107            sqlx::query_as(query).fetch_all(&self.pool).await
108        })
109    }
110
111    pub async fn get_users_by_ids(
112        &self,
113        ids: impl IntoIterator<Item = UserId>,
114    ) -> Result<Vec<User>> {
115        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
116        test_support!(self, {
117            let query = "
118                SELECT users.*
119                FROM users
120                WHERE users.id = ANY ($1)
121            ";
122
123            sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
124        })
125    }
126
127    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
128        test_support!(self, {
129            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
130            sqlx::query_as(query)
131                .bind(github_login)
132                .fetch_optional(&self.pool)
133                .await
134        })
135    }
136
137    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
138        test_support!(self, {
139            let query = "UPDATE users SET admin = $1 WHERE id = $2";
140            sqlx::query(query)
141                .bind(is_admin)
142                .bind(id.0)
143                .execute(&self.pool)
144                .await
145                .map(drop)
146        })
147    }
148
149    pub async fn delete_user(&self, id: UserId) -> Result<()> {
150        test_support!(self, {
151            let query = "DELETE FROM users WHERE id = $1;";
152            sqlx::query(query)
153                .bind(id.0)
154                .execute(&self.pool)
155                .await
156                .map(drop)
157        })
158    }
159
160    // access tokens
161
162    pub async fn create_access_token_hash(
163        &self,
164        user_id: UserId,
165        access_token_hash: String,
166    ) -> Result<()> {
167        test_support!(self, {
168            let query = "
169            INSERT INTO access_tokens (user_id, hash)
170            VALUES ($1, $2)
171        ";
172            sqlx::query(query)
173                .bind(user_id.0)
174                .bind(access_token_hash)
175                .execute(&self.pool)
176                .await
177                .map(drop)
178        })
179    }
180
181    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
182        test_support!(self, {
183            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
184            sqlx::query_scalar(query)
185                .bind(user_id.0)
186                .fetch_all(&self.pool)
187                .await
188        })
189    }
190
191    // orgs
192
193    #[allow(unused)] // Help rust-analyzer
194    #[cfg(any(test, feature = "seed-support"))]
195    pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
196        test_support!(self, {
197            let query = "
198                SELECT *
199                FROM orgs
200                WHERE slug = $1
201            ";
202            sqlx::query_as(query)
203                .bind(slug)
204                .fetch_optional(&self.pool)
205                .await
206        })
207    }
208
209    #[cfg(any(test, feature = "seed-support"))]
210    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
211        test_support!(self, {
212            let query = "
213                INSERT INTO orgs (name, slug)
214                VALUES ($1, $2)
215                RETURNING id
216            ";
217            sqlx::query_scalar(query)
218                .bind(name)
219                .bind(slug)
220                .fetch_one(&self.pool)
221                .await
222                .map(OrgId)
223        })
224    }
225
226    #[cfg(any(test, feature = "seed-support"))]
227    pub async fn add_org_member(
228        &self,
229        org_id: OrgId,
230        user_id: UserId,
231        is_admin: bool,
232    ) -> Result<()> {
233        test_support!(self, {
234            let query = "
235                INSERT INTO org_memberships (org_id, user_id, admin)
236                VALUES ($1, $2, $3)
237                ON CONFLICT DO NOTHING
238            ";
239            sqlx::query(query)
240                .bind(org_id.0)
241                .bind(user_id.0)
242                .bind(is_admin)
243                .execute(&self.pool)
244                .await
245                .map(drop)
246        })
247    }
248
249    // channels
250
251    #[cfg(any(test, feature = "seed-support"))]
252    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
253        test_support!(self, {
254            let query = "
255                INSERT INTO channels (owner_id, owner_is_user, name)
256                VALUES ($1, false, $2)
257                RETURNING id
258            ";
259            sqlx::query_scalar(query)
260                .bind(org_id.0)
261                .bind(name)
262                .fetch_one(&self.pool)
263                .await
264                .map(ChannelId)
265        })
266    }
267
268    #[allow(unused)] // Help rust-analyzer
269    #[cfg(any(test, feature = "seed-support"))]
270    pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
271        test_support!(self, {
272            let query = "
273                SELECT *
274                FROM channels
275                WHERE
276                    channels.owner_is_user = false AND
277                    channels.owner_id = $1
278            ";
279            sqlx::query_as(query)
280                .bind(org_id.0)
281                .fetch_all(&self.pool)
282                .await
283        })
284    }
285
286    pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
287        test_support!(self, {
288            let query = "
289                SELECT
290                    channels.id, channels.name
291                FROM
292                    channel_memberships, channels
293                WHERE
294                    channel_memberships.user_id = $1 AND
295                    channel_memberships.channel_id = channels.id
296            ";
297            sqlx::query_as(query)
298                .bind(user_id.0)
299                .fetch_all(&self.pool)
300                .await
301        })
302    }
303
304    pub async fn can_user_access_channel(
305        &self,
306        user_id: UserId,
307        channel_id: ChannelId,
308    ) -> Result<bool> {
309        test_support!(self, {
310            let query = "
311                SELECT id
312                FROM channel_memberships
313                WHERE user_id = $1 AND channel_id = $2
314                LIMIT 1
315            ";
316            sqlx::query_scalar::<_, i32>(query)
317                .bind(user_id.0)
318                .bind(channel_id.0)
319                .fetch_optional(&self.pool)
320                .await
321                .map(|e| e.is_some())
322        })
323    }
324
325    #[cfg(any(test, feature = "seed-support"))]
326    pub async fn add_channel_member(
327        &self,
328        channel_id: ChannelId,
329        user_id: UserId,
330        is_admin: bool,
331    ) -> Result<()> {
332        test_support!(self, {
333            let query = "
334                INSERT INTO channel_memberships (channel_id, user_id, admin)
335                VALUES ($1, $2, $3)
336                ON CONFLICT DO NOTHING
337            ";
338            sqlx::query(query)
339                .bind(channel_id.0)
340                .bind(user_id.0)
341                .bind(is_admin)
342                .execute(&self.pool)
343                .await
344                .map(drop)
345        })
346    }
347
348    // messages
349
350    pub async fn create_channel_message(
351        &self,
352        channel_id: ChannelId,
353        sender_id: UserId,
354        body: &str,
355        timestamp: OffsetDateTime,
356        nonce: u128,
357    ) -> Result<MessageId> {
358        test_support!(self, {
359            let query = "
360                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
361                VALUES ($1, $2, $3, $4, $5)
362                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
363                RETURNING id
364            ";
365            sqlx::query_scalar(query)
366                .bind(channel_id.0)
367                .bind(sender_id.0)
368                .bind(body)
369                .bind(timestamp)
370                .bind(Uuid::from_u128(nonce))
371                .fetch_one(&self.pool)
372                .await
373                .map(MessageId)
374        })
375    }
376
377    pub async fn get_channel_messages(
378        &self,
379        channel_id: ChannelId,
380        count: usize,
381        before_id: Option<MessageId>,
382    ) -> Result<Vec<ChannelMessage>> {
383        test_support!(self, {
384            let query = r#"
385                SELECT * FROM (
386                    SELECT
387                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
388                    FROM
389                        channel_messages
390                    WHERE
391                        channel_id = $1 AND
392                        id < $2
393                    ORDER BY id DESC
394                    LIMIT $3
395                ) as recent_messages
396                ORDER BY id ASC
397            "#;
398            sqlx::query_as(query)
399                .bind(channel_id.0)
400                .bind(before_id.unwrap_or(MessageId::MAX))
401                .bind(count as i64)
402                .fetch_all(&self.pool)
403                .await
404        })
405    }
406}
407
408macro_rules! id_type {
409    ($name:ident) => {
410        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
411        #[sqlx(transparent)]
412        #[serde(transparent)]
413        pub struct $name(pub i32);
414
415        impl $name {
416            #[allow(unused)]
417            pub const MAX: Self = Self(i32::MAX);
418
419            #[allow(unused)]
420            pub fn from_proto(value: u64) -> Self {
421                Self(value as i32)
422            }
423
424            #[allow(unused)]
425            pub fn to_proto(&self) -> u64 {
426                self.0 as u64
427            }
428        }
429    };
430}
431
432id_type!(UserId);
433#[derive(Debug, FromRow, Serialize, PartialEq)]
434pub struct User {
435    pub id: UserId,
436    pub github_login: String,
437    pub admin: bool,
438}
439
440id_type!(OrgId);
441#[derive(FromRow)]
442pub struct Org {
443    pub id: OrgId,
444    pub name: String,
445    pub slug: String,
446}
447
448id_type!(SignupId);
449#[derive(Debug, FromRow, Serialize)]
450pub struct Signup {
451    pub id: SignupId,
452    pub github_login: String,
453    pub email_address: String,
454    pub about: String,
455}
456
457id_type!(ChannelId);
458#[derive(Debug, FromRow, Serialize)]
459pub struct Channel {
460    pub id: ChannelId,
461    pub name: String,
462}
463
464id_type!(MessageId);
465#[derive(Debug, FromRow)]
466pub struct ChannelMessage {
467    pub id: MessageId,
468    pub sender_id: UserId,
469    pub body: String,
470    pub sent_at: OffsetDateTime,
471    pub nonce: Uuid,
472}
473
474#[cfg(test)]
475pub mod tests {
476    use super::*;
477    use rand::prelude::*;
478    use sqlx::{
479        migrate::{MigrateDatabase, Migrator},
480        Postgres,
481    };
482    use std::path::Path;
483
484    pub struct TestDb {
485        pub db: Db,
486        pub name: String,
487        pub url: String,
488    }
489
490    impl TestDb {
491        pub fn new() -> Self {
492            // Enable tests to run in parallel by serializing the creation of each test database.
493            lazy_static::lazy_static! {
494                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
495            }
496
497            let mut rng = StdRng::from_entropy();
498            let name = format!("zed-test-{}", rng.gen::<u128>());
499            let url = format!("postgres://postgres@localhost/{}", name);
500            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
501            let db = block_on(async {
502                {
503                    let _lock = DB_CREATION.lock();
504                    Postgres::create_database(&url)
505                        .await
506                        .expect("failed to create test db");
507                }
508                let mut db = Db::new(&url, 5).await.unwrap();
509                db.test_mode = true;
510                let migrator = Migrator::new(migrations_path).await.unwrap();
511                migrator.run(&db.pool).await.unwrap();
512                db
513            });
514
515            Self { db, name, url }
516        }
517
518        pub fn db(&self) -> &Db {
519            &self.db
520        }
521    }
522
523    impl Drop for TestDb {
524        fn drop(&mut self) {
525            block_on(async {
526                let query = "
527                    SELECT pg_terminate_backend(pg_stat_activity.pid)
528                    FROM pg_stat_activity
529                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
530                ";
531                sqlx::query(query)
532                    .bind(&self.name)
533                    .execute(&self.db.pool)
534                    .await
535                    .unwrap();
536                self.db.pool.close().await;
537                Postgres::drop_database(&self.url).await.unwrap();
538            });
539        }
540    }
541
542    #[gpui::test]
543    async fn test_get_users_by_ids() {
544        let test_db = TestDb::new();
545        let db = test_db.db();
546
547        let user = db.create_user("user", false).await.unwrap();
548        let friend1 = db.create_user("friend-1", false).await.unwrap();
549        let friend2 = db.create_user("friend-2", false).await.unwrap();
550        let friend3 = db.create_user("friend-3", false).await.unwrap();
551
552        assert_eq!(
553            db.get_users_by_ids([user, friend1, friend2, friend3])
554                .await
555                .unwrap(),
556            vec![
557                User {
558                    id: user,
559                    github_login: "user".to_string(),
560                    admin: false,
561                },
562                User {
563                    id: friend1,
564                    github_login: "friend-1".to_string(),
565                    admin: false,
566                },
567                User {
568                    id: friend2,
569                    github_login: "friend-2".to_string(),
570                    admin: false,
571                },
572                User {
573                    id: friend3,
574                    github_login: "friend-3".to_string(),
575                    admin: false,
576                }
577            ]
578        );
579    }
580
581    #[gpui::test]
582    async fn test_recent_channel_messages() {
583        let test_db = TestDb::new();
584        let db = test_db.db();
585        let user = db.create_user("user", false).await.unwrap();
586        let org = db.create_org("org", "org").await.unwrap();
587        let channel = db.create_org_channel(org, "channel").await.unwrap();
588        for i in 0..10 {
589            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
590                .await
591                .unwrap();
592        }
593
594        let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
595        assert_eq!(
596            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
597            ["5", "6", "7", "8", "9"]
598        );
599
600        let prev_messages = db
601            .get_channel_messages(channel, 4, Some(messages[0].id))
602            .await
603            .unwrap();
604        assert_eq!(
605            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
606            ["1", "2", "3", "4"]
607        );
608    }
609
610    #[gpui::test]
611    async fn test_channel_message_nonces() {
612        let test_db = TestDb::new();
613        let db = test_db.db();
614        let user = db.create_user("user", false).await.unwrap();
615        let org = db.create_org("org", "org").await.unwrap();
616        let channel = db.create_org_channel(org, "channel").await.unwrap();
617
618        let msg1_id = db
619            .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
620            .await
621            .unwrap();
622        let msg2_id = db
623            .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
624            .await
625            .unwrap();
626        let msg3_id = db
627            .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
628            .await
629            .unwrap();
630        let msg4_id = db
631            .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
632            .await
633            .unwrap();
634
635        assert_ne!(msg1_id, msg2_id);
636        assert_eq!(msg1_id, msg3_id);
637        assert_eq!(msg2_id, msg4_id);
638    }
639}