1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24pub struct Db {
25 db: sqlx::PgPool,
26 test_mode: bool,
27}
28
29#[derive(Debug, FromRow, Serialize)]
30pub struct User {
31 pub id: UserId,
32 pub github_login: String,
33 pub admin: bool,
34}
35
36#[derive(Debug, FromRow, Serialize)]
37pub struct Signup {
38 pub id: SignupId,
39 pub github_login: String,
40 pub email_address: String,
41 pub about: String,
42}
43
44#[derive(Debug, FromRow, Serialize)]
45pub struct Channel {
46 pub id: ChannelId,
47 pub name: String,
48}
49
50#[derive(Debug, FromRow)]
51pub struct ChannelMessage {
52 pub id: MessageId,
53 pub sender_id: UserId,
54 pub body: String,
55 pub sent_at: OffsetDateTime,
56}
57
58impl Db {
59 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
60 let db = DbOptions::new()
61 .max_connections(max_connections)
62 .connect(url)
63 .await
64 .context("failed to connect to postgres database")?;
65 Ok(Self {
66 db,
67 test_mode: false,
68 })
69 }
70
71 #[cfg(test)]
72 pub fn test(url: &str, max_connections: u32) -> Self {
73 let mut db = block_on(Self::new(url, max_connections)).unwrap();
74 db.test_mode = true;
75 db
76 }
77
78 #[cfg(test)]
79 pub fn migrate(&self, path: &std::path::Path) {
80 block_on(async {
81 let migrator = sqlx::migrate::Migrator::new(path).await.unwrap();
82 migrator.run(&self.db).await.unwrap();
83 });
84 }
85
86 // signups
87
88 pub async fn create_signup(
89 &self,
90 github_login: &str,
91 email_address: &str,
92 about: &str,
93 ) -> Result<SignupId> {
94 test_support!(self, {
95 let query = "
96 INSERT INTO signups (github_login, email_address, about)
97 VALUES ($1, $2, $3)
98 RETURNING id
99 ";
100 sqlx::query_scalar(query)
101 .bind(github_login)
102 .bind(email_address)
103 .bind(about)
104 .fetch_one(&self.db)
105 .await
106 .map(SignupId)
107 })
108 }
109
110 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
111 test_support!(self, {
112 let query = "SELECT * FROM users ORDER BY github_login ASC";
113 sqlx::query_as(query).fetch_all(&self.db).await
114 })
115 }
116
117 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
118 test_support!(self, {
119 let query = "DELETE FROM signups WHERE id = $1";
120 sqlx::query(query)
121 .bind(id.0)
122 .execute(&self.db)
123 .await
124 .map(drop)
125 })
126 }
127
128 // users
129
130 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
131 test_support!(self, {
132 let query = "
133 INSERT INTO users (github_login, admin)
134 VALUES ($1, $2)
135 RETURNING id
136 ";
137 sqlx::query_scalar(query)
138 .bind(github_login)
139 .bind(admin)
140 .fetch_one(&self.db)
141 .await
142 .map(UserId)
143 })
144 }
145
146 pub async fn get_all_users(&self) -> Result<Vec<User>> {
147 test_support!(self, {
148 let query = "SELECT * FROM users ORDER BY github_login ASC";
149 sqlx::query_as(query).fetch_all(&self.db).await
150 })
151 }
152
153 pub async fn get_users_by_ids(
154 &self,
155 requester_id: UserId,
156 ids: impl Iterator<Item = UserId>,
157 ) -> Result<Vec<User>> {
158 test_support!(self, {
159 // Only return users that are in a common channel with the requesting user.
160 let query = "
161 SELECT users.*
162 FROM
163 users, channel_memberships
164 WHERE
165 users.id = ANY ($1) AND
166 channel_memberships.user_id = users.id AND
167 channel_memberships.channel_id IN (
168 SELECT channel_id
169 FROM channel_memberships
170 WHERE channel_memberships.user_id = $2
171 )
172 ";
173
174 sqlx::query_as(query)
175 .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
176 .bind(requester_id)
177 .fetch_all(&self.db)
178 .await
179 })
180 }
181
182 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
183 test_support!(self, {
184 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
185 sqlx::query_as(query)
186 .bind(github_login)
187 .fetch_optional(&self.db)
188 .await
189 })
190 }
191
192 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
193 test_support!(self, {
194 let query = "UPDATE users SET admin = $1 WHERE id = $2";
195 sqlx::query(query)
196 .bind(is_admin)
197 .bind(id.0)
198 .execute(&self.db)
199 .await
200 .map(drop)
201 })
202 }
203
204 pub async fn delete_user(&self, id: UserId) -> Result<()> {
205 test_support!(self, {
206 let query = "DELETE FROM users WHERE id = $1;";
207 sqlx::query(query)
208 .bind(id.0)
209 .execute(&self.db)
210 .await
211 .map(drop)
212 })
213 }
214
215 // access tokens
216
217 pub async fn create_access_token_hash(
218 &self,
219 user_id: UserId,
220 access_token_hash: String,
221 ) -> Result<()> {
222 test_support!(self, {
223 let query = "
224 INSERT INTO access_tokens (user_id, hash)
225 VALUES ($1, $2)
226 ";
227 sqlx::query(query)
228 .bind(user_id.0)
229 .bind(access_token_hash)
230 .execute(&self.db)
231 .await
232 .map(drop)
233 })
234 }
235
236 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
237 test_support!(self, {
238 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
239 sqlx::query_scalar(query)
240 .bind(user_id.0)
241 .fetch_all(&self.db)
242 .await
243 })
244 }
245
246 // orgs
247
248 #[cfg(test)]
249 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
250 test_support!(self, {
251 let query = "
252 INSERT INTO orgs (name, slug)
253 VALUES ($1, $2)
254 RETURNING id
255 ";
256 sqlx::query_scalar(query)
257 .bind(name)
258 .bind(slug)
259 .fetch_one(&self.db)
260 .await
261 .map(OrgId)
262 })
263 }
264
265 #[cfg(test)]
266 pub async fn add_org_member(
267 &self,
268 org_id: OrgId,
269 user_id: UserId,
270 is_admin: bool,
271 ) -> Result<()> {
272 test_support!(self, {
273 let query = "
274 INSERT INTO org_memberships (org_id, user_id, admin)
275 VALUES ($1, $2, $3)
276 ";
277 sqlx::query(query)
278 .bind(org_id.0)
279 .bind(user_id.0)
280 .bind(is_admin)
281 .execute(&self.db)
282 .await
283 .map(drop)
284 })
285 }
286
287 // channels
288
289 #[cfg(test)]
290 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
291 test_support!(self, {
292 let query = "
293 INSERT INTO channels (owner_id, owner_is_user, name)
294 VALUES ($1, false, $2)
295 RETURNING id
296 ";
297 sqlx::query_scalar(query)
298 .bind(org_id.0)
299 .bind(name)
300 .fetch_one(&self.db)
301 .await
302 .map(ChannelId)
303 })
304 }
305
306 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
307 test_support!(self, {
308 let query = "
309 SELECT
310 channels.id, channels.name
311 FROM
312 channel_memberships, channels
313 WHERE
314 channel_memberships.user_id = $1 AND
315 channel_memberships.channel_id = channels.id
316 ";
317 sqlx::query_as(query)
318 .bind(user_id.0)
319 .fetch_all(&self.db)
320 .await
321 })
322 }
323
324 pub async fn can_user_access_channel(
325 &self,
326 user_id: UserId,
327 channel_id: ChannelId,
328 ) -> Result<bool> {
329 test_support!(self, {
330 let query = "
331 SELECT id
332 FROM channel_memberships
333 WHERE user_id = $1 AND channel_id = $2
334 LIMIT 1
335 ";
336 sqlx::query_scalar::<_, i32>(query)
337 .bind(user_id.0)
338 .bind(channel_id.0)
339 .fetch_optional(&self.db)
340 .await
341 .map(|e| e.is_some())
342 })
343 }
344
345 #[cfg(test)]
346 pub async fn add_channel_member(
347 &self,
348 channel_id: ChannelId,
349 user_id: UserId,
350 is_admin: bool,
351 ) -> Result<()> {
352 test_support!(self, {
353 let query = "
354 INSERT INTO channel_memberships (channel_id, user_id, admin)
355 VALUES ($1, $2, $3)
356 ";
357 sqlx::query(query)
358 .bind(channel_id.0)
359 .bind(user_id.0)
360 .bind(is_admin)
361 .execute(&self.db)
362 .await
363 .map(drop)
364 })
365 }
366
367 // messages
368
369 pub async fn create_channel_message(
370 &self,
371 channel_id: ChannelId,
372 sender_id: UserId,
373 body: &str,
374 timestamp: OffsetDateTime,
375 ) -> Result<MessageId> {
376 test_support!(self, {
377 let query = "
378 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
379 VALUES ($1, $2, $3, $4)
380 RETURNING id
381 ";
382 sqlx::query_scalar(query)
383 .bind(channel_id.0)
384 .bind(sender_id.0)
385 .bind(body)
386 .bind(timestamp)
387 .fetch_one(&self.db)
388 .await
389 .map(MessageId)
390 })
391 }
392
393 pub async fn get_recent_channel_messages(
394 &self,
395 channel_id: ChannelId,
396 count: usize,
397 ) -> Result<Vec<ChannelMessage>> {
398 test_support!(self, {
399 let query = r#"
400 SELECT
401 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
402 FROM
403 channel_messages
404 WHERE
405 channel_id = $1
406 LIMIT $2
407 "#;
408 sqlx::query_as(query)
409 .bind(channel_id.0)
410 .bind(count as i64)
411 .fetch_all(&self.db)
412 .await
413 })
414 }
415
416 #[cfg(test)]
417 pub async fn close(&self, db_name: &str) {
418 test_support!(self, {
419 let query = "
420 SELECT pg_terminate_backend(pg_stat_activity.pid)
421 FROM pg_stat_activity
422 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
423 ";
424 sqlx::query(query)
425 .bind(db_name)
426 .execute(&self.db)
427 .await
428 .unwrap();
429 self.db.close().await;
430 })
431 }
432}
433
434macro_rules! id_type {
435 ($name:ident) => {
436 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
437 #[sqlx(transparent)]
438 #[serde(transparent)]
439 pub struct $name(pub i32);
440
441 impl $name {
442 #[allow(unused)]
443 pub fn from_proto(value: u64) -> Self {
444 Self(value as i32)
445 }
446
447 #[allow(unused)]
448 pub fn to_proto(&self) -> u64 {
449 self.0 as u64
450 }
451 }
452 };
453}
454
455id_type!(UserId);
456id_type!(OrgId);
457id_type!(ChannelId);
458id_type!(SignupId);
459id_type!(MessageId);