db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24pub struct Db {
 25    db: sqlx::PgPool,
 26    test_mode: bool,
 27}
 28
 29#[derive(Debug, FromRow, Serialize)]
 30pub struct User {
 31    pub id: UserId,
 32    pub github_login: String,
 33    pub admin: bool,
 34}
 35
 36#[derive(Debug, FromRow, Serialize)]
 37pub struct Signup {
 38    pub id: SignupId,
 39    pub github_login: String,
 40    pub email_address: String,
 41    pub about: String,
 42}
 43
 44#[derive(Debug, FromRow, Serialize)]
 45pub struct Channel {
 46    pub id: ChannelId,
 47    pub name: String,
 48}
 49
 50#[derive(Debug, FromRow)]
 51pub struct ChannelMessage {
 52    pub id: MessageId,
 53    pub sender_id: UserId,
 54    pub body: String,
 55    pub sent_at: OffsetDateTime,
 56}
 57
 58impl Db {
 59    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 60        let db = DbOptions::new()
 61            .max_connections(max_connections)
 62            .connect(url)
 63            .await
 64            .context("failed to connect to postgres database")?;
 65        Ok(Self {
 66            db,
 67            test_mode: false,
 68        })
 69    }
 70
 71    // signups
 72
 73    pub async fn create_signup(
 74        &self,
 75        github_login: &str,
 76        email_address: &str,
 77        about: &str,
 78    ) -> Result<SignupId> {
 79        test_support!(self, {
 80            let query = "
 81                INSERT INTO signups (github_login, email_address, about)
 82                VALUES ($1, $2, $3)
 83                RETURNING id
 84            ";
 85            sqlx::query_scalar(query)
 86                .bind(github_login)
 87                .bind(email_address)
 88                .bind(about)
 89                .fetch_one(&self.db)
 90                .await
 91                .map(SignupId)
 92        })
 93    }
 94
 95    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 96        test_support!(self, {
 97            let query = "SELECT * FROM users ORDER BY github_login ASC";
 98            sqlx::query_as(query).fetch_all(&self.db).await
 99        })
