db.rs

  1use anyhow::Context;
  2use async_std::task::{block_on, yield_now};
  3use serde::Serialize;
  4use sqlx::{FromRow, Result};
  5use time::OffsetDateTime;
  6
  7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  8pub use sqlx::postgres::PgPoolOptions as DbOptions;
  9
 10macro_rules! test_support {
 11    ($self:ident, { $($token:tt)* }) => {{
 12        let body = async {
 13            $($token)*
 14        };
 15        if $self.test_mode {
 16            yield_now().await;
 17            block_on(body)
 18        } else {
 19            body.await
 20        }
 21    }};
 22}
 23
 24pub struct Db {
 25    db: sqlx::PgPool,
 26    test_mode: bool,
 27}
 28
 29#[derive(Debug, FromRow, Serialize)]
 30pub struct User {
 31    pub id: UserId,
 32    pub github_login: String,
 33    pub admin: bool,
 34}
 35
 36#[derive(Debug, FromRow, Serialize)]
 37pub struct Signup {
 38    pub id: SignupId,
 39    pub github_login: String,
 40    pub email_address: String,
 41    pub about: String,
 42}
 43
 44#[derive(Debug, FromRow, Serialize)]
 45pub struct Channel {
 46    pub id: ChannelId,
 47    pub name: String,
 48}
 49
 50#[derive(Debug, FromRow)]
 51pub struct ChannelMessage {
 52    pub id: MessageId,
 53    pub sender_id: UserId,
 54    pub body: String,
 55    pub sent_at: OffsetDateTime,
 56}
 57
 58impl Db {
 59    pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
 60        let db = DbOptions::new()
 61            .max_connections(max_connections)
 62            .connect(url)
 63            .await
 64            .context("failed to connect to postgres database")?;
 65        Ok(Self {
 66            db,
 67            test_mode: false,
 68        })
 69    }
 70
 71    #[cfg(test)]
 72    pub fn test(url: &str, max_connections: u32) -> Self {
 73        let mut db = block_on(Self::new(url, max_connections)).unwrap();
 74        db.test_mode = true;
 75        db
 76    }
 77
 78    #[cfg(test)]
 79    pub fn migrate(&self, path: &std::path::Path) {
 80        block_on(async {
 81            let migrator = sqlx::migrate::Migrator::new(path).await.unwrap();
 82            migrator.run(&self.db).await.unwrap();
 83        });
 84    }
 85
 86    // signups
 87
 88    pub async fn create_signup(
 89        &self,
 90        github_login: &str,
 91        email_address: &str,
 92        about: &str,
 93    ) -> Result<SignupId> {
 94        test_support!(self, {
 95            let query = "
 96                INSERT INTO signups (github_login, email_address, about)
 97                VALUES ($1, $2, $3)
 98                RETURNING id
 99            ";
100            sqlx::query_scalar(query)
101                .bind(github_login)
102                .bind(email_address)
103                .bind(about)
104                .fetch_one(&self.db)
105                .await
106                .map(SignupId)
107        })
108    }
109
110    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
111        test_support!(self, {
112            let query = "SELECT * FROM users ORDER BY github_login ASC";
113            sqlx::query_as(query).fetch_all(&self.db).await
114        })
115    }
116
117    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
118        test_support!(self, {
119            let query = "DELETE FROM signups WHERE id = $1";
120            sqlx::query(query)
121                .bind(id.0)
122                .execute(&self.db)
123                .await
124                .map(drop)
125        })
126    }
127
128    // users
129
130    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
131        test_support!(self, {
132            let query = "
133                INSERT INTO users (github_login, admin)
134                VALUES ($1, $2)
135                RETURNING id
136            ";
137            sqlx::query_scalar(query)
138                .bind(github_login)
139                .bind(admin)
140                .fetch_one(&self.db)
141                .await
142                .map(UserId)
143        })
144    }
145
146    pub async fn get_all_users(&self) -> Result<Vec<User>> {
147        test_support!(self, {
148            let query = "SELECT * FROM users ORDER BY github_login ASC";
149            sqlx::query_as(query).fetch_all(&self.db).await
150        })
151    }
152
153    pub async fn get_users_by_ids(
154        &self,
155        requester_id: UserId,
156        ids: impl Iterator<Item = UserId>,
157    ) -> Result<Vec<User>> {
158        test_support!(self, {
159            // Only return users that are in a common channel with the requesting user.
160            let query = "
161                SELECT users.*
162                FROM
163                    users, channel_memberships
164                WHERE
165                    users.id = ANY ($1) AND
166                    channel_memberships.user_id = users.id AND
167                    channel_memberships.channel_id IN (
168                        SELECT channel_id
169                        FROM channel_memberships
170                        WHERE channel_memberships.user_id = $2
171                    )
172            ";
173
174            sqlx::query_as(query)
175                .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
176                .bind(requester_id)
177                .fetch_all(&self.db)
178                .await
179        })
180    }
181
182    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
183        test_support!(self, {
184            let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
185            sqlx::query_as(query)
186                .bind(github_login)
187                .fetch_optional(&self.db)
188                .await
189        })
190    }
191
192    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
193        test_support!(self, {
194            let query = "UPDATE users SET admin = $1 WHERE id = $2";
195            sqlx::query(query)
196                .bind(is_admin)
197                .bind(id.0)
198                .execute(&self.db)
199                .await
200                .map(drop)
201        })
202    }
203
204    pub async fn delete_user(&self, id: UserId) -> Result<()> {
205        test_support!(self, {
206            let query = "DELETE FROM users WHERE id = $1;";
207            sqlx::query(query)
208                .bind(id.0)
209                .execute(&self.db)
210                .await
211                .map(drop)
212        })
213    }
214
215    // access tokens
216
217    pub async fn create_access_token_hash(
218        &self,
219        user_id: UserId,
220        access_token_hash: String,
221    ) -> Result<()> {
222        test_support!(self, {
223            let query = "
224            INSERT INTO access_tokens (user_id, hash)
225            VALUES ($1, $2)
226        ";
227            sqlx::query(query)
228                .bind(user_id.0)
229                .bind(access_token_hash)
230                .execute(&self.db)
231                .await
232                .map(drop)
233        })
234    }
235
236    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
237        test_support!(self, {
238            let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
239            sqlx::query_scalar(query)
240                .bind(user_id.0)
241                .fetch_all(&self.db)
242                .await
243        })
244    }
245
246    // orgs
247
248    #[cfg(test)]
249    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
250        test_support!(self, {
251            let query = "
252                INSERT INTO orgs (name, slug)
253                VALUES ($1, $2)
254                RETURNING id
255            ";
256            sqlx::query_scalar(query)
257                .bind(name)
258                .bind(slug)
259                .fetch_one(&self.db)
260                .await
261                .map(OrgId)
262        })
263    }
264
265    #[cfg(test)]
266    pub async fn add_org_member(
267        &self,
268        org_id: OrgId,
269        user_id: UserId,
270        is_admin: bool,
271    ) -> Result<()> {
272        test_support!(self, {
273            let query = "
274                INSERT INTO org_memberships (org_id, user_id, admin)
275                VALUES ($1, $2, $3)
276            ";
277            sqlx::query(query)
278                .bind(org_id.0)
279                .bind(user_id.0)
280                .bind(is_admin)
281                .execute(&self.db)
282                .await
283                .map(drop)
284        })
285    }
286
287    // channels
288
289    #[cfg(test)]
290    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
291        test_support!(self, {
292            let query = "
293                INSERT INTO channels (owner_id, owner_is_user, name)
294                VALUES ($1, false, $2)
295                RETURNING id
296            ";
297            sqlx::query_scalar(query)
298                .bind(org_id.0)
299                .bind(name)
300                .fetch_one(&self.db)
301                .await
302                .map(ChannelId)
303        })
304    }
305
306    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
307        test_support!(self, {
308            let query = "
309                SELECT
310                    channels.id, channels.name
311                FROM
312                    channel_memberships, channels
313                WHERE
314                    channel_memberships.user_id = $1 AND
315                    channel_memberships.channel_id = channels.id
316            ";
317            sqlx::query_as(query)
318                .bind(user_id.0)
319                .fetch_all(&self.db)
320                .await
321        })
322    }
323
324    pub async fn can_user_access_channel(
325        &self,
326        user_id: UserId,
327        channel_id: ChannelId,
328    ) -> Result<bool> {
329        test_support!(self, {
330            let query = "
331                SELECT id
332                FROM channel_memberships
333                WHERE user_id = $1 AND channel_id = $2
334                LIMIT 1
335            ";
336            sqlx::query_scalar::<_, i32>(query)
337                .bind(user_id.0)
338                .bind(channel_id.0)
339                .fetch_optional(&self.db)
340                .await
341                .map(|e| e.is_some())
342        })
343    }
344
345    #[cfg(test)]
346    pub async fn add_channel_member(
347        &self,
348        channel_id: ChannelId,
349        user_id: UserId,
350        is_admin: bool,
351    ) -> Result<()> {
352        test_support!(self, {
353            let query = "
354                INSERT INTO channel_memberships (channel_id, user_id, admin)
355                VALUES ($1, $2, $3)
356            ";
357            sqlx::query(query)
358                .bind(channel_id.0)
359                .bind(user_id.0)
360                .bind(is_admin)
361                .execute(&self.db)
362                .await
363                .map(drop)
364        })
365    }
366
367    // messages
368
369    pub async fn create_channel_message(
370        &self,
371        channel_id: ChannelId,
372        sender_id: UserId,
373        body: &str,
374        timestamp: OffsetDateTime,
375    ) -> Result<MessageId> {
376        test_support!(self, {
377            let query = "
378                INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
379                VALUES ($1, $2, $3, $4)
380                RETURNING id
381            ";
382            sqlx::query_scalar(query)
383                .bind(channel_id.0)
384                .bind(sender_id.0)
385                .bind(body)
386                .bind(timestamp)
387                .fetch_one(&self.db)
388                .await
389                .map(MessageId)
390        })
391    }
392
393    pub async fn get_recent_channel_messages(
394        &self,
395        channel_id: ChannelId,
396        count: usize,
397    ) -> Result<Vec<ChannelMessage>> {
398        test_support!(self, {
399            let query = r#"
400                SELECT
401                    id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
402                FROM
403                    channel_messages
404                WHERE
405                    channel_id = $1
406                LIMIT $2
407            "#;
408            sqlx::query_as(query)
409                .bind(channel_id.0)
410                .bind(count as i64)
411                .fetch_all(&self.db)
412                .await
413        })
414    }
415
416    #[cfg(test)]
417    pub async fn close(&self, db_name: &str) {
418        test_support!(self, {
419            let query = "
420                SELECT pg_terminate_backend(pg_stat_activity.pid)
421                FROM pg_stat_activity
422                WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
423            ";
424            sqlx::query(query)
425                .bind(db_name)
426                .execute(&self.db)
427                .await
428                .unwrap();
429            self.db.close().await;
430        })
431    }
432}
433
434macro_rules! id_type {
435    ($name:ident) => {
436        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
437        #[sqlx(transparent)]
438        #[serde(transparent)]
439        pub struct $name(pub i32);
440
441        impl $name {
442            #[allow(unused)]
443            pub fn from_proto(value: u64) -> Self {
444                Self(value as i32)
445            }
446
447            #[allow(unused)]
448            pub fn to_proto(&self) -> u64 {
449                self.0 as u64
450            }
451        }
452    };
453}
454
455id_type!(UserId);
456id_type!(OrgId);
457id_type!(ChannelId);
458id_type!(SignupId);
459id_type!(MessageId);