db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24#[derive(Clone)]
 25pub struct Db {
 26    pool: sqlx::PgPool,
 27    test_mode: bool,
 28}
 29
 30impl Db {
 31    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 32        let pool = DbOptions::new()
 33            .max_connections(max_connections)
 34            .connect(url)
 35            .await
 36            .context("failed to connect to postgres database")?;
 37        Ok(Self {
 38            pool,
 39            test_mode: false,
 40        })
 41    }
 42
 43    // signups
 44
 45    pub async fn create_signup(
 46        &self,
 47        github_login: &str,
 48        email_address: &str,
 49        about: &str,
 50    ) -> Result<SignupId> {
 51        test_support!(self, {
 52            let query = "
 53                INSERT INTO signups (github_login, email_address, about)
 54                VALUES ($1, $2, $3)
 55                RETURNING id
 56            ";
 57            sqlx::query_scalar(query)
 58                .bind(github_login)
 59                .bind(email_address)
 60                .bind(about)
 61                .fetch_one(&self.pool)
 62                .await
 63                .map(SignupId)
 64        })
 65    }
 66
 67    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 68        test_support!(self, {
 69            let query = "SELECT * FROM users ORDER BY github_login ASC";
 70            sqlx::query_as(query).fetch_all(&self.pool).await
 71        })
 72    }
 73
 74    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
 75        test_support!(self, {
 76            let query = "DELETE FROM signups WHERE id = $1";
 77            sqlx::query(query)
 78                .bind(id.0)
 79                .execute(&self.pool)
 80                .await
 81                .map(drop)
 82        })
 83    }
 84
 85    // users
 86
 87    #[allow(unused)] // Help rust-analyzer
 88    #[cfg(any(test, feature = "seed-support"))]
 89    pub async fn get_user(&self, github_login: &str) -> Result<Option<UserId>> {
 90        test_support!(self, {
 91            let query = "
 92                SELECT id
 93                FROM users
 94                WHERE github_login = $1
 95            ";
 96            sqlx::query_scalar(query)
 97                .bind(github_login)
 98                .fetch_optional(&self.pool)
 99                .await
100        })
101    }
102
103    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
104        test_support!(self, {
105            let query = "
106                INSERT INTO users (github_login, admin)
107                VALUES ($1, $2)
108                RETURNING id
109            ";
110            sqlx::query_scalar(query)
111                .bind(github_login)
112                .bind(admin)
113                .fetch_one(&self.pool)
114                .await
115                .map(UserId)
116        })
117    }
118
119    pub async fn get_all_users(&self) -> Result<Vec<User>> {
120        test_support!(self, {
121            let query = "SELECT * FROM users ORDER BY github_login ASC";
122            sqlx::query_as(query).fetch_all(&self.pool).await
123        })
124    }
125
126    pub async fn get_users_by_ids(
127        &self,
128        requester_id: UserId,
129        ids: impl Iterator<Item = UserId>,
130    ) -> Result<Vec<User>> {
131        test_support!(self, {
132            // Only return users that are in a common channel with the requesting user.
133            let query = "
134                SELECT users.*
135                FROM
136                    users, channel_memberships
137                WHERE
138                    users.id = ANY ($1) AND
139                    channel_memberships.user_id = users.id AND
140                    channel_memberships.channel_id IN (
141                        SELECT channel_id
142                        FROM channel_memberships
143                        WHERE channel_memberships.user_id = $2
144                    )
145            ";
146
147            sqlx::query_as(query)
148                .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
149                .bind(requester_id)
150                .fetch_all(&self.pool)
151                .await
152        })
153    }
154
155    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
156        test_support!(self, {
157            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
158            sqlx::query_as(query)
159                .bind(github_login)
160                .fetch_optional(&self.pool)
161                .await
162        })
163    }
164
165    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
166        test_support!(self, {
167            let query = "UPDATE users SET admin = $1 WHERE id = $2";
168            sqlx::query(query)
169                .bind(is_admin)
170                .bind(id.0)
171                .execute(&self.pool)
172                .await
173                .map(drop)
174        })
175    }
176
177    pub async fn delete_user(&self, id: UserId) -> Result<()> {
178        test_support!(self, {
179            let query = "DELETE FROM users WHERE id = $1;";
180            sqlx::query(query)
181                .bind(id.0)
182                .execute(&self.pool)
183                .await
184                .map(drop)
185        })
186    }
187
188    // access tokens
189
190    pub async fn create_access_token_hash(
191        &self,
192        user_id: UserId,
193        access_token_hash: String,
194    ) -> Result<()> {
195        test_support!(self, {
196            let query = "
197            INSERT INTO access_tokens (user_id, hash)
198            VALUES ($1, $2)
199        ";
200            sqlx::query(query)
201                .bind(user_id.0)
202                .bind(access_token_hash)
203                .execute(&self.pool)
204                .await
205                .map(drop)
206        })
207    }
208
209    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
210        test_support!(self, {
211            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
212            sqlx::query_scalar(query)
213                .bind(user_id.0)
214                .fetch_all(&self.pool)
215                .await
216        })
217    }
218
219    // orgs
220
221    #[allow(unused)] // Help rust-analyzer
222    #[cfg(any(test, feature = "seed-support"))]
223    pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
224        test_support!(self, {
225            let query = "
226                SELECT *
227                FROM orgs
228                WHERE slug = $1
229            ";
230            sqlx::query_as(query)
231                .bind(slug)
232                .fetch_optional(&self.pool)
233                .await
234        })
235    }
236
237    #[cfg(any(test, feature = "seed-support"))]
238    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
239        test_support!(self, {
240            let query = "
241                INSERT INTO orgs (name, slug)
242                VALUES ($1, $2)
243                RETURNING id
244            ";
245            sqlx::query_scalar(query)
246                .bind(name)
247                .bind(slug)
248                .fetch_one(&self.pool)
249                .await
250                .map(OrgId)
251        })
252    }
253
254    #[cfg(any(test, feature = "seed-support"))]
255    pub async fn add_org_member(
256        &self,
257        org_id: OrgId,
258        user_id: UserId,
259        is_admin: bool,
260    ) -> Result<()> {
261        test_support!(self, {
262            let query = "
263                INSERT INTO org_memberships (org_id, user_id, admin)
264                VALUES ($1, $2, $3)
265                ON CONFLICT DO NOTHING
266            ";
267            sqlx::query(query)
268                .bind(org_id.0)
269                .bind(user_id.0)
270                .bind(is_admin)
271                .execute(&self.pool)
272                .await
273                .map(drop)
274        })
275    }
276
277    // channels
278
279    #[cfg(any(test, feature = "seed-support"))]
280    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
281        test_support!(self, {
282            let query = "
283                INSERT INTO channels (owner_id, owner_is_user, name)
284                VALUES ($1, false, $2)
285                RETURNING id
286            ";
287            sqlx::query_scalar(query)
288                .bind(org_id.0)
289                .bind(name)
290                .fetch_one(&self.pool)
291                .await
292                .map(ChannelId)
293        })
294    }
295
296    #[allow(unused)] // Help rust-analyzer
297    #[cfg(any(test, feature = "seed-support"))]
298    pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
299        test_support!(self, {
300            let query = "
301                SELECT *
302                FROM channels
303                WHERE
304                    channels.owner_is_user = false AND
305                    channels.owner_id = $1
306            ";
307            sqlx::query_as(query)
308                .bind(org_id.0)
309                .fetch_all(&self.pool)
310                .await
311        })
312    }
313
314    pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
315        test_support!(self, {
316            let query = "
317                SELECT
318                    channels.id, channels.name
319                FROM
320                    channel_memberships, channels
321                WHERE
322                    channel_memberships.user_id = $1 AND
323                    channel_memberships.channel_id = channels.id
324            ";
325            sqlx::query_as(query)
326                .bind(user_id.0)
327                .fetch_all(&self.pool)
328                .await
329        })
330    }
331
332    pub async fn can_user_access_channel(
333        &self,
334        user_id: UserId,
335        channel_id: ChannelId,
336    ) -> Result<bool> {
337        test_support!(self, {
338            let query = "
339                SELECT id
340                FROM channel_memberships
341                WHERE user_id = $1 AND channel_id = $2
342                LIMIT 1
343            ";
344            sqlx::query_scalar::<_, i32>(query)
345                .bind(user_id.0)
346                .bind(channel_id.0)
347                .fetch_optional(&self.pool)
348                .await
349                .map(|e| e.is_some())
350        })
351    }
352
353    #[cfg(any(test, feature = "seed-support"))]
354    pub async fn add_channel_member(
355        &self,
356        channel_id: ChannelId,
357        user_id: UserId,
358        is_admin: bool,
359    ) -> Result<()> {
360        test_support!(self, {
361            let query = "
362                INSERT INTO channel_memberships (channel_id, user_id, admin)
363                VALUES ($1, $2, $3)
364                ON CONFLICT DO NOTHING
365            ";
366            sqlx::query(query)
367                .bind(channel_id.0)
368                .bind(user_id.0)
369                .bind(is_admin)
370                .execute(&self.pool)
371                .await
372                .map(drop)
373        })
374    }
375
376    // messages
377
378    pub async fn create_channel_message(
379        &self,
380        channel_id: ChannelId,
381        sender_id: UserId,
382        body: &str,
383        timestamp: OffsetDateTime,
384    ) -> Result<MessageId> {
385        test_support!(self, {
386            let query = "
387                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
388                VALUES ($1, $2, $3, $4)
389                RETURNING id
390            ";
391            sqlx::query_scalar(query)
392                .bind(channel_id.0)
393                .bind(sender_id.0)
394                .bind(body)
395                .bind(timestamp)
396                .fetch_one(&self.pool)
397                .await
398                .map(MessageId)
399        })
400    }
401
402    pub async fn get_channel_messages(
403        &self,
404        channel_id: ChannelId,
405        count: usize,
406        before_id: Option<MessageId>,
407    ) -> Result<Vec<ChannelMessage>> {
408        test_support!(self, {
409            let query = r#"
410                SELECT * FROM (
411                    SELECT
412                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
413                    FROM
414                        channel_messages
415                    WHERE
416                        channel_id = $1 AND
417                        id < $2
418                    ORDER BY id DESC
419                    LIMIT $3
420                ) as recent_messages
421                ORDER BY id ASC
422            "#;
423            sqlx::query_as(query)
424                .bind(channel_id.0)
425                .bind(before_id.unwrap_or(MessageId::MAX))
426                .bind(count as i64)
427                .fetch_all(&self.pool)
428                .await
429        })
430    }
431}
432
433macro_rules! id_type {
434    ($name:ident) => {
435        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
436        #[sqlx(transparent)]
437        #[serde(transparent)]
438        pub struct $name(pub i32);
439
440        impl $name {
441            #[allow(unused)]
442            pub const MAX: Self = Self(i32::MAX);
443
444            #[allow(unused)]
445            pub fn from_proto(value: u64) -> Self {
446                Self(value as i32)
447            }
448
449            #[allow(unused)]
450            pub fn to_proto(&self) -> u64 {
451                self.0 as u64
452            }
453        }
454    };
455}
456
457id_type!(UserId);
458#[derive(Debug, FromRow, Serialize)]
459pub struct User {
460    pub id: UserId,
461    pub github_login: String,
462    pub admin: bool,
463}
464
465id_type!(OrgId);
466#[derive(FromRow)]
467pub struct Org {
468    pub id: OrgId,
469    pub name: String,
470    pub slug: String,
471}
472
473id_type!(SignupId);
474#[derive(Debug, FromRow, Serialize)]
475pub struct Signup {
476    pub id: SignupId,
477    pub github_login: String,
478    pub email_address: String,
479    pub about: String,
480}
481
482id_type!(ChannelId);
483#[derive(Debug, FromRow, Serialize)]
484pub struct Channel {
485    pub id: ChannelId,
486    pub name: String,
487}
488
489id_type!(MessageId);
490#[derive(Debug, FromRow)]
491pub struct ChannelMessage {
492    pub id: MessageId,
493    pub sender_id: UserId,
494    pub body: String,
495    pub sent_at: OffsetDateTime,
496}
497
498#[cfg(test)]
499pub mod tests {
500    use super::*;
501    use rand::prelude::*;
502    use sqlx::{
503        migrate::{MigrateDatabase, Migrator},
504        Postgres,
505    };
506    use std::path::Path;
507
508    pub struct TestDb {
509        pub db: Db,
510        pub name: String,
511        pub url: String,
512    }
513
514    impl TestDb {
515        pub fn new() -> Self {
516            // Enable tests to run in parallel by serializing the creation of each test database.
517            lazy_static::lazy_static! {
518                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
519            }
520
521            let mut rng = StdRng::from_entropy();
522            let name = format!("zed-test-{}", rng.gen::<u128>());
523            let url = format!("postgres://postgres@localhost/{}", name);
524            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
525            let db = block_on(async {
526                {
527                    let _lock = DB_CREATION.lock();
528                    Postgres::create_database(&url)
529                        .await
530                        .expect("failed to create test db");
531                }
532                let mut db = Db::new(&url, 5).await.unwrap();
533                db.test_mode = true;
534                let migrator = Migrator::new(migrations_path).await.unwrap();
535                migrator.run(&db.pool).await.unwrap();
536                db
537            });
538
539            Self { db, name, url }
540        }
541
542        pub fn db(&self) -> &Db {
543            &self.db
544        }
545    }
546
547    impl Drop for TestDb {
548        fn drop(&mut self) {
549            block_on(async {
550                let query = "
551                    SELECT pg_terminate_backend(pg_stat_activity.pid)
552                    FROM pg_stat_activity
553                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
554                ";
555                sqlx::query(query)
556                    .bind(&self.name)
557                    .execute(&self.db.pool)
558                    .await
559                    .unwrap();
560                self.db.pool.close().await;
561                Postgres::drop_database(&self.url).await.unwrap();
562            });
563        }
564    }
565
566    #[gpui::test]
567    async fn test_recent_channel_messages() {
568        let test_db = TestDb::new();
569        let db = test_db.db();
570        let user = db.create_user("user", false).await.unwrap();
571        let org = db.create_org("org", "org").await.unwrap();
572        let channel = db.create_org_channel(org, "channel").await.unwrap();
573        for i in 0..10 {
574            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc())
575                .await
576                .unwrap();
577        }
578
579        let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
580        assert_eq!(
581            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
582            ["5", "6", "7", "8", "9"]
583        );
584
585        let prev_messages = db
586            .get_channel_messages(channel, 4, Some(messages[0].id))
587            .await
588            .unwrap();
589        assert_eq!(
590            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
591            ["1", "2", "3", "4"]
592        );
593    }
594}