db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{types::Uuid, FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24#[derive(Clone)]
 25pub struct Db {
 26    pool: sqlx::PgPool,
 27    test_mode: bool,
 28}
 29
 30impl Db {
 31    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 32        let pool = DbOptions::new()
 33            .max_connections(max_connections)
 34            .connect(url)
 35            .await
 36            .context("failed to connect to postgres database")?;
 37        Ok(Self {
 38            pool,
 39            test_mode: false,
 40        })
 41    }
 42
 43    // signups
 44
 45    pub async fn create_signup(
 46        &self,
 47        github_login: &str,
 48        email_address: &str,
 49        about: &str,
 50    ) -> Result<SignupId> {
 51        test_support!(self, {
 52            let query = "
 53                INSERT INTO signups (github_login, email_address, about)
 54                VALUES ($1, $2, $3)
 55                RETURNING id
 56            ";
 57            sqlx::query_scalar(query)
 58                .bind(github_login)
 59                .bind(email_address)
 60                .bind(about)
 61                .fetch_one(&self.pool)
 62                .await
 63                .map(SignupId)
 64        })
 65    }
 66
 67    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 68        test_support!(self, {
 69            let query = "SELECT * FROM signups ORDER BY github_login ASC";
 70            sqlx::query_as(query).fetch_all(&self.pool).await
 71        })
 72    }
 73
 74    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
 75        test_support!(self, {
 76            let query = "DELETE FROM signups WHERE id = $1";
 77            sqlx::query(query)
 78                .bind(id.0)
 79                .execute(&self.pool)
 80                .await
 81                .map(drop)
 82        })
 83    }
 84
 85    // users
 86
 87    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 88        test_support!(self, {
 89            let query = "
 90                INSERT INTO users (github_login, admin)
 91                VALUES ($1, $2)
 92                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 93                RETURNING id
 94            ";
 95            sqlx::query_scalar(query)
 96                .bind(github_login)
 97                .bind(admin)
 98                .fetch_one(&self.pool)
 99                .await
100                .map(UserId)
101        })
102    }
103
104    pub async fn get_all_users(&self) -> Result<Vec<User>> {
105        test_support!(self, {
106            let query = "SELECT * FROM users ORDER BY github_login ASC";
107            sqlx::query_as(query).fetch_all(&self.pool).await
108        })
109    }
110
111    pub async fn get_users_by_ids(
112        &self,
113        requester_id: UserId,
114        ids: impl Iterator<Item = UserId>,
115    ) -> Result<Vec<User>> {
116        let mut include_requester = false;
117        let ids = ids
118            .map(|id| {
119                if id == requester_id {
120                    include_requester = true;
121                }
122                id.0
123            })
124            .collect::<Vec<_>>();
125
126        test_support!(self, {
127            // Only return users that are in a common channel with the requesting user.
128            // Also allow the requesting user to return their own data, even if they aren't
129            // in any channels.
130            let query = "
131                SELECT
132                    users.*
133                FROM
134                    users, channel_memberships
135                WHERE
136                    users.id = ANY ($1) AND
137                    channel_memberships.user_id = users.id AND
138                    channel_memberships.channel_id IN (
139                        SELECT channel_id
140                        FROM channel_memberships
141                        WHERE channel_memberships.user_id = $2
142                    )
143                UNION
144                SELECT
145                    users.*
146                FROM
147                    users
148                WHERE
149                    $3 AND users.id = $2
150            ";
151
152            sqlx::query_as(query)
153                .bind(&ids)
154                .bind(requester_id)
155                .bind(include_requester)
156                .fetch_all(&self.pool)
157                .await
158        })
159    }
160
161    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
162        test_support!(self, {
163            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
164            sqlx::query_as(query)
165                .bind(github_login)
166                .fetch_optional(&self.pool)
167                .await
168        })
169    }
170
171    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
172        test_support!(self, {
173            let query = "UPDATE users SET admin = $1 WHERE id = $2";
174            sqlx::query(query)
175                .bind(is_admin)
176                .bind(id.0)
177                .execute(&self.pool)
178                .await
179                .map(drop)
180        })
181    }
182
183    pub async fn delete_user(&self, id: UserId) -> Result<()> {
184        test_support!(self, {
185            let query = "DELETE FROM users WHERE id = $1;";
186            sqlx::query(query)
187                .bind(id.0)
188                .execute(&self.pool)
189                .await
190                .map(drop)
191        })
192    }
193
194    // access tokens
195
196    pub async fn create_access_token_hash(
197        &self,
198        user_id: UserId,
199        access_token_hash: String,
200    ) -> Result<()> {
201        test_support!(self, {
202            let query = "
203            INSERT INTO access_tokens (user_id, hash)
204            VALUES ($1, $2)
205        ";
206            sqlx::query(query)
207                .bind(user_id.0)
208                .bind(access_token_hash)
209                .execute(&self.pool)
210                .await
211                .map(drop)
212        })
213    }
214
215    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
216        test_support!(self, {
217            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
218            sqlx::query_scalar(query)
219                .bind(user_id.0)
220                .fetch_all(&self.pool)
221                .await
222        })
223    }
224
225    // orgs
226
227    #[allow(unused)] // Help rust-analyzer
228    #[cfg(any(test, feature = "seed-support"))]
229    pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
230        test_support!(self, {
231            let query = "
232                SELECT *
233                FROM orgs
234                WHERE slug = $1
235            ";
236            sqlx::query_as(query)
237                .bind(slug)
238                .fetch_optional(&self.pool)
239                .await
240        })
241    }
242
243    #[cfg(any(test, feature = "seed-support"))]
244    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
245        test_support!(self, {
246            let query = "
247                INSERT INTO orgs (name, slug)
248                VALUES ($1, $2)
249                RETURNING id
250            ";
251            sqlx::query_scalar(query)
252                .bind(name)
253                .bind(slug)
254                .fetch_one(&self.pool)
255                .await
256                .map(OrgId)
257        })
258    }
259
260    #[cfg(any(test, feature = "seed-support"))]
261    pub async fn add_org_member(
262        &self,
263        org_id: OrgId,
264        user_id: UserId,
265        is_admin: bool,
266    ) -> Result<()> {
267        test_support!(self, {
268            let query = "
269                INSERT INTO org_memberships (org_id, user_id, admin)
270                VALUES ($1, $2, $3)
271                ON CONFLICT DO NOTHING
272            ";
273            sqlx::query(query)
274                .bind(org_id.0)
275                .bind(user_id.0)
276                .bind(is_admin)
277                .execute(&self.pool)
278                .await
279                .map(drop)
280        })
281    }
282
283    // channels
284
285    #[cfg(any(test, feature = "seed-support"))]
286    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
287        test_support!(self, {
288            let query = "
289                INSERT INTO channels (owner_id, owner_is_user, name)
290                VALUES ($1, false, $2)
291                RETURNING id
292            ";
293            sqlx::query_scalar(query)
294                .bind(org_id.0)
295                .bind(name)
296                .fetch_one(&self.pool)
297                .await
298                .map(ChannelId)
299        })
300    }
301
302    #[allow(unused)] // Help rust-analyzer
303    #[cfg(any(test, feature = "seed-support"))]
304    pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
305        test_support!(self, {
306            let query = "
307                SELECT *
308                FROM channels
309                WHERE
310                    channels.owner_is_user = false AND
311                    channels.owner_id = $1
312            ";
313            sqlx::query_as(query)
314                .bind(org_id.0)
315                .fetch_all(&self.pool)
316                .await
317        })
318    }
319
320    pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
321        test_support!(self, {
322            let query = "
323                SELECT
324                    channels.id, channels.name
325                FROM
326                    channel_memberships, channels
327                WHERE
328                    channel_memberships.user_id = $1 AND
329                    channel_memberships.channel_id = channels.id
330            ";
331            sqlx::query_as(query)
332                .bind(user_id.0)
333                .fetch_all(&self.pool)
334                .await
335        })
336    }
337
338    pub async fn can_user_access_channel(
339        &self,
340        user_id: UserId,
341        channel_id: ChannelId,
342    ) -> Result<bool> {
343        test_support!(self, {
344            let query = "
345                SELECT id
346                FROM channel_memberships
347                WHERE user_id = $1 AND channel_id = $2
348                LIMIT 1
349            ";
350            sqlx::query_scalar::<_, i32>(query)
351                .bind(user_id.0)
352                .bind(channel_id.0)
353                .fetch_optional(&self.pool)
354                .await
355                .map(|e| e.is_some())
356        })
357    }
358
359    #[cfg(any(test, feature = "seed-support"))]
360    pub async fn add_channel_member(
361        &self,
362        channel_id: ChannelId,
363        user_id: UserId,
364        is_admin: bool,
365    ) -> Result<()> {
366        test_support!(self, {
367            let query = "
368                INSERT INTO channel_memberships (channel_id, user_id, admin)
369                VALUES ($1, $2, $3)
370                ON CONFLICT DO NOTHING
371            ";
372            sqlx::query(query)
373                .bind(channel_id.0)
374                .bind(user_id.0)
375                .bind(is_admin)
376                .execute(&self.pool)
377                .await
378                .map(drop)
379        })
380    }
381
382    // messages
383
384    pub async fn create_channel_message(
385        &self,
386        channel_id: ChannelId,
387        sender_id: UserId,
388        body: &str,
389        timestamp: OffsetDateTime,
390        nonce: u128,
391    ) -> Result<MessageId> {
392        test_support!(self, {
393            let query = "
394                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
395                VALUES ($1, $2, $3, $4, $5)
396                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
397                RETURNING id
398            ";
399            sqlx::query_scalar(query)
400                .bind(channel_id.0)
401                .bind(sender_id.0)
402                .bind(body)
403                .bind(timestamp)
404                .bind(Uuid::from_u128(nonce))
405                .fetch_one(&self.pool)
406                .await
407                .map(MessageId)
408        })
409    }
410
411    pub async fn get_channel_messages(
412        &self,
413        channel_id: ChannelId,
414        count: usize,
415        before_id: Option<MessageId>,
416    ) -> Result<Vec<ChannelMessage>> {
417        test_support!(self, {
418            let query = r#"
419                SELECT * FROM (
420                    SELECT
421                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
422                    FROM
423                        channel_messages
424                    WHERE
425                        channel_id = $1 AND
426                        id < $2
427                    ORDER BY id DESC
428                    LIMIT $3
429                ) as recent_messages
430                ORDER BY id ASC
431            "#;
432            sqlx::query_as(query)
433                .bind(channel_id.0)
434                .bind(before_id.unwrap_or(MessageId::MAX))
435                .bind(count as i64)
436                .fetch_all(&self.pool)
437                .await
438        })
439    }
440}
441
442macro_rules! id_type {
443    ($name:ident) => {
444        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
445        #[sqlx(transparent)]
446        #[serde(transparent)]
447        pub struct $name(pub i32);
448
449        impl $name {
450            #[allow(unused)]
451            pub const MAX: Self = Self(i32::MAX);
452
453            #[allow(unused)]
454            pub fn from_proto(value: u64) -> Self {
455                Self(value as i32)
456            }
457
458            #[allow(unused)]
459            pub fn to_proto(&self) -> u64 {
460                self.0 as u64
461            }
462        }
463    };
464}
465
466id_type!(UserId);
467#[derive(Debug, FromRow, Serialize, PartialEq)]
468pub struct User {
469    pub id: UserId,
470    pub github_login: String,
471    pub admin: bool,
472}
473
474id_type!(OrgId);
475#[derive(FromRow)]
476pub struct Org {
477    pub id: OrgId,
478    pub name: String,
479    pub slug: String,
480}
481
482id_type!(SignupId);
483#[derive(Debug, FromRow, Serialize)]
484pub struct Signup {
485    pub id: SignupId,
486    pub github_login: String,
487    pub email_address: String,
488    pub about: String,
489}
490
491id_type!(ChannelId);
492#[derive(Debug, FromRow, Serialize)]
493pub struct Channel {
494    pub id: ChannelId,
495    pub name: String,
496}
497
498id_type!(MessageId);
499#[derive(Debug, FromRow)]
500pub struct ChannelMessage {
501    pub id: MessageId,
502    pub sender_id: UserId,
503    pub body: String,
504    pub sent_at: OffsetDateTime,
505    pub nonce: Uuid,
506}
507
508#[cfg(test)]
509pub mod tests {
510    use super::*;
511    use rand::prelude::*;
512    use sqlx::{
513        migrate::{MigrateDatabase, Migrator},
514        Postgres,
515    };
516    use std::path::Path;
517
518    pub struct TestDb {
519        pub db: Db,
520        pub name: String,
521        pub url: String,
522    }
523
524    impl TestDb {
525        pub fn new() -> Self {
526            // Enable tests to run in parallel by serializing the creation of each test database.
527            lazy_static::lazy_static! {
528                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
529            }
530
531            let mut rng = StdRng::from_entropy();
532            let name = format!("zed-test-{}", rng.gen::<u128>());
533            let url = format!("postgres://postgres@localhost/{}", name);
534            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
535            let db = block_on(async {
536                {
537                    let _lock = DB_CREATION.lock();
538                    Postgres::create_database(&url)
539                        .await
540                        .expect("failed to create test db");
541                }
542                let mut db = Db::new(&url, 5).await.unwrap();
543                db.test_mode = true;
544                let migrator = Migrator::new(migrations_path).await.unwrap();
545                migrator.run(&db.pool).await.unwrap();
546                db
547            });
548
549            Self { db, name, url }
550        }
551
552        pub fn db(&self) -> &Db {
553            &self.db
554        }
555    }
556
557    impl Drop for TestDb {
558        fn drop(&mut self) {
559            block_on(async {
560                let query = "
561                    SELECT pg_terminate_backend(pg_stat_activity.pid)
562                    FROM pg_stat_activity
563                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
564                ";
565                sqlx::query(query)
566                    .bind(&self.name)
567                    .execute(&self.db.pool)
568                    .await
569                    .unwrap();
570                self.db.pool.close().await;
571                Postgres::drop_database(&self.url).await.unwrap();
572            });
573        }
574    }
575
576    #[gpui::test]
577    async fn test_get_users_by_ids() {
578        let test_db = TestDb::new();
579        let db = test_db.db();
580
581        let user = db.create_user("user", false).await.unwrap();
582        let friend1 = db.create_user("friend-1", false).await.unwrap();
583        let friend2 = db.create_user("friend-2", false).await.unwrap();
584        let friend3 = db.create_user("friend-3", false).await.unwrap();
585        let stranger = db.create_user("stranger", false).await.unwrap();
586
587        // A user can read their own info, even if they aren't in any channels.
588        assert_eq!(
589            db.get_users_by_ids(
590                user,
591                [user, friend1, friend2, friend3, stranger].iter().copied()
592            )
593            .await
594            .unwrap(),
595            vec![User {
596                id: user,
597                github_login: "user".to_string(),
598                admin: false,
599            },],
600        );
601
602        // A user can read the info of any other user who is in a shared channel
603        // with them.
604        let org = db.create_org("test org", "test-org").await.unwrap();
605        let chan1 = db.create_org_channel(org, "channel-1").await.unwrap();
606        let chan2 = db.create_org_channel(org, "channel-2").await.unwrap();
607        let chan3 = db.create_org_channel(org, "channel-3").await.unwrap();
608
609        db.add_channel_member(chan1, user, false).await.unwrap();
610        db.add_channel_member(chan2, user, false).await.unwrap();
611        db.add_channel_member(chan1, friend1, false).await.unwrap();
612        db.add_channel_member(chan1, friend2, false).await.unwrap();
613        db.add_channel_member(chan2, friend2, false).await.unwrap();
614        db.add_channel_member(chan2, friend3, false).await.unwrap();
615        db.add_channel_member(chan3, stranger, false).await.unwrap();
616
617        assert_eq!(
618            db.get_users_by_ids(
619                user,
620                [user, friend1, friend2, friend3, stranger].iter().copied()
621            )
622            .await
623            .unwrap(),
624            vec![
625                User {
626                    id: user,
627                    github_login: "user".to_string(),
628                    admin: false,
629                },
630                User {
631                    id: friend1,
632                    github_login: "friend-1".to_string(),
633                    admin: false,
634                },
635                User {
636                    id: friend2,
637                    github_login: "friend-2".to_string(),
638                    admin: false,
639                },
640                User {
641                    id: friend3,
642                    github_login: "friend-3".to_string(),
643                    admin: false,
644                }
645            ]
646        );
647
648        // The user's own info is only returned if they request it.
649        assert_eq!(
650            db.get_users_by_ids(user, [friend1].iter().copied())
651                .await
652                .unwrap(),
653            vec![User {
654                id: friend1,
655                github_login: "friend-1".to_string(),
656                admin: false,
657            },]
658        )
659    }
660
661    #[gpui::test]
662    async fn test_recent_channel_messages() {
663        let test_db = TestDb::new();
664        let db = test_db.db();
665        let user = db.create_user("user", false).await.unwrap();
666        let org = db.create_org("org", "org").await.unwrap();
667        let channel = db.create_org_channel(org, "channel").await.unwrap();
668        for i in 0..10 {
669            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
670                .await
671                .unwrap();
672        }
673
674        let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
675        assert_eq!(
676            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
677            ["5", "6", "7", "8", "9"]
678        );
679
680        let prev_messages = db
681            .get_channel_messages(channel, 4, Some(messages[0].id))
682            .await
683            .unwrap();
684        assert_eq!(
685            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
686            ["1", "2", "3", "4"]
687        );
688    }
689
690    #[gpui::test]
691    async fn test_channel_message_nonces() {
692        let test_db = TestDb::new();
693        let db = test_db.db();
694        let user = db.create_user("user", false).await.unwrap();
695        let org = db.create_org("org", "org").await.unwrap();
696        let channel = db.create_org_channel(org, "channel").await.unwrap();
697
698        let msg1_id = db
699            .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
700            .await
701            .unwrap();
702        let msg2_id = db
703            .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
704            .await
705            .unwrap();
706        let msg3_id = db
707            .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
708            .await
709            .unwrap();
710        let msg4_id = db
711            .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
712            .await
713            .unwrap();
714
715        assert_ne!(msg1_id, msg2_id);
716        assert_eq!(msg1_id, msg3_id);
717        assert_eq!(msg2_id, msg4_id);
718    }
719}