db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{types::Uuid, FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24#[derive(Clone)]
 25pub struct Db {
 26    pool: sqlx::PgPool,
 27    test_mode: bool,
 28}
 29
 30impl Db {
 31    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 32        let pool = DbOptions::new()
 33            .max_connections(max_connections)
 34            .connect(url)
 35            .await
 36            .context("failed to connect to postgres database")?;
 37        Ok(Self {
 38            pool,
 39            test_mode: false,
 40        })
 41    }
 42
 43    // signups
 44
 45    pub async fn create_signup(
 46        &self,
 47        github_login: &str,
 48        email_address: &str,
 49        about: &str,
 50    ) -> Result<SignupId> {
 51        test_support!(self, {
 52            let query = "
 53                INSERT INTO signups (github_login, email_address, about)
 54                VALUES ($1, $2, $3)
 55                RETURNING id
 56            ";
 57            sqlx::query_scalar(query)
 58                .bind(github_login)
 59                .bind(email_address)
 60                .bind(about)
 61                .fetch_one(&self.pool)
 62                .await
 63                .map(SignupId)
 64        })
 65    }
 66
 67    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 68        test_support!(self, {
 69            let query = "SELECT * FROM signups ORDER BY github_login ASC";
 70            sqlx::query_as(query).fetch_all(&self.pool).await
 71        })
 72    }
 73
 74    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
 75        test_support!(self, {
 76            let query = "DELETE FROM signups WHERE id = $1";
 77            sqlx::query(query)
 78                .bind(id.0)
 79                .execute(&self.pool)
 80                .await
 81                .map(drop)
 82        })
 83    }
 84
 85    // users
 86
 87    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 88        test_support!(self, {
 89            let query = "
 90                INSERT INTO users (github_login, admin)
 91                VALUES ($1, $2)
 92                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 93                RETURNING id
 94            ";
 95            sqlx::query_scalar(query)
 96                .bind(github_login)
 97                .bind(admin)
 98                .fetch_one(&self.pool)
 99                .await
100                .map(UserId)
101        })
102    }
103
104    pub async fn get_all_users(&self) -> Result<Vec<User>> {
105        test_support!(self, {
106            let query = "SELECT * FROM users ORDER BY github_login ASC";
107            sqlx::query_as(query).fetch_all(&self.pool).await
108        })
109    }
110
111    pub async fn get_users_by_ids(
112        &self,
113        ids: impl IntoIterator<Item = UserId>,
114    ) -> Result<Vec<User>> {
115        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
116        test_support!(self, {
117            let query = "
118                SELECT users.*
119                FROM users
120                WHERE users.id = ANY ($1)
121            ";
122
123            sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
124        })
125    }
126
127    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
128        test_support!(self, {
129            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
130            sqlx::query_as(query)
131                .bind(github_login)
132                .fetch_optional(&self.pool)
133                .await
134        })
135    }
136
137    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
138        test_support!(self, {
139            let query = "UPDATE users SET admin = $1 WHERE id = $2";
140            sqlx::query(query)
141                .bind(is_admin)
142                .bind(id.0)
143                .execute(&self.pool)
144                .await
145                .map(drop)
146        })
147    }
148
149    pub async fn delete_user(&self, id: UserId) -> Result<()> {
150        test_support!(self, {
151            let query = "DELETE FROM users WHERE id = $1;";
152            sqlx::query(query)
153                .bind(id.0)
154                .execute(&self.pool)
155                .await
156                .map(drop)
157        })
158    }
159
160    // access tokens
161
162    pub async fn create_access_token_hash(
163        &self,
164        user_id: UserId,
165        access_token_hash: &str,
166        max_access_token_count: usize,
167    ) -> Result<()> {
168        test_support!(self, {
169            let insert_query = "
170                INSERT INTO access_tokens (user_id, hash)
171                VALUES ($1, $2);
172            ";
173            let cleanup_query = "
174                DELETE FROM access_tokens
175                WHERE id IN (
176                    SELECT id from access_tokens
177                    WHERE user_id = $1
178                    ORDER BY id DESC
179                    OFFSET $3
180                )
181            ";
182
183            let mut tx = self.pool.begin().await?;
184            sqlx::query(insert_query)
185                .bind(user_id.0)
186                .bind(access_token_hash)
187                .execute(&mut tx)
188                .await?;
189            sqlx::query(cleanup_query)
190                .bind(user_id.0)
191                .bind(access_token_hash)
192                .bind(max_access_token_count as u32)
193                .execute(&mut tx)
194                .await?;
195            tx.commit().await
196        })
197    }
198
199    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
200        test_support!(self, {
201            let query = "
202                SELECT hash
203                FROM access_tokens
204                WHERE user_id = $1
205                ORDER BY id DESC
206            ";
207            sqlx::query_scalar(query)
208                .bind(user_id.0)
209                .fetch_all(&self.pool)
210                .await
211        })
212    }
213
214    // orgs
215
216    #[allow(unused)] // Help rust-analyzer
217    #[cfg(any(test, feature = "seed-support"))]
218    pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
219        test_support!(self, {
220            let query = "
221                SELECT *
222                FROM orgs
223                WHERE slug = $1
224            ";
225            sqlx::query_as(query)
226                .bind(slug)
227                .fetch_optional(&self.pool)
228                .await
229        })
230    }
231
232    #[cfg(any(test, feature = "seed-support"))]
233    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
234        test_support!(self, {
235            let query = "
236                INSERT INTO orgs (name, slug)
237                VALUES ($1, $2)
238                RETURNING id
239            ";
240            sqlx::query_scalar(query)
241                .bind(name)
242                .bind(slug)
243                .fetch_one(&self.pool)
244                .await
245                .map(OrgId)
246        })
247    }
248
249    #[cfg(any(test, feature = "seed-support"))]
250    pub async fn add_org_member(
251        &self,
252        org_id: OrgId,
253        user_id: UserId,
254        is_admin: bool,
255    ) -> Result<()> {
256        test_support!(self, {
257            let query = "
258                INSERT INTO org_memberships (org_id, user_id, admin)
259                VALUES ($1, $2, $3)
260                ON CONFLICT DO NOTHING
261            ";
262            sqlx::query(query)
263                .bind(org_id.0)
264                .bind(user_id.0)
265                .bind(is_admin)
266                .execute(&self.pool)
267                .await
268                .map(drop)
269        })
270    }
271
272    // channels
273
274    #[cfg(any(test, feature = "seed-support"))]
275    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
276        test_support!(self, {
277            let query = "
278                INSERT INTO channels (owner_id, owner_is_user, name)
279                VALUES ($1, false, $2)
280                RETURNING id
281            ";
282            sqlx::query_scalar(query)
283                .bind(org_id.0)
284                .bind(name)
285                .fetch_one(&self.pool)
286                .await
287                .map(ChannelId)
288        })
289    }
290
291    #[allow(unused)] // Help rust-analyzer
292    #[cfg(any(test, feature = "seed-support"))]
293    pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
294        test_support!(self, {
295            let query = "
296                SELECT *
297                FROM channels
298                WHERE
299                    channels.owner_is_user = false AND
300                    channels.owner_id = $1
301            ";
302            sqlx::query_as(query)
303                .bind(org_id.0)
304                .fetch_all(&self.pool)
305                .await
306        })
307    }
308
309    pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
310        test_support!(self, {
311            let query = "
312                SELECT
313                    channels.id, channels.name
314                FROM
315                    channel_memberships, channels
316                WHERE
317                    channel_memberships.user_id = $1 AND
318                    channel_memberships.channel_id = channels.id
319            ";
320            sqlx::query_as(query)
321                .bind(user_id.0)
322                .fetch_all(&self.pool)
323                .await
324        })
325    }
326
327    pub async fn can_user_access_channel(
328        &self,
329        user_id: UserId,
330        channel_id: ChannelId,
331    ) -> Result<bool> {
332        test_support!(self, {
333            let query = "
334                SELECT id
335                FROM channel_memberships
336                WHERE user_id = $1 AND channel_id = $2
337                LIMIT 1
338            ";
339            sqlx::query_scalar::<_, i32>(query)
340                .bind(user_id.0)
341                .bind(channel_id.0)
342                .fetch_optional(&self.pool)
343                .await
344                .map(|e| e.is_some())
345        })
346    }
347
348    #[cfg(any(test, feature = "seed-support"))]
349    pub async fn add_channel_member(
350        &self,
351        channel_id: ChannelId,
352        user_id: UserId,
353        is_admin: bool,
354    ) -> Result<()> {
355        test_support!(self, {
356            let query = "
357                INSERT INTO channel_memberships (channel_id, user_id, admin)
358                VALUES ($1, $2, $3)
359                ON CONFLICT DO NOTHING
360            ";
361            sqlx::query(query)
362                .bind(channel_id.0)
363                .bind(user_id.0)
364                .bind(is_admin)
365                .execute(&self.pool)
366                .await
367                .map(drop)
368        })
369    }
370
371    // messages
372
373    pub async fn create_channel_message(
374        &self,
375        channel_id: ChannelId,
376        sender_id: UserId,
377        body: &str,
378        timestamp: OffsetDateTime,
379        nonce: u128,
380    ) -> Result<MessageId> {
381        test_support!(self, {
382            let query = "
383                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
384                VALUES ($1, $2, $3, $4, $5)
385                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
386                RETURNING id
387            ";
388            sqlx::query_scalar(query)
389                .bind(channel_id.0)
390                .bind(sender_id.0)
391                .bind(body)
392                .bind(timestamp)
393                .bind(Uuid::from_u128(nonce))
394                .fetch_one(&self.pool)
395                .await
396                .map(MessageId)
397        })
398    }
399
400    pub async fn get_channel_messages(
401        &self,
402        channel_id: ChannelId,
403        count: usize,
404        before_id: Option<MessageId>,
405    ) -> Result<Vec<ChannelMessage>> {
406        test_support!(self, {
407            let query = r#"
408                SELECT * FROM (
409                    SELECT
410                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
411                    FROM
412                        channel_messages
413                    WHERE
414                        channel_id = $1 AND
415                        id < $2
416                    ORDER BY id DESC
417                    LIMIT $3
418                ) as recent_messages
419                ORDER BY id ASC
420            "#;
421            sqlx::query_as(query)
422                .bind(channel_id.0)
423                .bind(before_id.unwrap_or(MessageId::MAX))
424                .bind(count as i64)
425                .fetch_all(&self.pool)
426                .await
427        })
428    }
429}
430
431macro_rules! id_type {
432    ($name:ident) => {
433        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
434        #[sqlx(transparent)]
435        #[serde(transparent)]
436        pub struct $name(pub i32);
437
438        impl $name {
439            #[allow(unused)]
440            pub const MAX: Self = Self(i32::MAX);
441
442            #[allow(unused)]
443            pub fn from_proto(value: u64) -> Self {
444                Self(value as i32)
445            }
446
447            #[allow(unused)]
448            pub fn to_proto(&self) -> u64 {
449                self.0 as u64
450            }
451        }
452    };
453}
454
455id_type!(UserId);
456#[derive(Debug, FromRow, Serialize, PartialEq)]
457pub struct User {
458    pub id: UserId,
459    pub github_login: String,
460    pub admin: bool,
461}
462
463id_type!(OrgId);
464#[derive(FromRow)]
465pub struct Org {
466    pub id: OrgId,
467    pub name: String,
468    pub slug: String,
469}
470
471id_type!(SignupId);
472#[derive(Debug, FromRow, Serialize)]
473pub struct Signup {
474    pub id: SignupId,
475    pub github_login: String,
476    pub email_address: String,
477    pub about: String,
478}
479
480id_type!(ChannelId);
481#[derive(Debug, FromRow, Serialize)]
482pub struct Channel {
483    pub id: ChannelId,
484    pub name: String,
485}
486
487id_type!(MessageId);
488#[derive(Debug, FromRow)]
489pub struct ChannelMessage {
490    pub id: MessageId,
491    pub sender_id: UserId,
492    pub body: String,
493    pub sent_at: OffsetDateTime,
494    pub nonce: Uuid,
495}
496
497#[cfg(test)]
498pub mod tests {
499    use super::*;
500    use rand::prelude::*;
501    use sqlx::{
502        migrate::{MigrateDatabase, Migrator},
503        Postgres,
504    };
505    use std::path::Path;
506
507    pub struct TestDb {
508        pub db: Db,
509        pub name: String,
510        pub url: String,
511    }
512
513    impl TestDb {
514        pub fn new() -> Self {
515            // Enable tests to run in parallel by serializing the creation of each test database.
516            lazy_static::lazy_static! {
517                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
518            }
519
520            let mut rng = StdRng::from_entropy();
521            let name = format!("zed-test-{}", rng.gen::<u128>());
522            let url = format!("postgres://postgres@localhost/{}", name);
523            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
524            let db = block_on(async {
525                {
526                    let _lock = DB_CREATION.lock();
527                    Postgres::create_database(&url)
528                        .await
529                        .expect("failed to create test db");
530                }
531                let mut db = Db::new(&url, 5).await.unwrap();
532                db.test_mode = true;
533                let migrator = Migrator::new(migrations_path).await.unwrap();
534                migrator.run(&db.pool).await.unwrap();
535                db
536            });
537
538            Self { db, name, url }
539        }
540
541        pub fn db(&self) -> &Db {
542            &self.db
543        }
544    }
545
546    impl Drop for TestDb {
547        fn drop(&mut self) {
548            block_on(async {
549                let query = "
550                    SELECT pg_terminate_backend(pg_stat_activity.pid)
551                    FROM pg_stat_activity
552                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
553                ";
554                sqlx::query(query)
555                    .bind(&self.name)
556                    .execute(&self.db.pool)
557                    .await
558                    .unwrap();
559                self.db.pool.close().await;
560                Postgres::drop_database(&self.url).await.unwrap();
561            });
562        }
563    }
564
565    #[gpui::test]
566    async fn test_get_users_by_ids() {
567        let test_db = TestDb::new();
568        let db = test_db.db();
569
570        let user = db.create_user("user", false).await.unwrap();
571        let friend1 = db.create_user("friend-1", false).await.unwrap();
572        let friend2 = db.create_user("friend-2", false).await.unwrap();
573        let friend3 = db.create_user("friend-3", false).await.unwrap();
574
575        assert_eq!(
576            db.get_users_by_ids([user, friend1, friend2, friend3])
577                .await
578                .unwrap(),
579            vec![
580                User {
581                    id: user,
582                    github_login: "user".to_string(),
583                    admin: false,
584                },
585                User {
586                    id: friend1,
587                    github_login: "friend-1".to_string(),
588                    admin: false,
589                },
590                User {
591                    id: friend2,
592                    github_login: "friend-2".to_string(),
593                    admin: false,
594                },
595                User {
596                    id: friend3,
597                    github_login: "friend-3".to_string(),
598                    admin: false,
599                }
600            ]
601        );
602    }
603
604    #[gpui::test]
605    async fn test_recent_channel_messages() {
606        let test_db = TestDb::new();
607        let db = test_db.db();
608        let user = db.create_user("user", false).await.unwrap();
609        let org = db.create_org("org", "org").await.unwrap();
610        let channel = db.create_org_channel(org, "channel").await.unwrap();
611        for i in 0..10 {
612            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
613                .await
614                .unwrap();
615        }
616
617        let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
618        assert_eq!(
619            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
620            ["5", "6", "7", "8", "9"]
621        );
622
623        let prev_messages = db
624            .get_channel_messages(channel, 4, Some(messages[0].id))
625            .await
626            .unwrap();
627        assert_eq!(
628            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
629            ["1", "2", "3", "4"]
630        );
631    }
632
633    #[gpui::test]
634    async fn test_channel_message_nonces() {
635        let test_db = TestDb::new();
636        let db = test_db.db();
637        let user = db.create_user("user", false).await.unwrap();
638        let org = db.create_org("org", "org").await.unwrap();
639        let channel = db.create_org_channel(org, "channel").await.unwrap();
640
641        let msg1_id = db
642            .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
643            .await
644            .unwrap();
645        let msg2_id = db
646            .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
647            .await
648            .unwrap();
649        let msg3_id = db
650            .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
651            .await
652            .unwrap();
653        let msg4_id = db
654            .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
655            .await
656            .unwrap();
657
658        assert_ne!(msg1_id, msg2_id);
659        assert_eq!(msg1_id, msg3_id);
660        assert_eq!(msg2_id, msg4_id);
661    }
662
663    #[gpui::test]
664    async fn test_create_access_tokens() {
665        let test_db = TestDb::new();
666        let db = test_db.db();
667        let user = db.create_user("the-user", false).await.unwrap();
668
669        db.create_access_token_hash(user, "h1", 3).await.unwrap();
670        db.create_access_token_hash(user, "h2", 3).await.unwrap();
671        assert_eq!(
672            db.get_access_token_hashes(user).await.unwrap(),
673            &["h2".to_string(), "h1".to_string()]
674        );
675
676        db.create_access_token_hash(user, "h3", 3).await.unwrap();
677        assert_eq!(
678            db.get_access_token_hashes(user).await.unwrap(),
679            &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
680        );
681
682        db.create_access_token_hash(user, "h4", 3).await.unwrap();
683        assert_eq!(
684            db.get_access_token_hashes(user).await.unwrap(),
685            &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
686        );
687
688        db.create_access_token_hash(user, "h5", 3).await.unwrap();
689        assert_eq!(
690            db.get_access_token_hashes(user).await.unwrap(),
691            &["h5".to_string(), "h4".to_string(), "h3".to_string()]
692        );
693    }
694}