db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{types::Uuid, FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24#[derive(Clone)]
 25pub struct Db {
 26    pool: sqlx::PgPool,
 27    test_mode: bool,
 28}
 29
 30impl Db {
 31    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 32        let pool = DbOptions::new()
 33            .max_connections(max_connections)
 34            .connect(url)
 35            .await
 36            .context("failed to connect to postgres database")?;
 37        Ok(Self {
 38            pool,
 39            test_mode: false,
 40        })
 41    }
 42
 43    // signups
 44
 45    pub async fn create_signup(
 46        &self,
 47        github_login: &str,
 48        email_address: &str,
 49        about: &str,
 50        wants_releases: bool,
 51        wants_updates: bool,
 52        wants_community: bool,
 53    ) -> Result<SignupId> {
 54        test_support!(self, {
 55            let query = "
 56                INSERT INTO signups (
 57                    github_login,
 58                    email_address,
 59                    about,
 60                    wants_releases,
 61                    wants_updates,
 62                    wants_community
 63                )
 64                VALUES ($1, $2, $3, $4, $5, $6)
 65                RETURNING id
 66            ";
 67            sqlx::query_scalar(query)
 68                .bind(github_login)
 69                .bind(email_address)
 70                .bind(about)
 71                .bind(wants_releases)
 72                .bind(wants_updates)
 73                .bind(wants_community)
 74                .fetch_one(&self.pool)
 75                .await
 76                .map(SignupId)
 77        })
 78    }
 79
 80    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 81        test_support!(self, {
 82            let query = "SELECT * FROM signups ORDER BY github_login ASC";
 83            sqlx::query_as(query).fetch_all(&self.pool).await
 84        })
 85    }
 86
 87    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
 88        test_support!(self, {
 89            let query = "DELETE FROM signups WHERE id = $1";
 90            sqlx::query(query)
 91                .bind(id.0)
 92                .execute(&self.pool)
 93                .await
 94                .map(drop)
 95        })
 96    }
 97
 98    // users
 99
