db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24#[derive(Clone)]
 25pub struct Db {
 26    pool: sqlx::PgPool,
 27    test_mode: bool,
 28}
 29
 30#[derive(Debug, FromRow, Serialize)]
 31pub struct User {
 32    pub id: UserId,
 33    pub github_login: String,
 34    pub admin: bool,
 35}
 36
 37#[derive(Debug, FromRow, Serialize)]
 38pub struct Signup {
 39    pub id: SignupId,
 40    pub github_login: String,
 41    pub email_address: String,
 42    pub about: String,
 43}
 44
 45#[derive(Debug, FromRow, Serialize)]
 46pub struct Channel {
 47    pub id: ChannelId,
 48    pub name: String,
 49}
 50
 51#[derive(Debug, FromRow)]
 52pub struct ChannelMessage {
 53    pub id: MessageId,
 54    pub sender_id: UserId,
 55    pub body: String,
 56    pub sent_at: OffsetDateTime,
 57}
 58
 59impl Db {
 60    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 61        let pool = DbOptions::new()
 62            .max_connections(max_connections)
 63            .connect(url)
 64            .await
 65            .context("failed to connect to postgres database")?;
 66        Ok(Self {
 67            pool,
 68            test_mode: false,
 69        })
 70    }
 71
 72    // signups
 73
 74    pub async fn create_signup(
 75        &self,
 76        github_login: &str,
 77        email_address: &str,
 78        about: &str,
 79    ) -> Result<SignupId> {
 80        test_support!(self, {
 81            let query = "
 82                INSERT INTO signups (github_login, email_address, about)
 83                VALUES ($1, $2, $3)
 84                RETURNING id
 85            ";
 86            sqlx::query_scalar(query)
 87                .bind(github_login)
 88                .bind(email_address)
 89                .bind(about)
 90                .fetch_one(&self.pool)
 91                .await
 92                .map(SignupId)
 93        })
 94    }
 95
 96    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 97        test_support!(self, {
 98            let query = "SELECT * FROM users ORDER BY github_login ASC";
 99            sqlx::query_as(query).fetch_all(&self.pool).await
100        })
101    }
102
103    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
104        test_support!(self, {
105            let query = "DELETE FROM signups WHERE id = $1";
106            sqlx::query(query)
107                .bind(id.0)
108                .execute(&self.pool)
109                .await
110                .map(drop)
111        })
112    }
113
114    // users
115
116    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
117        test_support!(self, {
118            let query = "
119                INSERT INTO users (github_login, admin)
120                VALUES ($1, $2)
121                RETURNING id
122            ";
123            sqlx::query_scalar(query)
124                .bind(github_login)
125                .bind(admin)
126                .fetch_one(&self.pool)
127                .await
128                .map(UserId)
129        })
130    }
131
132    pub async fn get_all_users(&self) -> Result<Vec<User>> {
133        test_support!(self, {
134            let query = "SELECT * FROM users ORDER BY github_login ASC";
135            sqlx::query_as(query).fetch_all(&self.pool).await
136        })
137    }
138
139    pub async fn get_users_by_ids(
140        &self,
141        requester_id: UserId,
142        ids: impl Iterator<Item = UserId>,
143    ) -> Result<Vec<User>> {
144        test_support!(self, {
145            // Only return users that are in a common channel with the requesting user.
146            let query = "
147                SELECT users.*
148                FROM
149                    users, channel_memberships
150                WHERE
151                    users.id = ANY ($1) AND
152                    channel_memberships.user_id = users.id AND
153                    channel_memberships.channel_id IN (
154                        SELECT channel_id
155                        FROM channel_memberships
156                        WHERE channel_memberships.user_id = $2
157                    )
158            ";
159
160            sqlx::query_as(query)
161                .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
162                .bind(requester_id)
163                .fetch_all(&self.pool)
164                .await
165        })
166    }
167
168    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
169        test_support!(self, {
170            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
171            sqlx::query_as(query)
172                .bind(github_login)
173                .fetch_optional(&self.pool)
174                .await
175        })
176    }
177
178    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
179        test_support!(self, {
180            let query = "UPDATE users SET admin = $1 WHERE id = $2";
181            sqlx::query(query)
182                .bind(is_admin)
183                .bind(id.0)
184                .execute(&self.pool)
185                .await
186                .map(drop)
187        })
188    }
189
190    pub async fn delete_user(&self, id: UserId) -> Result<()> {
191        test_support!(self, {
192            let query = "DELETE FROM users WHERE id = $1;";
193            sqlx::query(query)
194                .bind(id.0)
195                .execute(&self.pool)
196                .await
197                .map(drop)
198        })
199    }
200
201    // access tokens
202
203    pub async fn create_access_token_hash(
204        &self,
205        user_id: UserId,
206        access_token_hash: String,
207    ) -> Result<()> {
208        test_support!(self, {
209            let query = "
210            INSERT INTO access_tokens (user_id, hash)
211            VALUES ($1, $2)
212        ";
213            sqlx::query(query)
214                .bind(user_id.0)
215                .bind(access_token_hash)
216                .execute(&self.pool)
217                .await
218                .map(drop)
219        })
220    }
221
222    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
223        test_support!(self, {
224            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
225            sqlx::query_scalar(query)
226                .bind(user_id.0)
227                .fetch_all(&self.pool)
228                .await
229        })
230    }
231
232    // orgs
233
234    #[cfg(test)]
235    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
236        test_support!(self, {
237            let query = "
238                INSERT INTO orgs (name, slug)
239                VALUES ($1, $2)
240                RETURNING id
241            ";
242            sqlx::query_scalar(query)
243                .bind(name)
244                .bind(slug)
245                .fetch_one(&self.pool)
246                .await
247                .map(OrgId)
248        })
249    }
250
251    #[cfg(test)]
252    pub async fn add_org_member(
253        &self,
254        org_id: OrgId,
255        user_id: UserId,
256        is_admin: bool,
257    ) -> Result<()> {
258        test_support!(self, {
259            let query = "
260                INSERT INTO org_memberships (org_id, user_id, admin)
261                VALUES ($1, $2, $3)
262            ";
263            sqlx::query(query)
264                .bind(org_id.0)
265                .bind(user_id.0)
266                .bind(is_admin)
267                .execute(&self.pool)
268                .await
269                .map(drop)
270        })
271    }
272
273    // channels
274
275    #[cfg(test)]
276    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
277        test_support!(self, {
278            let query = "
279                INSERT INTO channels (owner_id, owner_is_user, name)
280                VALUES ($1, false, $2)
281                RETURNING id
282            ";
283            sqlx::query_scalar(query)
284                .bind(org_id.0)
285                .bind(name)
286                .fetch_one(&self.pool)
287                .await
288                .map(ChannelId)
289        })
290    }
291
292    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
293        test_support!(self, {
294            let query = "
295                SELECT
296                    channels.id, channels.name
297                FROM
298                    channel_memberships, channels
299                WHERE
300                    channel_memberships.user_id = $1 AND
301                    channel_memberships.channel_id = channels.id
302            ";
303            sqlx::query_as(query)
304                .bind(user_id.0)
305                .fetch_all(&self.pool)
306                .await
307        })
308    }
309
310    pub async fn can_user_access_channel(
311        &self,
312        user_id: UserId,
313        channel_id: ChannelId,
314    ) -> Result<bool> {
315        test_support!(self, {
316            let query = "
317                SELECT id
318                FROM channel_memberships
319                WHERE user_id = $1 AND channel_id = $2
320                LIMIT 1
321            ";
322            sqlx::query_scalar::<_, i32>(query)
323                .bind(user_id.0)
324                .bind(channel_id.0)
325                .fetch_optional(&self.pool)
326                .await
327                .map(|e| e.is_some())
328        })
329    }
330
331    #[cfg(test)]
332    pub async fn add_channel_member(
333        &self,
334        channel_id: ChannelId,
335        user_id: UserId,
336        is_admin: bool,
337    ) -> Result<()> {
338        test_support!(self, {
339            let query = "
340                INSERT INTO channel_memberships (channel_id, user_id, admin)
341                VALUES ($1, $2, $3)
342            ";
343            sqlx::query(query)
344                .bind(channel_id.0)
345                .bind(user_id.0)
346                .bind(is_admin)
347                .execute(&self.pool)
348                .await
349                .map(drop)
350        })
351    }
352
353    // messages
354
355    pub async fn create_channel_message(
356        &self,
357        channel_id: ChannelId,
358        sender_id: UserId,
359        body: &str,
360        timestamp: OffsetDateTime,
361    ) -> Result<MessageId> {
362        test_support!(self, {
363            let query = "
364                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
365                VALUES ($1, $2, $3, $4)
366                RETURNING id
367            ";
368            sqlx::query_scalar(query)
369                .bind(channel_id.0)
370                .bind(sender_id.0)
371                .bind(body)
372                .bind(timestamp)
373                .fetch_one(&self.pool)
374                .await
375                .map(MessageId)
376        })
377    }
378
379    pub async fn get_recent_channel_messages(
380        &self,
381        channel_id: ChannelId,
382        count: usize,
383        before_id: Option<MessageId>,
384    ) -> Result<Vec<ChannelMessage>> {
385        test_support!(self, {
386            let query = r#"
387                SELECT * FROM (
388                    SELECT
389                        id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
390                    FROM
391                        channel_messages
392                    WHERE
393                        channel_id = $1 AND
394                        id < $2
395                    ORDER BY id DESC
396                    LIMIT $3
397                ) as recent_messages
398                ORDER BY id ASC
399            "#;
400            sqlx::query_as(query)
401                .bind(channel_id.0)
402                .bind(before_id.unwrap_or(MessageId::MAX))
403                .bind(count as i64)
404                .fetch_all(&self.pool)
405                .await
406        })
407    }
408}
409
410macro_rules! id_type {
411    ($name:ident) => {
412        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
413        #[sqlx(transparent)]
414        #[serde(transparent)]
415        pub struct $name(pub i32);
416
417        impl $name {
418            #[allow(unused)]
419            pub const MAX: Self = Self(i32::MAX);
420
421            #[allow(unused)]
422            pub fn from_proto(value: u64) -> Self {
423                Self(value as i32)
424            }
425
426            #[allow(unused)]
427            pub fn to_proto(&self) -> u64 {
428                self.0 as u64
429            }
430        }
431    };
432}
433
434id_type!(UserId);
435id_type!(OrgId);
436id_type!(ChannelId);
437id_type!(SignupId);
438id_type!(MessageId);
439
440#[cfg(test)]
441pub mod tests {
442    use super::*;
443    use rand::prelude::*;
444    use sqlx::{
445        migrate::{MigrateDatabase, Migrator},
446        Postgres,
447    };
448    use std::path::Path;
449
450    pub struct TestDb {
451        pub db: Db,
452        pub name: String,
453        pub url: String,
454    }
455
456    impl TestDb {
457        pub fn new() -> Self {
458            // Enable tests to run in parallel by serializing the creation of each test database.
459            lazy_static::lazy_static! {
460                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
461            }
462
463            let mut rng = StdRng::from_entropy();
464            let name = format!("zed-test-{}", rng.gen::<u128>());
465            let url = format!("postgres://postgres@localhost/{}", name);
466            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
467            let db = block_on(async {
468                {
469                    let _lock = DB_CREATION.lock();
470                    Postgres::create_database(&url)
471                        .await
472                        .expect("failed to create test db");
473                }
474                let mut db = Db::new(&url, 5).await.unwrap();
475                db.test_mode = true;
476                let migrator = Migrator::new(migrations_path).await.unwrap();
477                migrator.run(&db.pool).await.unwrap();
478                db
479            });
480
481            Self { db, name, url }
482        }
483
484        pub fn db(&self) -> &Db {
485            &self.db
486        }
487    }
488
489    impl Drop for TestDb {
490        fn drop(&mut self) {
491            block_on(async {
492                let query = "
493                    SELECT pg_terminate_backend(pg_stat_activity.pid)
494                    FROM pg_stat_activity
495                    WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
496                ";
497                sqlx::query(query)
498                    .bind(&self.name)
499                    .execute(&self.db.pool)
500                    .await
501                    .unwrap();
502                self.db.pool.close().await;
503                Postgres::drop_database(&self.url).await.unwrap();
504            });
505        }
506    }
507
508    #[gpui::test]
509    async fn test_recent_channel_messages() {
510        let test_db = TestDb::new();
511        let db = test_db.db();
512        let user = db.create_user("user", false).await.unwrap();
513        let org = db.create_org("org", "org").await.unwrap();
514        let channel = db.create_org_channel(org, "channel").await.unwrap();
515        for i in 0..10 {
516            db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc())
517                .await
518                .unwrap();
519        }
520
521        let messages = db
522            .get_recent_channel_messages(channel, 5, None)
523            .await
524            .unwrap();
525        assert_eq!(
526            messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
527            ["5", "6", "7", "8", "9"]
528        );
529
530        let prev_messages = db
531            .get_recent_channel_messages(channel, 4, Some(messages[0].id))
532            .await
533            .unwrap();
534        assert_eq!(
535            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
536            ["1", "2", "3", "4"]
537        );
538    }
539}