db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{types::Uuid, FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24#[derive(Clone)]
 25pub struct Db {
 26    pool: sqlx::PgPool,
 27    test_mode: bool,
 28}
 29
 30impl Db {
 31    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 32        let pool = DbOptions::new()
 33            .max_connections(max_connections)
 34            .connect(url)
 35            .await
 36            .context("failed to connect to postgres database")?;
 37        Ok(Self {
 38            pool,
 39            test_mode: false,
 40        })
 41    }
 42
 43    // signups
 44
 45    pub async fn create_signup(
 46        &self,
 47        github_login: &str,
 48        email_address: &str,
 49        about: &str,
 50    ) -> Result<SignupId> {
 51        test_support!(self, {
 52            let query = "
 53                INSERT INTO signups (github_login, email_address, about)
 54                VALUES ($1, $2, $3)
 55                RETURNING id
 56            ";
 57            sqlx::query_scalar(query)
 58                .bind(github_login)
 59                .bind(email_address)
 60                .bind(about)
 61                .fetch_one(&self.pool)
 62                .await
 63                .map(SignupId)
 64        })
 65    }
 66
 67    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 68        test_support!(self, {
 69            let query = "SELECT * FROM signups ORDER BY github_login ASC";
 70            sqlx::query_as(query).fetch_all(&self.pool).await
 71        })
 72    }
 73
 74    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
 75        test_support!(self, {
 76            let query = "DELETE FROM signups WHERE id = $1";
 77            sqlx::query(query)
 78                .bind(id.0)
 79                .execute(&self.pool)
 80                .await
 81                .map(drop)
 82        })
 83    }
 84
 85    // users
 86
 87    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 88        test_support!(self, {
 89            let query = "
 90                INSERT INTO users (github_login, admin)
 91                VALUES ($1, $2)
 92                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
 93                RETURNING id
 94            ";
 95            sqlx::query_scalar(query)
 96                .bind(github_login)
 97                .bind(admin)
 98                .fetch_one(&self.pool)
 99                .await
100                .map(UserId)
101        })
102    }
103
104    pub async fn get_all_users(&self) -> Result<Vec<User>> {
105        test_support!(self, {
106            let query = "SELECT * FROM users ORDER BY github_login ASC";
107            sqlx::query_as(query).fetch_all(&self.pool).await
108        })
109    }
110
111    pub async fn get_users_by_ids(&self, ids: impl Iterator<Item = UserId>) -> Result<Vec<User>> {
112        let ids = ids.map(|id| id.0).collect::<Vec<_>>();
113        test_support!(self, {
114            let query = "
115                SELECT users.*
116                FROM users
117                WHERE users.id = ANY ($1)
118            ";
119
120            sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
121        })
122    }
123
124    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
125        test_support!(self, {
126            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
127            sqlx::query_as(query)
128                .bind(github_login)
129                .fetch_optional(&self.pool)
130                .await
131        })
132    }
133
134    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
135        test_support!(self, {
136            let query = "UPDATE users SET admin = $1 WHERE id = $2";
137            sqlx::query(query)
138                .bind(is_admin)
139                .bind(id.0)
140                .execute(&self.pool)
141                .await
142                .map(drop)
143        })
144    }
145
146    pub async fn delete_user(&self, id: UserId) -> Result<()> {
147        test_support!(self, {
148            let query = "DELETE FROM users WHERE id = $1;";
149            sqlx::query(query)
150                .bind(id.0)
151                .execute(&self.pool)
152                .await
153                .map(drop)
154        })
155    }
156
157    // access tokens
158
159    pub async fn create_access_token_hash(
160        &self,
161        user_id: UserId,
162        access_token_hash: String,
163    ) -> Result<()> {
164        test_support!(self, {
165            let query = "
166            INSERT INTO access_tokens (user_id, hash)
167            VALUES ($1, $2)
168        ";
169            sqlx::query(query)
170                .bind(user_id.0)
171                .bind(access_token_hash)
172                .execute(&self.pool)
173                .await
174                .map(drop)
175        })
176    }
177
178    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
179        test_support!(self, {
180            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
181            sqlx::query_scalar(query)
182                .bind(user_id.0)
183                .fetch_all(&self.pool)
184                .await
185        })
186    }
187
188    // orgs
189
190    #[allow(unused)] // Help rust-analyzer
191    #[cfg(any(test, feature = "seed-support"))]
192    pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
193        test_support!(self, {
194            let query = "
195                SELECT *
196                FROM orgs
197                WHERE slug = $1
198            ";
199            sqlx::query_as(query)
200                .bind(slug)
201                .fetch_optional(&self.pool)
202                .await
203        })
204    }
205
206    #[cfg(any(test, feature = "seed-support"))]
207    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
208        test_support!(self, {
209            let query = "
210                INSERT INTO orgs (name, slug)
211                VALUES ($1, $2)
212                RETURNING id
213            ";
214            sqlx::query_scalar(query)
215                .bind(name)
216                .bind(slug)
217                .fetch_one(&self.pool)
218                .await
219                .map(OrgId)
220        })
221    }
222
223    #[cfg(any(test, feature = "seed-support"))]
224    pub async fn add_org_member(
225        &self,
226        org_id: OrgId,
227        user_id: UserId,
228        is_admin: bool,
229    ) -> Result<()> {
230        test_support!(self, {
231            let query = "
232                INSERT INTO org_memberships (org_id, user_id, admin)
233                VALUES ($1, $2, $3)
234                ON CONFLICT DO NOTHING
235            ";
236            sqlx::query(query)
237                .bind(org_id.0)
238                .bind(user_id.0)
239                .bind(is_admin)
240                .execute(&self.pool)
241                .await
242                .map(drop)
243        })
244    }
245
246    // channels
247
248    #[cfg(any(test, feature = "seed-support"))]
249    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
250        test_support!(self, {
251            let query = "
252                INSERT INTO channels (owner_id, owner_is_user, name)
253                VALUES ($1, false, $2)
254                RETURNING id
255            ";
256            sqlx::query_scalar(query)
257                .bind(org_id.0)
258                .bind(name)
259                .fetch_one(&self.pool)
260                .await
261                .map(ChannelId)
262        })
263    }
264
265    #[allow(unused)] // Help rust-analyzer
266    #[cfg(any(test, feature = "seed-support"))]
267    pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
268        test_support!(self, {
269            let query = "
270                SELECT *
271                FROM channels
272                WHERE
273                    channels.owner_is_user = false AND
274                    channels.owner_id = $1
275            ";
276            sqlx::query_as(query)
277                .bind(org_id.0)
278                .fetch_all(&self.pool)
279                .await
280        })
281    }
282
283    pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
284        test_support!(self, {
285            let query = "
286                SELECT
287                    channels.id, channels.name
288                FROM
289                    channel_memberships, channels
290                WHERE
291                    channel_memberships.user_id = $1 AND
292                    channel_memberships.channel_id = channels.id
293            ";
294            sqlx::query_as(query)
295                .bind(user_id.0)
296                .fetch_all(&self.pool)
297                .await
298        })
299    }
300
301    pub async fn can_user_access_channel(
302        &self,
303        user_id: UserId,
304        channel_id: ChannelId,
305    ) -> Result<bool> {
306        test_support!(self, {
307            let query = "
308                SELECT id
309                FROM channel_memberships
310                WHERE user_id = $1 AND channel_id = $2
311                LIMIT 1
312            ";
313            sqlx::query_scalar::<_, i32>(query)
314                .bind(user_id.0)
315                .bind(channel_id.0)
316                .fetch_optional(&self.pool)
317                .await
318                .map(|e| e.is_some())
319        })
320    }
321
322    #[cfg(any(test, feature = "seed-support"))]
323    pub async fn add_channel_member(
324        &self,
325        channel_id: ChannelId,
326        user_id: UserId,
327        is_admin: bool,
328    ) -> Result<()> {
329        test_support!(self, {
330            let query = "
331                INSERT INTO channel_memberships (channel_id, user_id, admin)
332                VALUES ($1, $2, $3)
333                ON CONFLICT DO NOTHING
334            ";
335            sqlx::query(query)
336                .bind(channel_id.0)
337                .bind(user_id.0)
338                .bind(is_admin)
339                .execute(&self.pool)
340                .await
341                .map(drop)
342        })
343    }
344
345    // messages
346
347    pub async fn create_channel_message(
348        &self,
349        channel_id: ChannelId,
350        sender_id: UserId,
351        body: &str,
352        timestamp: OffsetDateTime,
353        nonce: u128,
354    ) -> Result<MessageId> {
355        test_support!(self, {
356            let query = "
357                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
358                VALUES ($1, $2, $3, $4, $5)
359                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
360                RETURNING id
361            ";
362            sqlx::query_scalar(query)
363                .bind(channel_id.0)
364                .bind(sender_id.0)
365                .bind(body)
366                .bind(timestamp)
367                .bind(Uuid::from_u128(nonce))
368                .fetch_one(&self.pool)
369                .await
370                .map(MessageId)
371        })
372    }
373
374    pub async fn get_channel_messages(
375        &self,
376        channel_id: ChannelId,
377        count: usize,
378        before_id: Option<MessageId>,
379    ) -> Result<Vec<ChannelMessage>> {
380        test_support!(self, {
381            let query = r#"
382                SELECT * FROM (
383                    SELECT
384                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
385                    FROM
386                        channel_messages
387                    WHERE
388                        channel_id = $1 AND
389                        id < $2
390                    ORDER BY id DESC
391                    LIMIT $3
392                ) as recent_messages
393                ORDER BY id ASC
394            "#;
395            sqlx::query_as(query)
396                .bind(channel_id.0)
397                .bind(before_id.unwrap_or(MessageId::MAX))
398                .bind(count as i64)
399                .fetch_all(&self.pool)
400                .await
401        })
402    }
403}
404
405macro_rules! id_type {
406    ($name:ident) => {
407        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
408        #[sqlx(transparent)]
409        #[serde(transparent)]
410        pub struct $name(pub i32);
411
412        impl $name {
413            #[allow(unused)]
414            pub const MAX: Self = Self(i32::MAX);
415
416            #[allow(unused)]
417            pub fn from_proto(value: u64) -> Self {
418                Self(value as i32)
419            }
420
421            #[allow(unused)]
422            pub fn to_proto(&self) -> u64 {
423                self.0 as u64
424            }
425        }
426    };
427}
428
429id_type!(UserId);
430#[derive(Debug, FromRow, Serialize, PartialEq)]
431pub struct User {
432    pub id: UserId,
433    pub github_login: String,
434    pub admin: bool,
435}
436
437id_type!(OrgId);
438#[derive(FromRow)]
439pub struct Org {
440    pub id: OrgId,
441    pub name: String,
442    pub slug: String,
443}
444
445id_type!(SignupId);
446#[derive(Debug, FromRow, Serialize)]
447pub struct Signup {
448    pub id: SignupId,
449    pub github_login: String,
450    pub email_address: String,
451    pub about: String,
452}
453
454id_type!(ChannelId);
455#[derive(Debug, FromRow, Serialize)]
456pub struct Channel {
457    pub id: ChannelId,
458    pub name: String,
459}
460
461id_type!(MessageId);
462#[derive(Debug, FromRow)]
463pub struct ChannelMessage {
464    pub id: MessageId,
465    pub sender_id: UserId,
466    pub body: String,
467    pub sent_at: OffsetDateTime,
468    pub nonce: Uuid,
469}
470
471#[cfg(test)]
472pub mod tests {
473    use super::*;
474    use rand::prelude::*;
475    use sqlx::{
476        migrate::{MigrateDatabase, Migrator},
477        Postgres,
478    };
479    use std::path::Path;
480
481    pub struct TestDb {
482        pub db: Db,
483        pub name: String,
484        pub url: String,
485    }
486
487    impl TestDb {
488        pub fn new() -> Self {
489            // Enable tests to run in parallel by serializing the creation of each test database.
490            lazy_static::lazy_static! {
491                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
492            }
493
494            let mut rng = StdRng::from_entropy();
495            let name = format!("zed-test-{}", rng.gen::<u128>());
496            let url = format!("postgres://postgres@localhost/{}", name);
497            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
498            let db = block_on(async {
499                {
500                    let _lock = DB_CREATION.lock();
501                    Postgres::create_database(&url)
502                        .await
503                        .expect("failed to create test db");
504                }
505                let mut db = Db::new(&url, 5).await.unwrap();
506                db.test_mode = true;
507                let migrator = Migrator::new(migrations_path).await.unwrap();
508                migrator.run(&db.pool).await.unwrap();
509                db
510            });
511
512            Self { db, name, url }
513        }
514
515        pub fn db(&self) -> &Db {
516            &self.db
517        }
518    }
519
520    impl Drop for TestDb {
521        fn drop(&mut self) {
522            block_on(async {
523                let query = "
524                    SELECT pg_terminate_backend(pg_stat_activity.pid)
525                    FROM pg_stat_activity
526                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
527                ";
528                sqlx::query(query)
529                    .bind(&self.name)
530                    .execute(&self.db.pool)
531                    .await
532                    .unwrap();
533                self.db.pool.close().await;
534                Postgres::drop_database(&self.url).await.unwrap();
535            });
536        }
537    }
538
539    #[gpui::test]
540    async fn test_get_users_by_ids() {
541        let test_db = TestDb::new();
542        let db = test_db.db();
543
544        let user = db.create_user("user", false).await.unwrap();
545        let friend1 = db.create_user("friend-1", false).await.unwrap();
546        let friend2 = db.create_user("friend-2", false).await.unwrap();
547        let friend3 = db.create_user("friend-3", false).await.unwrap();
548
549        assert_eq!(
550            db.get_users_by_ids([user, friend1, friend2, friend3].iter().copied())
551                .await
552                .unwrap(),
553            vec![
554                User {
555                    id: user,
556                    github_login: "user".to_string(),
557                    admin: false,
558                },
559                User {
560                    id: friend1,
561                    github_login: "friend-1".to_string(),
562                    admin: false,
563                },
564                User {
565                    id: friend2,
566                    github_login: "friend-2".to_string(),
567                    admin: false,
568                },
569                User {
570                    id: friend3,
571                    github_login: "friend-3".to_string(),
572                    admin: false,
573                }
574            ]
575        );
576    }
577
578    #[gpui::test]
579    async fn test_recent_channel_messages() {
580        let test_db = TestDb::new();
581        let db = test_db.db();
582        let user = db.create_user("user", false).await.unwrap();
583        let org = db.create_org("org", "org").await.unwrap();
584        let channel = db.create_org_channel(org, "channel").await.unwrap();
585        for i in 0..10 {
586            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
587                .await
588                .unwrap();
589        }
590
591        let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
592        assert_eq!(
593            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
594            ["5", "6", "7", "8", "9"]
595        );
596
597        let prev_messages = db
598            .get_channel_messages(channel, 4, Some(messages[0].id))
599            .await
600            .unwrap();
601        assert_eq!(
602            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
603            ["1", "2", "3", "4"]
604        );
605    }
606
607    #[gpui::test]
608    async fn test_channel_message_nonces() {
609        let test_db = TestDb::new();
610        let db = test_db.db();
611        let user = db.create_user("user", false).await.unwrap();
612        let org = db.create_org("org", "org").await.unwrap();
613        let channel = db.create_org_channel(org, "channel").await.unwrap();
614
615        let msg1_id = db
616            .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
617            .await
618            .unwrap();
619        let msg2_id = db
620            .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
621            .await
622            .unwrap();
623        let msg3_id = db
624            .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
625            .await
626            .unwrap();
627        let msg4_id = db
628            .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
629            .await
630            .unwrap();
631
632        assert_ne!(msg1_id, msg2_id);
633        assert_eq!(msg1_id, msg3_id);
634        assert_eq!(msg2_id, msg4_id);
635    }
636}