db.rs

  1use serde::Serialize;
  2use sqlx::{FromRow, Result};
  3
  4pub use async_sqlx_session::PostgresSessionStore as SessionStore;
  5pub use sqlx::postgres::PgPoolOptions as DbOptions;
  6
  7pub struct Db(pub sqlx::PgPool);
  8
  9#[derive(Debug, FromRow, Serialize)]
 10pub struct User {
 11    id: i32,
 12    pub github_login: String,
 13    pub admin: bool,
 14}
 15
 16#[derive(Debug, FromRow, Serialize)]
 17pub struct Signup {
 18    id: i32,
 19    pub github_login: String,
 20    pub email_address: String,
 21    pub about: String,
 22}
 23
 24#[derive(Debug, FromRow, Serialize)]
 25pub struct Channel {
 26    id: i32,
 27    pub name: String,
 28}
 29
 30#[derive(Debug, FromRow)]
 31pub struct ChannelMessage {
 32    id: i32,
 33    sender_id: i32,
 34    body: String,
 35    sent_at: i64,
 36}
 37
 38#[derive(Clone, Copy)]
 39pub struct UserId(pub i32);
 40
 41#[derive(Clone, Copy)]
 42pub struct OrgId(pub i32);
 43
 44#[derive(Clone, Copy)]
 45pub struct ChannelId(pub i32);
 46
 47#[derive(Clone, Copy)]
 48pub struct SignupId(pub i32);
 49
 50#[derive(Clone, Copy)]
 51pub struct MessageId(pub i32);
 52
 53impl Db {
 54    // signups
 55
 56    pub async fn create_signup(
 57        &self,
 58        github_login: &str,
 59        email_address: &str,
 60        about: &str,
 61    ) -> Result<SignupId> {
 62        let query = "
 63            INSERT INTO signups (github_login, email_address, about)
 64            VALUES ($1, $2, $3)
 65            RETURNING id
 66        ";
 67        sqlx::query_scalar(query)
 68            .bind(github_login)
 69            .bind(email_address)
 70            .bind(about)
 71            .fetch_one(&self.0)
 72            .await
 73            .map(SignupId)
 74    }
 75
 76    pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
 77        let query = "SELECT * FROM users ORDER BY github_login ASC";
 78        sqlx::query_as(query).fetch_all(&self.0).await
 79    }
 80
 81    pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
 82        let query = "DELETE FROM signups WHERE id = $1";
 83        sqlx::query(query)
 84            .bind(id.0)
 85            .execute(&self.0)
 86            .await
 87            .map(drop)
 88    }
 89
 90    // users
 91
 92    pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
 93        let query = "
 94            INSERT INTO users (github_login, admin)
 95            VALUES ($1, $2)
 96            RETURNING id
 97        ";
 98        sqlx::query_scalar(query)
 99            .bind(github_login)