100    }
101
102    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
103        test_support!(self, {
104            let query = "DELETE FROM signups WHERE id = $1";
105            sqlx::query(query)
106                .bind(id.0)
107                .execute(&self.db)
108                .await
109                .map(drop)
110        })
111    }
112
113    // users
114
115    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
116        test_support!(self, {
117            let query = "
118                INSERT INTO users (github_login, admin)
119                VALUES ($1, $2)
120                RETURNING id
121            ";
122            sqlx::query_scalar(query)
123                .bind(github_login)
124                .bind(admin)
125                .fetch_one(&self.db)
126                .await
127                .map(UserId)
128        })
129    }
130
131    pub async fn get_all_users(&self) -> Result<Vec<User>> {
132        test_support!(self, {
133            let query = "SELECT * FROM users ORDER BY github_login ASC";
134            sqlx::query_as(query).fetch_all(&self.db).await
135        })
136    }
137
138    pub async fn get_users_by_ids(
139        &self,
140        requester_id: UserId,
141        ids: impl Iterator<Item = UserId>,
142    ) -> Result<Vec<User>> {
143        test_support!(self, {
144            // Only return users that are in a common channel with the requesting user.
145            let query = "
146                SELECT users.*
147                FROM
148                    users, channel_memberships
149                WHERE
150                    users.id = ANY ($1) AND
151                    channel_memberships.user_id = users.id AND
152                    channel_memberships.channel_id IN (
153                        SELECT channel_id
154                        FROM channel_memberships
155                        WHERE channel_memberships.user_id = $2
156                    )
157            ";
158
159            sqlx::query_as(query)
160                .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
161                .bind(requester_id)
162                .fetch_all(&self.db)
163                .await
164        })
165    }
166
167    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
168        test_support!(self, {
169            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
170            sqlx::query_as(query)
171                .bind(github_login)
172                .fetch_optional(&self.db)
173                .await
174        })
175    }
176
177    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
178        test_support!(self, {
179            let query = "UPDATE users SET admin = $1 WHERE id = $2";
180            sqlx::query(query)
181                .bind(is_admin)
182                .bind(id.0)
183                .execute(&self.db)
184                .await
185                .map(drop)
186        })
187    }
188
189    pub async fn delete_user(&self, id: UserId) -> Result<()> {
190        test_support!(self, {
191            let query = "DELETE FROM users WHERE id = $1;";
192            sqlx::query(query)
193                .bind(id.0)
194                .execute(&self.db)
195                .await
196                .map(drop)
197        })
198    }
199
200    // access tokens
201
202    pub async fn create_access_token_hash(
203        &self,
204        user_id: UserId,
205        access_token_hash: String,
206    ) -> Result<()> {
207        test_support!(self, {
208            let query = "
209            INSERT INTO access_tokens (user_id, hash)
210            VALUES ($1, $2)
211        ";
212            sqlx::query(query)
213                .bind(user_id.0)
214                .bind(access_token_hash)
215                .execute(&self.db)
216                .await
217                .map(drop)
218        })
219    }
220
221    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
222        test_support!(self, {
223            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
224            sqlx::query_scalar(query)
225                .bind(user_id.0)
226                .fetch_all(&self.db)
227                .await
228        })
229    }
230
231    // orgs
232
233    #[cfg(test)]
234    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
235        test_support!(self, {
236            let query = "
237                INSERT INTO orgs (name, slug)
238                VALUES ($1, $2)
239                RETURNING id
240            ";
241            sqlx::query_scalar(query)
242                .bind(name)
243                .bind(slug)
244                .fetch_one(&self.db)
245                .await
246                .map(OrgId)
247        })
248    }
249
250    #[cfg(test)]
251    pub async fn add_org_member(
252        &self,
253        org_id: OrgId,
254        user_id: UserId,
255        is_admin: bool,
256    ) -> Result<()> {
257        test_support!(self, {
258            let query = "
259                INSERT INTO org_memberships (org_id, user_id, admin)
260                VALUES ($1, $2, $3)
261            ";
262            sqlx::query(query)
263                .bind(org_id.0)
264                .bind(user_id.0)
265                .bind(is_admin)
266                .execute(&self.db)
267                .await
268                .map(drop)
269        })
270    }
271
272    // channels
273
274    #[cfg(test)]
275    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
276        test_support!(self, {
277            let query = "
278                INSERT INTO channels (owner_id, owner_is_user, name)
279                VALUES ($1, false, $2)
280                RETURNING id
281            ";
282            sqlx::query_scalar(query)
283                .bind(org_id.0)
284                .bind(name)
285                .fetch_one(&self.db)
286                .await
287                .map(ChannelId)
288        })
289    }
290
291    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
292        test_support!(self, {
293            let query = "
294                SELECT
295                    channels.id, channels.name
296                FROM
297                    channel_memberships, channels
298                WHERE
299                    channel_memberships.user_id = $1 AND
300                    channel_memberships.channel_id = channels.id
301            ";
302            sqlx::query_as(query)
303                .bind(user_id.0)
304                .fetch_all(&self.db)
305                .await
306        })
307    }
308
309    pub async fn can_user_access_channel(
310        &self,
311        user_id: UserId,
312        channel_id: ChannelId,
313    ) -> Result<bool> {
314        test_support!(self, {
315            let query = "
316                SELECT id
317                FROM channel_memberships
318                WHERE user_id = $1 AND channel_id = $2
319                LIMIT 1
320            ";
321            sqlx::query_scalar::<_, i32>(query)
322                .bind(user_id.0)
323                .bind(channel_id.0)
324                .fetch_optional(&self.db)
325                .await
326                .map(|e| e.is_some())
327        })
328    }
329
330    #[cfg(test)]
331    pub async fn add_channel_member(
332        &self,
333        channel_id: ChannelId,
334        user_id: UserId,
335        is_admin: bool,
336    ) -> Result<()> {
337        test_support!(self, {
338            let query = "
339                INSERT INTO channel_memberships (channel_id, user_id, admin)
340                VALUES ($1, $2, $3)
341            ";
342            sqlx::query(query)
343                .bind(channel_id.0)
344                .bind(user_id.0)
345                .bind(is_admin)
346                .execute(&self.db)
347                .await
348                .map(drop)
349        })
350    }
351
352    // messages
353
354    pub async fn create_channel_message(
355        &self,
356        channel_id: ChannelId,
357        sender_id: UserId,
358        body: &str,
359        timestamp: OffsetDateTime,
360    ) -> Result<MessageId> {
361        test_support!(self, {
362            let query = "
363                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
364                VALUES ($1, $2, $3, $4)
365                RETURNING id
366            ";
367            sqlx::query_scalar(query)
368                .bind(channel_id.0)
369                .bind(sender_id.0)
370                .bind(body)
371                .bind(timestamp)
372                .fetch_one(&self.db)
373                .await
374                .map(MessageId)
375        })
376    }
377
378    pub async fn get_recent_channel_messages(
379        &self,
380        channel_id: ChannelId,
381        count: usize,
382    ) -> Result<Vec<ChannelMessage>> {
383        test_support!(self, {
384            let query = r#"
385                SELECT
386                    id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
387                FROM
388                    channel_messages
389                WHERE
390                    channel_id = $1
391                LIMIT $2
392            "#;
393            sqlx::query_as(query)
394                .bind(channel_id.0)
395                .bind(count as i64)
396                .fetch_all(&self.db)
397                .await
398        })
399    }
400
401    #[cfg(test)]
402    pub async fn close(&self, db_name: &str) {
403        test_support!(self, {
404            let query = "
405                SELECT pg_terminate_backend(pg_stat_activity.pid)
406                FROM pg_stat_activity
407                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
408            ";
409            sqlx::query(query)
410                .bind(db_name)
411                .execute(&self.db)
412                .await
413                .unwrap();
414            self.db.close().await;
415        })
416    }
417}
418
419macro_rules! id_type {
420    ($name:ident) => {
421        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
422        #[sqlx(transparent)]
423        #[serde(transparent)]
424        pub struct $name(pub i32);
425
426        impl $name {
427            #[allow(unused)]
428            pub fn from_proto(value: u64) -> Self {
429                Self(value as i32)
430            }
431
432            #[allow(unused)]
433            pub fn to_proto(&self) -> u64 {
434                self.0 as u64
435            }
436        }
437    };
438}
439
440id_type!(UserId);
441id_type!(OrgId);
442id_type!(ChannelId);
443id_type!(SignupId);
444id_type!(MessageId);
445
446#[cfg(test)]
447pub mod tests {
448    use super::*;
449    use rand::prelude::*;
450    use sqlx::{
451        migrate::{MigrateDatabase, Migrator},
452        Postgres,
453    };
454    use std::path::Path;
455
456    pub struct TestDb {
457        pub name: String,
458        pub url: String,
459    }
460
461    impl TestDb {
462        pub fn new() -> (Self, Db) {
463            // Enable tests to run in parallel by serializing the creation of each test database.
464            lazy_static::lazy_static! {
465                static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
466            }
467
468            let mut rng = StdRng::from_entropy();
469            let name = format!("zed-test-{}", rng.gen::<u128>());
470            let url = format!("postgres://postgres@localhost/{}", name);
471            let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
472            let db = block_on(async {
473                {
474                    let _lock = DB_CREATION.lock();
475                    Postgres::create_database(&url)
476                        .await
477                        .expect("failed to create test db");
478                }
479                let mut db = Db::new(&url, 5).await.unwrap();
480                db.test_mode = true;
481                let migrator = Migrator::new(migrations_path).await.unwrap();
482                migrator.run(&db.db).await.unwrap();
483                db
484            });
485
486            (Self { name, url }, db)
487        }
488    }
489
490    impl Drop for TestDb {
491        fn drop(&mut self) {
492            block_on(Postgres::drop_database(&self.url)).unwrap();
493        }
494    }
495}