db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{types::Uuid, FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24#[derive(Clone)]
 25pub struct Db {
 26    pool: sqlx::PgPool,
 27    test_mode: bool,
 28}
 29
 30impl Db {
 31    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 32        let pool = DbOptions::new()
 33            .max_connections(max_connections)
 34            .connect(url)
 35            .await
 36            .context("failed to connect to postgres database")?;
 37        Ok(Self {
 38            pool,
 39            test_mode: false,
 40        })
 41    }
 42
 43    // signups
 44
 45    pub async fn create_signup(
 46        &self,
 47        github_login: &str,
 48        email_address: &str,
 49        about: &str,
 50        wants_releases: bool,
 51        wants_updates: bool,
 52        wants_community: bool,
 53    ) -> Result<SignupId> {
 54        test_support!(self, {
 55            let query = "
 56                INSERT INTO signups (
 57                    github_login,
 58                    email_address,
 59                    about,
 60                    wants_releases,
 61                    wants_updates,
 62                    wants_community
 63                )
 64                VALUES ($1, $2, $3, $4, $5, $6)
 65                RETURNING id
 66            ";
 67            sqlx::query_scalar(query)
 68                .bind(github_login)
 69                .bind(email_address)
 70                .bind(about)
 71                .bind(wants_releases)
 72                .bind(wants_updates)
 73                .bind(wants_community)
 74                .fetch_one(&self.pool)
 75                .await
 76                .map(SignupId)
 77        })
 78    }
 79
 80    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 81        test_support!(self, {
 82            let query = "SELECT * FROM signups ORDER BY github_login ASC";
 83            sqlx::query_as(query).fetch_all(&self.pool).await
 84        })
 85    }
 86
 87    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
 88        test_support!(self, {
 89            let query = "DELETE FROM signups WHERE id = $1";
 90            sqlx::query(query)
 91                .bind(id.0)
 92                .execute(&self.pool)
 93                .await
 94                .map(drop)
 95        })
 96    }
 97
 98    // users
 99