100            .bind(admin)
101            .fetch_one(&self.0)
102            .await
103            .map(UserId)
104    }
105
106    pub async fn get_all_users(&self) -> Result<Vec<User>> {
107        let query = "SELECT * FROM users ORDER BY github_login ASC";
108        sqlx::query_as(query).fetch_all(&self.0).await
109    }
110
111    pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
112        let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
113        sqlx::query_as(query)
114            .bind(github_login)
115            .fetch_optional(&self.0)
116            .await
117    }
118
119    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
120        let query = "UPDATE users SET admin = $1 WHERE id = $2";
121        sqlx::query(query)
122            .bind(is_admin)
123            .bind(id.0)
124            .execute(&self.0)
125            .await
126            .map(drop)
127    }
128
129    pub async fn delete_user(&self, id: UserId) -> Result<()> {
130        let query = "DELETE FROM users WHERE id = $1;";
131        sqlx::query(query)
132            .bind(id.0)
133            .execute(&self.0)
134            .await
135            .map(drop)
136    }
137
138    // access tokens
139
140    pub async fn create_access_token_hash(
141        &self,
142        user_id: UserId,
143        access_token_hash: String,
144    ) -> Result<()> {
145        let query = "
146            INSERT INTO access_tokens (user_id, hash)
147            VALUES ($1, $2)
148        ";
149        sqlx::query(query)
150            .bind(user_id.0 as i32)
151            .bind(access_token_hash)
152            .execute(&self.0)
153            .await
154            .map(drop)
155    }
156
157    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
158        let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
159        sqlx::query_scalar::<_, String>(query)
160            .bind(user_id.0 as i32)
161            .fetch_all(&self.0)
162            .await
163    }
164
165    // orgs
166
167    #[cfg(test)]
168    pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
169        let query = "
170            INSERT INTO orgs (name, slug)
171            VALUES ($1, $2)
172            RETURNING id
173        ";
174        sqlx::query_scalar(query)
175            .bind(name)
176            .bind(slug)
177            .fetch_one(&self.0)
178            .await
179            .map(OrgId)
180    }
181
182    #[cfg(test)]
183    pub async fn add_org_member(&self, org_id: OrgId, user_id: UserId) -> Result<()> {
184        let query = "
185            INSERT INTO org_memberships (org_id, user_id)
186            VALUES ($1, $2)
187        ";
188        sqlx::query(query)
189            .bind(org_id.0)
190            .bind(user_id.0)
191            .execute(&self.0)
192            .await
193            .map(drop)
194    }
195
196    // channels
197
198    #[cfg(test)]
199    pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
200        let query = "
201            INSERT INTO channels (owner_id, owner_is_user, name)
202            VALUES ($1, false, $2)
203            RETURNING id
204        ";
205        sqlx::query_scalar(query)
206            .bind(org_id.0)
207            .bind(name)
208            .fetch_one(&self.0)
209            .await
210            .map(ChannelId)
211    }
212
213    pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
214        let query = "
215            SELECT
216                channels.id, channels.name
217            FROM
218                channel_memberships, channels
219            WHERE
220                channel_memberships.user_id = $1 AND
221                channel_memberships.channel_id = channels.id
222        ";
223        sqlx::query_as(query)
224            .bind(user_id.0)
225            .fetch_all(&self.0)
226            .await
227    }
228
229    pub async fn can_user_access_channel(
230        &self,
231        user_id: UserId,
232        channel_id: ChannelId,
233    ) -> Result<bool> {
234        let query = "
235            SELECT id
236            FROM channel_memberships
237            WHERE user_id = $1 AND channel_id = $2
238            LIMIT 1
239        ";
240        sqlx::query_scalar::<_, i32>(query)
241            .bind(user_id.0)
242            .bind(channel_id.0)
243            .fetch_optional(&self.0)
244            .await
245            .map(|e| e.is_some())
246    }
247
248    #[cfg(test)]
249    pub async fn add_channel_member(
250        &self,
251        channel_id: ChannelId,
252        user_id: UserId,
253        is_admin: bool,
254    ) -> Result<()> {
255        let query = "
256            INSERT INTO channel_memberships (channel_id, user_id, admin)
257            VALUES ($1, $2, $3)
258        ";
259        sqlx::query(query)
260            .bind(channel_id.0)
261            .bind(user_id.0)
262            .bind(is_admin)
263            .execute(&self.0)
264            .await
265            .map(drop)
266    }
267
268    // messages
269
270    pub async fn create_channel_message(
271        &self,
272        channel_id: ChannelId,
273        sender_id: UserId,
274        body: &str,
275    ) -> Result<MessageId> {
276        let query = "
277            INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
278            VALUES ($1, $2, $3, NOW()::timestamp)
279            RETURNING id
280        ";
281        sqlx::query_scalar(query)
282            .bind(channel_id.0)
283            .bind(sender_id.0)
284            .bind(body)
285            .fetch_one(&self.0)
286            .await
287            .map(MessageId)
288    }
289
290    pub async fn get_recent_channel_messages(
291        &self,
292        channel_id: ChannelId,
293        count: usize,
294    ) -> Result<Vec<ChannelMessage>> {
295        let query = "
296            SELECT id, sender_id, body, sent_at
297            FROM channel_messages
298            WHERE channel_id = $1
299            LIMIT $2
300        ";
301        sqlx::query_as(query)
302            .bind(channel_id.0)
303            .bind(count as i64)
304            .fetch_all(&self.0)
305            .await
306    }
307}
308
309impl std::ops::Deref for Db {
310    type Target = sqlx::PgPool;
311
312    fn deref(&self) -> &Self::Target {
313        &self.0
314    }
315}
316
317impl Channel {
318    pub fn id(&self) -> ChannelId {
319        ChannelId(self.id)
320    }
321}
322
323impl User {
324    pub fn id(&self) -> UserId {
325        UserId(self.id)
326    }
327}