100    #[allow(unused)] // Help rust-analyzer
101    #[cfg(any(test, feature = "seed-support"))]
102    pub async fn get_user(&self, github_login: &str) -> Result<Option<UserId>> {
103        test_support!(self, {
104            let query = "
105                SELECT id
106                FROM users
107                WHERE github_login = $1
108            ";
109            sqlx::query_scalar(query)
110                .bind(github_login)
111                .fetch_optional(&self.pool)
112                .await
113        })
114    }
115
116    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
117        test_support!(self, {
118            let query = "
119                INSERT INTO users (github_login, admin)
120                VALUES ($1, $2)
121                RETURNING id
122            ";
123            sqlx::query_scalar(query)
124                .bind(github_login)
125                .bind(admin)
126                .fetch_one(&self.pool)
127                .await
128                .map(UserId)
129        })
130    }
131
132    pub async fn get_all_users(&self) -> Result<Vec<User>> {
133        test_support!(self, {
134            let query = "SELECT * FROM users ORDER BY github_login ASC";
135            sqlx::query_as(query).fetch_all(&self.pool).await
136        })
137    }
138
139    pub async fn get_users_by_ids(
140        &self,
141        requester_id: UserId,
142        ids: impl Iterator<Item = UserId>,
143    ) -> Result<Vec<User>> {
144        let mut include_requester = false;
145        let ids = ids
146            .map(|id| {
147                if id == requester_id {
148                    include_requester = true;
149                }
150                id.0
151            })
152            .collect::<Vec<_>>();
153
154        test_support!(self, {
155            // Only return users that are in a common channel with the requesting user.
156            // Also allow the requesting user to return their own data, even if they aren't
157            // in any channels.
158            let query = "
159                SELECT
160                    users.*
161                FROM
162                    users, channel_memberships
163                WHERE
164                    users.id = ANY ($1) AND
165                    channel_memberships.user_id = users.id AND
166                    channel_memberships.channel_id IN (
167                        SELECT channel_id
168                        FROM channel_memberships
169                        WHERE channel_memberships.user_id = $2
170                    )
171                UNION
172                SELECT
173                    users.*
174                FROM
175                    users
176                WHERE
177                    $3 AND users.id = $2
178            ";
179
180            sqlx::query_as(query)
181                .bind(&ids)
182                .bind(requester_id)
183                .bind(include_requester)
184                .fetch_all(&self.pool)
185                .await
186        })
187    }
188
189    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
190        test_support!(self, {
191            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
192            sqlx::query_as(query)
193                .bind(github_login)
194                .fetch_optional(&self.pool)
195                .await
196        })
197    }
198
199    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
200        test_support!(self, {
201            let query = "UPDATE users SET admin = $1 WHERE id = $2";
202            sqlx::query(query)
203                .bind(is_admin)
204                .bind(id.0)
205                .execute(&self.pool)
206                .await
207                .map(drop)
208        })
209    }
210
211    pub async fn delete_user(&self, id: UserId) -> Result<()> {
212        test_support!(self, {
213            let query = "DELETE FROM users WHERE id = $1;";
214            sqlx::query(query)
215                .bind(id.0)
216                .execute(&self.pool)
217                .await
218                .map(drop)
219        })
220    }
221
222    // access tokens
223
224    pub async fn create_access_token_hash(
225        &self,
226        user_id: UserId,
227        access_token_hash: String,
228    ) -> Result<()> {
229        test_support!(self, {
230            let query = "
231            INSERT INTO access_tokens (user_id, hash)
232            VALUES ($1, $2)
233        ";
234            sqlx::query(query)
235                .bind(user_id.0)
236                .bind(access_token_hash)
237                .execute(&self.pool)
238                .await
239                .map(drop)
240        })
241    }
242
243    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
244        test_support!(self, {
245            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
246            sqlx::query_scalar(query)
247                .bind(user_id.0)
248                .fetch_all(&self.pool)
249                .await
250        })
251    }
252
253    // orgs
254
255    #[allow(unused)] // Help rust-analyzer
256    #[cfg(any(test, feature = "seed-support"))]
257    pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
258        test_support!(self, {
259            let query = "
260                SELECT *
261                FROM orgs
262                WHERE slug = $1
263            ";
264            sqlx::query_as(query)
265                .bind(slug)
266                .fetch_optional(&self.pool)
267                .await
268        })
269    }
270
271    #[cfg(any(test, feature = "seed-support"))]
272    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
273        test_support!(self, {
274            let query = "
275                INSERT INTO orgs (name, slug)
276                VALUES ($1, $2)
277                RETURNING id
278            ";
279            sqlx::query_scalar(query)
280                .bind(name)
281                .bind(slug)
282                .fetch_one(&self.pool)
283                .await
284                .map(OrgId)
285        })
286    }
287
288    #[cfg(any(test, feature = "seed-support"))]
289    pub async fn add_org_member(
290        &self,
291        org_id: OrgId,
292        user_id: UserId,
293        is_admin: bool,
294    ) -> Result<()> {
295        test_support!(self, {
296            let query = "
297                INSERT INTO org_memberships (org_id, user_id, admin)
298                VALUES ($1, $2, $3)
299                ON CONFLICT DO NOTHING
300            ";
301            sqlx::query(query)
302                .bind(org_id.0)
303                .bind(user_id.0)
304                .bind(is_admin)
305                .execute(&self.pool)
306                .await
307                .map(drop)
308        })
309    }
310
311    // channels
312
313    #[cfg(any(test, feature = "seed-support"))]
314    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
315        test_support!(self, {
316            let query = "
317                INSERT INTO channels (owner_id, owner_is_user, name)
318                VALUES ($1, false, $2)
319                RETURNING id
320            ";
321            sqlx::query_scalar(query)
322                .bind(org_id.0)
323                .bind(name)
324                .fetch_one(&self.pool)
325                .await
326                .map(ChannelId)
327        })
328    }
329
330    #[allow(unused)] // Help rust-analyzer
331    #[cfg(any(test, feature = "seed-support"))]
332    pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
333        test_support!(self, {
334            let query = "
335                SELECT *
336                FROM channels
337                WHERE
338                    channels.owner_is_user = false AND
339                    channels.owner_id = $1
340            ";
341            sqlx::query_as(query)
342                .bind(org_id.0)
343                .fetch_all(&self.pool)
344                .await
345        })
346    }
347
348    pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
349        test_support!(self, {
350            let query = "
351                SELECT
352                    channels.id, channels.name
353                FROM
354                    channel_memberships, channels
355                WHERE
356                    channel_memberships.user_id = $1 AND
357                    channel_memberships.channel_id = channels.id
358            ";
359            sqlx::query_as(query)
360                .bind(user_id.0)
361                .fetch_all(&self.pool)
362                .await
363        })
364    }
365
366    pub async fn can_user_access_channel(
367        &self,
368        user_id: UserId,
369        channel_id: ChannelId,
370    ) -> Result<bool> {
371        test_support!(self, {
372            let query = "
373                SELECT id
374                FROM channel_memberships
375                WHERE user_id = $1 AND channel_id = $2
376                LIMIT 1
377            ";
378            sqlx::query_scalar::<_, i32>(query)
379                .bind(user_id.0)
380                .bind(channel_id.0)
381                .fetch_optional(&self.pool)
382                .await
383                .map(|e| e.is_some())
384        })
385    }
386
387    #[cfg(any(test, feature = "seed-support"))]
388    pub async fn add_channel_member(
389        &self,
390        channel_id: ChannelId,
391        user_id: UserId,
392        is_admin: bool,
393    ) -> Result<()> {
394        test_support!(self, {
395            let query = "
396                INSERT INTO channel_memberships (channel_id, user_id, admin)
397                VALUES ($1, $2, $3)
398                ON CONFLICT DO NOTHING
399            ";
400            sqlx::query(query)
401                .bind(channel_id.0)
402                .bind(user_id.0)
403                .bind(is_admin)
404                .execute(&self.pool)
405                .await
406                .map(drop)
407        })
408    }
409
410    // messages
411
412    pub async fn create_channel_message(
413        &self,
414        channel_id: ChannelId,
415        sender_id: UserId,
416        body: &str,
417        timestamp: OffsetDateTime,
418        nonce: u128,
419    ) -> Result<MessageId> {
420        test_support!(self, {
421            let query = "
422                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
423                VALUES ($1, $2, $3, $4, $5)
424                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
425                RETURNING id
426            ";
427            sqlx::query_scalar(query)
428                .bind(channel_id.0)
429                .bind(sender_id.0)
430                .bind(body)
431                .bind(timestamp)
432                .bind(Uuid::from_u128(nonce))
433                .fetch_one(&self.pool)
434                .await
435                .map(MessageId)
436        })
437    }
438
439    pub async fn get_channel_messages(
440        &self,
441        channel_id: ChannelId,
442        count: usize,
443        before_id: Option<MessageId>,
444    ) -> Result<Vec<ChannelMessage>> {
445        test_support!(self, {
446            let query = r#"
447                SELECT * FROM (
448                    SELECT
449                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
450                    FROM
451                        channel_messages
452                    WHERE
453                        channel_id = $1 AND
454                        id < $2
455                    ORDER BY id DESC
456                    LIMIT $3
457                ) as recent_messages
458                ORDER BY id ASC
459            "#;
460            sqlx::query_as(query)
461                .bind(channel_id.0)
462                .bind(before_id.unwrap_or(MessageId::MAX))
463                .bind(count as i64)
464                .fetch_all(&self.pool)
465                .await
466        })
467    }
468}
469
470macro_rules! id_type {
471    ($name:ident) => {
472        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
473        #[sqlx(transparent)]
474        #[serde(transparent)]
475        pub struct $name(pub i32);
476
477        impl $name {
478            #[allow(unused)]
479            pub const MAX: Self = Self(i32::MAX);
480
481            #[allow(unused)]
482            pub fn from_proto(value: u64) -> Self {
483                Self(value as i32)
484            }
485
486            #[allow(unused)]
487            pub fn to_proto(&self) -> u64 {
488                self.0 as u64
489            }
490        }
491    };
492}
493
494id_type!(UserId);
495#[derive(Debug, FromRow, Serialize, PartialEq)]
496pub struct User {
497    pub id: UserId,
498    pub github_login: String,
499    pub admin: bool,
500}
501
502id_type!(OrgId);
503#[derive(FromRow)]
504pub struct Org {
505    pub id: OrgId,
506    pub name: String,
507    pub slug: String,
508}
509
510id_type!(SignupId);
511#[derive(Debug, FromRow, Serialize)]
512pub struct Signup {
513    pub id: SignupId,
514    pub github_login: String,
515    pub email_address: String,
516    pub about: String,
517    pub wants_releases: Option<bool>,
518    pub wants_updates: Option<bool>,
519    pub wants_community: Option<bool>,
520}
521
522id_type!(ChannelId);
523#[derive(Debug, FromRow, Serialize)]
524pub struct Channel {
525    pub id: ChannelId,
526    pub name: String,
527}
528
529id_type!(MessageId);
530#[derive(Debug, FromRow)]
531pub struct ChannelMessage {
532    pub id: MessageId,
533    pub sender_id: UserId,
534    pub body: String,
535    pub sent_at: OffsetDateTime,
536    pub nonce: Uuid,
537}
538
539#[cfg(test)]
540pub mod tests {
541    use super::*;
542    use rand::prelude::*;
543    use sqlx::{
544        migrate::{MigrateDatabase, Migrator},
545        Postgres,
546    };
547    use std::path::Path;
548
549    pub struct TestDb {
550        pub db: Db,
551        pub name: String,
552        pub url: String,
553    }
554
555    impl TestDb {
556        pub fn new() -> Self {
557            // Enable tests to run in parallel by serializing the creation of each test database.
558            lazy_static::lazy_static! {
559                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
560            }
561
562            let mut rng = StdRng::from_entropy();
563            let name = format!("zed-test-{}", rng.gen::<u128>());
564            let url = format!("postgres://postgres@localhost/{}", name);
565            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
566            let db = block_on(async {
567                {
568                    let _lock = DB_CREATION.lock();
569                    Postgres::create_database(&url)
570                        .await
571                        .expect("failed to create test db");
572                }
573                let mut db = Db::new(&url, 5).await.unwrap();
574                db.test_mode = true;
575                let migrator = Migrator::new(migrations_path).await.unwrap();
576                migrator.run(&db.pool).await.unwrap();
577                db
578            });
579
580            Self { db, name, url }
581        }
582
583        pub fn db(&self) -> &Db {
584            &self.db
585        }
586    }
587
588    impl Drop for TestDb {
589        fn drop(&mut self) {
590            block_on(async {
591                let query = "
592                    SELECT pg_terminate_backend(pg_stat_activity.pid)
593                    FROM pg_stat_activity
594                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
595                ";
596                sqlx::query(query)
597                    .bind(&self.name)
598                    .execute(&self.db.pool)
599                    .await
600                    .unwrap();
601                self.db.pool.close().await;
602                Postgres::drop_database(&self.url).await.unwrap();
603            });
604        }
605    }
606
607    #[gpui::test]
608    async fn test_get_users_by_ids() {
609        let test_db = TestDb::new();
610        let db = test_db.db();
611
612        let user = db.create_user("user", false).await.unwrap();
613        let friend1 = db.create_user("friend-1", false).await.unwrap();
614        let friend2 = db.create_user("friend-2", false).await.unwrap();
615        let friend3 = db.create_user("friend-3", false).await.unwrap();
616        let stranger = db.create_user("stranger", false).await.unwrap();
617
618        // A user can read their own info, even if they aren't in any channels.
619        assert_eq!(
620            db.get_users_by_ids(
621                user,
622                [user, friend1, friend2, friend3, stranger].iter().copied()
623            )
624            .await
625            .unwrap(),
626            vec![User {
627                id: user,
628                github_login: "user".to_string(),
629                admin: false,
630            },],
631        );
632
633        // A user can read the info of any other user who is in a shared channel
634        // with them.
635        let org = db.create_org("test org", "test-org").await.unwrap();
636        let chan1 = db.create_org_channel(org, "channel-1").await.unwrap();
637        let chan2 = db.create_org_channel(org, "channel-2").await.unwrap();
638        let chan3 = db.create_org_channel(org, "channel-3").await.unwrap();
639
640        db.add_channel_member(chan1, user, false).await.unwrap();
641        db.add_channel_member(chan2, user, false).await.unwrap();
642        db.add_channel_member(chan1, friend1, false).await.unwrap();
643        db.add_channel_member(chan1, friend2, false).await.unwrap();
644        db.add_channel_member(chan2, friend2, false).await.unwrap();
645        db.add_channel_member(chan2, friend3, false).await.unwrap();
646        db.add_channel_member(chan3, stranger, false).await.unwrap();
647
648        assert_eq!(
649            db.get_users_by_ids(
650                user,
651                [user, friend1, friend2, friend3, stranger].iter().copied()
652            )
653            .await
654            .unwrap(),
655            vec![
656                User {
657                    id: user,
658                    github_login: "user".to_string(),
659                    admin: false,
660                },
661                User {
662                    id: friend1,
663                    github_login: "friend-1".to_string(),
664                    admin: false,
665                },
666                User {
667                    id: friend2,
668                    github_login: "friend-2".to_string(),
669                    admin: false,
670                },
671                User {
672                    id: friend3,
673                    github_login: "friend-3".to_string(),
674                    admin: false,
675                }
676            ]
677        );
678
679        // The user's own info is only returned if they request it.
680        assert_eq!(
681            db.get_users_by_ids(user, [friend1].iter().copied())
682                .await
683                .unwrap(),
684            vec![User {
685                id: friend1,
686                github_login: "friend-1".to_string(),
687                admin: false,
688            },]
689        )
690    }
691
692    #[gpui::test]
693    async fn test_recent_channel_messages() {
694        let test_db = TestDb::new();
695        let db = test_db.db();
696        let user = db.create_user("user", false).await.unwrap();
697        let org = db.create_org("org", "org").await.unwrap();
698        let channel = db.create_org_channel(org, "channel").await.unwrap();
699        for i in 0..10 {
700            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
701                .await
702                .unwrap();
703        }
704
705        let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
706        assert_eq!(
707            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
708            ["5", "6", "7", "8", "9"]
709        );
710
711        let prev_messages = db
712            .get_channel_messages(channel, 4, Some(messages[0].id))
713            .await
714            .unwrap();
715        assert_eq!(
716            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
717            ["1", "2", "3", "4"]
718        );
719    }
720
721    #[gpui::test]
722    async fn test_channel_message_nonces() {
723        let test_db = TestDb::new();
724        let db = test_db.db();
725        let user = db.create_user("user", false).await.unwrap();
726        let org = db.create_org("org", "org").await.unwrap();
727        let channel = db.create_org_channel(org, "channel").await.unwrap();
728
729        let msg1_id = db
730            .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
731            .await
732            .unwrap();
733        let msg2_id = db
734            .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
735            .await
736            .unwrap();
737        let msg3_id = db
738            .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
739            .await
740            .unwrap();
741        let msg4_id = db
742            .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
743            .await
744            .unwrap();
745
746        assert_ne!(msg1_id, msg2_id);
747        assert_eq!(msg1_id, msg3_id);
748        assert_eq!(msg2_id, msg4_id);
749    }
750}