db.rs

  1use serde::Serialize;
  2use sqlx::{FromRow, Result};
  3use time::OffsetDateTime;
  4
  5pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  6pub use sqlx::postgres::PgPoolOptions as DbOptions;
  7
  8pub struct Db(pub sqlx::PgPool);
  9
 10#[derive(Debug, FromRow, Serialize)]
 11pub struct User {
 12    pub id: UserId,
 13    pub github_login: String,
 14    pub admin: bool,
 15}
 16
 17#[derive(Debug, FromRow, Serialize)]
 18pub struct Signup {
 19    pub id: SignupId,
 20    pub github_login: String,
 21    pub email_address: String,
 22    pub about: String,
 23}
 24
 25#[derive(Debug, FromRow, Serialize)]
 26pub struct Channel {
 27    pub id: ChannelId,
 28    pub name: String,
 29}
 30
 31#[derive(Debug, FromRow)]
 32pub struct ChannelMessage {
 33    pub id: MessageId,
 34    pub sender_id: UserId,
 35    pub body: String,
 36    pub sent_at: OffsetDateTime,
 37}
 38
 39impl Db {
 40    // signups
 41
 42    pub async fn create_signup(
 43        &self,
 44        github_login: &str,
 45        email_address: &str,
 46        about: &str,
 47    ) -> Result<SignupId> {
 48        let query = "
 49            INSERT INTO signups (github_login, email_address, about)
 50            VALUES ($1, $2, $3)
 51            RETURNING id
 52        ";
 53        sqlx::query_scalar(query)
 54            .bind(github_login)
 55            .bind(email_address)
 56            .bind(about)
 57            .fetch_one(&self.0)
 58            .await
 59            .map(SignupId)
 60    }
 61
 62    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 63        let query = "SELECT * FROM users ORDER BY github_login ASC";
 64        sqlx::query_as(query).fetch_all(&self.0).await
 65    }
 66
 67    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
 68        let query = "DELETE FROM signups WHERE id = $1";
 69        sqlx::query(query)
 70            .bind(id.0)
 71            .execute(&self.0)
 72            .await
 73            .map(drop)
 74    }
 75
 76    // users
 77
 78    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 79        let query = "
 80            INSERT INTO users (github_login, admin)
 81            VALUES ($1, $2)
 82            RETURNING id
 83        ";
 84        sqlx::query_scalar(query)
 85            .bind(github_login)
 86            .bind(admin)
 87            .fetch_one(&self.0)
 88            .await
 89            .map(UserId)
 90    }
 91
 92    pub async fn get_all_users(&self) -> Result<Vec<User>> {
 93        let query = "SELECT * FROM users ORDER BY github_login ASC";
 94        sqlx::query_as(query).fetch_all(&self.0).await
 95    }
 96
 97    pub async fn get_users_by_ids(
 98        &self,
 99        requester_id: UserId,
100        ids: impl Iterator<Item = UserId>,
101    ) -> Result<Vec<User>> {
102        // Only return users that are in a common channel with the requesting user.
103        let query = "
104            SELECT users.*
105            FROM
106                users, channel_memberships
107            WHERE
108                users.id IN $1 AND
109                channel_memberships.user_id = users.id AND
110                channel_memberships.channel_id IN (
111                    SELECT channel_id
112                    FROM channel_memberships
113                    WHERE channel_memberships.user_id = $2
114                )
115        ";
116
117        sqlx::query_as(query)
118            .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
119            .bind(requester_id)
120            .fetch_all(&self.0)
121            .await
122    }
123
124    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
125        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
126        sqlx::query_as(query)
127            .bind(github_login)
128            .fetch_optional(&self.0)
129            .await
130    }
131
132    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
133        let query = "UPDATE users SET admin = $1 WHERE id = $2";
134        sqlx::query(query)
135            .bind(is_admin)
136            .bind(id.0)
137            .execute(&self.0)
138            .await
139            .map(drop)
140    }
141
142    pub async fn delete_user(&self, id: UserId) -> Result<()> {
143        let query = "DELETE FROM users WHERE id = $1;";
144        sqlx::query(query)
145            .bind(id.0)
146            .execute(&self.0)
147            .await
148            .map(drop)
149    }
150
151    // access tokens
152
153    pub async fn create_access_token_hash(
154        &self,
155        user_id: UserId,
156        access_token_hash: String,
157    ) -> Result<()> {
158        let query = "
159            INSERT INTO access_tokens (user_id, hash)
160            VALUES ($1, $2)
161        ";
162        sqlx::query(query)
163            .bind(user_id.0)
164            .bind(access_token_hash)
165            .execute(&self.0)
166            .await
167            .map(drop)
168    }
169
170    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
171        let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
172        sqlx::query_scalar(query)
173            .bind(user_id.0)
174            .fetch_all(&self.0)
175            .await
176    }
177
178    // orgs
179
180    #[cfg(test)]
181    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
182        let query = "
183            INSERT INTO orgs (name, slug)
184            VALUES ($1, $2)
185            RETURNING id
186        ";
187        sqlx::query_scalar(query)
188            .bind(name)
189            .bind(slug)
190            .fetch_one(&self.0)
191            .await
192            .map(OrgId)
193    }
194
195    #[cfg(test)]
196    pub async fn add_org_member(
197        &self,
198        org_id: OrgId,
199        user_id: UserId,
200        is_admin: bool,
201    ) -> Result<()> {
202        let query = "
203            INSERT INTO org_memberships (org_id, user_id, admin)
204            VALUES ($1, $2, $3)
205        ";
206        sqlx::query(query)
207            .bind(org_id.0)
208            .bind(user_id.0)
209            .bind(is_admin)
210            .execute(&self.0)
211            .await
212            .map(drop)
213    }
214
215    // channels
216
217    #[cfg(test)]
218    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
219        let query = "
220            INSERT INTO channels (owner_id, owner_is_user, name)
221            VALUES ($1, false, $2)
222            RETURNING id
223        ";
224        sqlx::query_scalar(query)
225            .bind(org_id.0)
226            .bind(name)
227            .fetch_one(&self.0)
228            .await
229            .map(ChannelId)
230    }
231
232    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
233        let query = "
234            SELECT
235                channels.id, channels.name
236            FROM
237                channel_memberships, channels
238            WHERE
239                channel_memberships.user_id = $1 AND
240                channel_memberships.channel_id = channels.id
241        ";
242        sqlx::query_as(query)
243            .bind(user_id.0)
244            .fetch_all(&self.0)
245            .await
246    }
247
248    pub async fn can_user_access_channel(
249        &self,
250        user_id: UserId,
251        channel_id: ChannelId,
252    ) -> Result<bool> {
253        let query = "
254            SELECT id
255            FROM channel_memberships
256            WHERE user_id = $1 AND channel_id = $2
257            LIMIT 1
258        ";
259        sqlx::query_scalar::<_, i32>(query)
260            .bind(user_id.0)
261            .bind(channel_id.0)
262            .fetch_optional(&self.0)
263            .await
264            .map(|e| e.is_some())
265    }
266
267    #[cfg(test)]
268    pub async fn add_channel_member(
269        &self,
270        channel_id: ChannelId,
271        user_id: UserId,
272        is_admin: bool,
273    ) -> Result<()> {
274        let query = "
275            INSERT INTO channel_memberships (channel_id, user_id, admin)
276            VALUES ($1, $2, $3)
277        ";
278        sqlx::query(query)
279            .bind(channel_id.0)
280            .bind(user_id.0)
281            .bind(is_admin)
282            .execute(&self.0)
283            .await
284            .map(drop)
285    }
286
287    // messages
288
289    pub async fn create_channel_message(
290        &self,
291        channel_id: ChannelId,
292        sender_id: UserId,
293        body: &str,
294        timestamp: OffsetDateTime,
295    ) -> Result<MessageId> {
296        let query = "
297            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
298            VALUES ($1, $2, $3, $4)
299            RETURNING id
300        ";
301        sqlx::query_scalar(query)
302            .bind(channel_id.0)
303            .bind(sender_id.0)
304            .bind(body)
305            .bind(timestamp)
306            .fetch_one(&self.0)
307            .await
308            .map(MessageId)
309    }
310
311    pub async fn get_recent_channel_messages(
312        &self,
313        channel_id: ChannelId,
314        count: usize,
315    ) -> Result<Vec<ChannelMessage>> {
316        let query = r#"
317            SELECT
318                id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
319            FROM
320                channel_messages
321            WHERE
322                channel_id = $1
323            LIMIT $2
324        "#;
325        sqlx::query_as(query)
326            .bind(channel_id.0)
327            .bind(count as i64)
328            .fetch_all(&self.0)
329            .await
330    }
331}
332
333impl std::ops::Deref for Db {
334    type Target = sqlx::PgPool;
335
336    fn deref(&self) -> &Self::Target {
337        &self.0
338    }
339}
340
341macro_rules! id_type {
342    ($name:ident) => {
343        #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
344        #[sqlx(transparent)]
345        #[serde(transparent)]
346        pub struct $name(pub i32);
347
348        impl $name {
349            #[allow(unused)]
350            pub fn from_proto(value: u64) -> Self {
351                Self(value as i32)
352            }
353
354            #[allow(unused)]
355            pub fn to_proto(&self) -> u64 {
356                self.0 as u64
357            }
358        }
359    };
360}
361
362id_type!(UserId);
363id_type!(OrgId);
364id_type!(ChannelId);
365id_type!(SignupId);
366id_type!(MessageId);