100    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
101        test_support!(self, {
102            let query = "
103                INSERT INTO users (github_login, admin)
104                VALUES ($1, $2)
105                ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
106                RETURNING id
107            ";
108            sqlx::query_scalar(query)
109                .bind(github_login)
110                .bind(admin)
111                .fetch_one(&self.pool)
112                .await
113                .map(UserId)
114        })
115    }
116
117    pub async fn get_all_users(&self) -> Result<Vec<User>> {
118        test_support!(self, {
119            let query = "SELECT * FROM users ORDER BY github_login ASC";
120            sqlx::query_as(query).fetch_all(&self.pool).await
121        })
122    }
123
124    pub async fn get_users_by_ids(
125        &self,
126        ids: impl IntoIterator<Item = UserId>,
127    ) -> Result<Vec<User>> {
128        let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
129        test_support!(self, {
130            let query = "
131                SELECT users.*
132                FROM users
133                WHERE users.id = ANY ($1)
134            ";
135
136            sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
137        })
138    }
139
140    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
141        test_support!(self, {
142            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
143            sqlx::query_as(query)
144                .bind(github_login)
145                .fetch_optional(&self.pool)
146                .await
147        })
148    }
149
150    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
151        test_support!(self, {
152            let query = "UPDATE users SET admin = $1 WHERE id = $2";
153            sqlx::query(query)
154                .bind(is_admin)
155                .bind(id.0)
156                .execute(&self.pool)
157                .await
158                .map(drop)
159        })
160    }
161
162    pub async fn delete_user(&self, id: UserId) -> Result<()> {
163        test_support!(self, {
164            let query = "DELETE FROM users WHERE id = $1;";
165            sqlx::query(query)
166                .bind(id.0)
167                .execute(&self.pool)
168                .await
169                .map(drop)
170        })
171    }
172
173    // access tokens
174
175    pub async fn create_access_token_hash(
176        &self,
177        user_id: UserId,
178        access_token_hash: String,
179    ) -> Result<()> {
180        test_support!(self, {
181            let query = "
182            INSERT INTO access_tokens (user_id, hash)
183            VALUES ($1, $2)
184        ";
185            sqlx::query(query)
186                .bind(user_id.0)
187                .bind(access_token_hash)
188                .execute(&self.pool)
189                .await
190                .map(drop)
191        })
192    }
193
194    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
195        test_support!(self, {
196            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
197            sqlx::query_scalar(query)
198                .bind(user_id.0)
199                .fetch_all(&self.pool)
200                .await
201        })
202    }
203
204    // orgs
205
206    #[allow(unused)] // Help rust-analyzer
207    #[cfg(any(test, feature = "seed-support"))]
208    pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
209        test_support!(self, {
210            let query = "
211                SELECT *
212                FROM orgs
213                WHERE slug = $1
214            ";
215            sqlx::query_as(query)
216                .bind(slug)
217                .fetch_optional(&self.pool)
218                .await
219        })
220    }
221
222    #[cfg(any(test, feature = "seed-support"))]
223    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
224        test_support!(self, {
225            let query = "
226                INSERT INTO orgs (name, slug)
227                VALUES ($1, $2)
228                RETURNING id
229            ";
230            sqlx::query_scalar(query)
231                .bind(name)
232                .bind(slug)
233                .fetch_one(&self.pool)
234                .await
235                .map(OrgId)
236        })
237    }
238
239    #[cfg(any(test, feature = "seed-support"))]
240    pub async fn add_org_member(
241        &self,
242        org_id: OrgId,
243        user_id: UserId,
244        is_admin: bool,
245    ) -> Result<()> {
246        test_support!(self, {
247            let query = "
248                INSERT INTO org_memberships (org_id, user_id, admin)
249                VALUES ($1, $2, $3)
250                ON CONFLICT DO NOTHING
251            ";
252            sqlx::query(query)
253                .bind(org_id.0)
254                .bind(user_id.0)
255                .bind(is_admin)
256                .execute(&self.pool)
257                .await
258                .map(drop)
259        })
260    }
261
262    // channels
263
264    #[cfg(any(test, feature = "seed-support"))]
265    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
266        test_support!(self, {
267            let query = "
268                INSERT INTO channels (owner_id, owner_is_user, name)
269                VALUES ($1, false, $2)
270                RETURNING id
271            ";
272            sqlx::query_scalar(query)
273                .bind(org_id.0)
274                .bind(name)
275                .fetch_one(&self.pool)
276                .await
277                .map(ChannelId)
278        })
279    }
280
281    #[allow(unused)] // Help rust-analyzer
282    #[cfg(any(test, feature = "seed-support"))]
283    pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
284        test_support!(self, {
285            let query = "
286                SELECT *
287                FROM channels
288                WHERE
289                    channels.owner_is_user = false AND
290                    channels.owner_id = $1
291            ";
292            sqlx::query_as(query)
293                .bind(org_id.0)
294                .fetch_all(&self.pool)
295                .await
296        })
297    }
298
299    pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
300        test_support!(self, {
301            let query = "
302                SELECT
303                    channels.id, channels.name
304                FROM
305                    channel_memberships, channels
306                WHERE
307                    channel_memberships.user_id = $1 AND
308                    channel_memberships.channel_id = channels.id
309            ";
310            sqlx::query_as(query)
311                .bind(user_id.0)
312                .fetch_all(&self.pool)
313                .await
314        })
315    }
316
317    pub async fn can_user_access_channel(
318        &self,
319        user_id: UserId,
320        channel_id: ChannelId,
321    ) -> Result<bool> {
322        test_support!(self, {
323            let query = "
324                SELECT id
325                FROM channel_memberships
326                WHERE user_id = $1 AND channel_id = $2
327                LIMIT 1
328            ";
329            sqlx::query_scalar::<_, i32>(query)
330                .bind(user_id.0)
331                .bind(channel_id.0)
332                .fetch_optional(&self.pool)
333                .await
334                .map(|e| e.is_some())
335        })
336    }
337
338    #[cfg(any(test, feature = "seed-support"))]
339    pub async fn add_channel_member(
340        &self,
341        channel_id: ChannelId,
342        user_id: UserId,
343        is_admin: bool,
344    ) -> Result<()> {
345        test_support!(self, {
346            let query = "
347                INSERT INTO channel_memberships (channel_id, user_id, admin)
348                VALUES ($1, $2, $3)
349                ON CONFLICT DO NOTHING
350            ";
351            sqlx::query(query)
352                .bind(channel_id.0)
353                .bind(user_id.0)
354                .bind(is_admin)
355                .execute(&self.pool)
356                .await
357                .map(drop)
358        })
359    }
360
361    // messages
362
363    pub async fn create_channel_message(
364        &self,
365        channel_id: ChannelId,
366        sender_id: UserId,
367        body: &str,
368        timestamp: OffsetDateTime,
369        nonce: u128,
370    ) -> Result<MessageId> {
371        test_support!(self, {
372            let query = "
373                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
374                VALUES ($1, $2, $3, $4, $5)
375                ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
376                RETURNING id
377            ";
378            sqlx::query_scalar(query)
379                .bind(channel_id.0)
380                .bind(sender_id.0)
381                .bind(body)
382                .bind(timestamp)
383                .bind(Uuid::from_u128(nonce))
384                .fetch_one(&self.pool)
385                .await
386                .map(MessageId)
387        })
388    }
389
390    pub async fn get_channel_messages(
391        &self,
392        channel_id: ChannelId,
393        count: usize,
394        before_id: Option<MessageId>,
395    ) -> Result<Vec<ChannelMessage>> {
396        test_support!(self, {
397            let query = r#"
398                SELECT * FROM (
399                    SELECT
400                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
401                    FROM
402                        channel_messages
403                    WHERE
404                        channel_id = $1 AND
405                        id < $2
406                    ORDER BY id DESC
407                    LIMIT $3
408                ) as recent_messages
409                ORDER BY id ASC
410            "#;
411            sqlx::query_as(query)
412                .bind(channel_id.0)
413                .bind(before_id.unwrap_or(MessageId::MAX))
414                .bind(count as i64)
415                .fetch_all(&self.pool)
416                .await
417        })
418    }
419}
420
421macro_rules! id_type {
422    ($name:ident) => {
423        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
424        #[sqlx(transparent)]
425        #[serde(transparent)]
426        pub struct $name(pub i32);
427
428        impl $name {
429            #[allow(unused)]
430            pub const MAX: Self = Self(i32::MAX);
431
432            #[allow(unused)]
433            pub fn from_proto(value: u64) -> Self {
434                Self(value as i32)
435            }
436
437            #[allow(unused)]
438            pub fn to_proto(&self) -> u64 {
439                self.0 as u64
440            }
441        }
442    };
443}
444
445id_type!(UserId);
446#[derive(Debug, FromRow, Serialize, PartialEq)]
447pub struct User {
448    pub id: UserId,
449    pub github_login: String,
450    pub admin: bool,
451}
452
453id_type!(OrgId);
454#[derive(FromRow)]
455pub struct Org {
456    pub id: OrgId,
457    pub name: String,
458    pub slug: String,
459}
460
461id_type!(SignupId);
462#[derive(Debug, FromRow, Serialize)]
463pub struct Signup {
464    pub id: SignupId,
465    pub github_login: String,
466    pub email_address: String,
467    pub about: String,
468    pub wants_releases: Option<bool>,
469    pub wants_updates: Option<bool>,
470    pub wants_community: Option<bool>,
471}
472
473id_type!(ChannelId);
474#[derive(Debug, FromRow, Serialize)]
475pub struct Channel {
476    pub id: ChannelId,
477    pub name: String,
478}
479
480id_type!(MessageId);
481#[derive(Debug, FromRow)]
482pub struct ChannelMessage {
483    pub id: MessageId,
484    pub sender_id: UserId,
485    pub body: String,
486    pub sent_at: OffsetDateTime,
487    pub nonce: Uuid,
488}
489
490#[cfg(test)]
491pub mod tests {
492    use super::*;
493    use rand::prelude::*;
494    use sqlx::{
495        migrate::{MigrateDatabase, Migrator},
496        Postgres,
497    };
498    use std::path::Path;
499
500    pub struct TestDb {
501        pub db: Db,
502        pub name: String,
503        pub url: String,
504    }
505
506    impl TestDb {
507        pub fn new() -> Self {
508            // Enable tests to run in parallel by serializing the creation of each test database.
509            lazy_static::lazy_static! {
510                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
511            }
512
513            let mut rng = StdRng::from_entropy();
514            let name = format!("zed-test-{}", rng.gen::<u128>());
515            let url = format!("postgres://postgres@localhost/{}", name);
516            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
517            let db = block_on(async {
518                {
519                    let _lock = DB_CREATION.lock();
520                    Postgres::create_database(&url)
521                        .await
522                        .expect("failed to create test db");
523                }
524                let mut db = Db::new(&url, 5).await.unwrap();
525                db.test_mode = true;
526                let migrator = Migrator::new(migrations_path).await.unwrap();
527                migrator.run(&db.pool).await.unwrap();
528                db
529            });
530
531            Self { db, name, url }
532        }
533
534        pub fn db(&self) -> &Db {
535            &self.db
536        }
537    }
538
539    impl Drop for TestDb {
540        fn drop(&mut self) {
541            block_on(async {
542                let query = "
543                    SELECT pg_terminate_backend(pg_stat_activity.pid)
544                    FROM pg_stat_activity
545                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
546                ";
547                sqlx::query(query)
548                    .bind(&self.name)
549                    .execute(&self.db.pool)
550                    .await
551                    .unwrap();
552                self.db.pool.close().await;
553                Postgres::drop_database(&self.url).await.unwrap();
554            });
555        }
556    }
557
558    #[gpui::test]
559    async fn test_get_users_by_ids() {
560        let test_db = TestDb::new();
561        let db = test_db.db();
562
563        let user = db.create_user("user", false).await.unwrap();
564        let friend1 = db.create_user("friend-1", false).await.unwrap();
565        let friend2 = db.create_user("friend-2", false).await.unwrap();
566        let friend3 = db.create_user("friend-3", false).await.unwrap();
567
568        assert_eq!(
569            db.get_users_by_ids([user, friend1, friend2, friend3])
570                .await
571                .unwrap(),
572            vec![
573                User {
574                    id: user,
575                    github_login: "user".to_string(),
576                    admin: false,
577                },
578                User {
579                    id: friend1,
580                    github_login: "friend-1".to_string(),
581                    admin: false,
582                },
583                User {
584                    id: friend2,
585                    github_login: "friend-2".to_string(),
586                    admin: false,
587                },
588                User {
589                    id: friend3,
590                    github_login: "friend-3".to_string(),
591                    admin: false,
592                }
593            ]
594        );
595    }
596
597    #[gpui::test]
598    async fn test_recent_channel_messages() {
599        let test_db = TestDb::new();
600        let db = test_db.db();
601        let user = db.create_user("user", false).await.unwrap();
602        let org = db.create_org("org", "org").await.unwrap();
603        let channel = db.create_org_channel(org, "channel").await.unwrap();
604        for i in 0..10 {
605            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
606                .await
607                .unwrap();
608        }
609
610        let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
611        assert_eq!(
612            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
613            ["5", "6", "7", "8", "9"]
614        );
615
616        let prev_messages = db
617            .get_channel_messages(channel, 4, Some(messages[0].id))
618            .await
619            .unwrap();
620        assert_eq!(
621            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
622            ["1", "2", "3", "4"]
623        );
624    }
625
626    #[gpui::test]
627    async fn test_channel_message_nonces() {
628        let test_db = TestDb::new();
629        let db = test_db.db();
630        let user = db.create_user("user", false).await.unwrap();
631        let org = db.create_org("org", "org").await.unwrap();
632        let channel = db.create_org_channel(org, "channel").await.unwrap();
633
634        let msg1_id = db
635            .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
636            .await
637            .unwrap();
638        let msg2_id = db
639            .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
640            .await
641            .unwrap();
642        let msg3_id = db
643            .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
644            .await
645            .unwrap();
646        let msg4_id = db
647            .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
648            .await
649            .unwrap();
650
651        assert_ne!(msg1_id, msg2_id);
652        assert_eq!(msg1_id, msg3_id);
653        assert_eq!(msg2_id, msg4_id);
654    }
655}