1use anyhow::Context;
2use async_std::task::{block_on, yield_now};
3use serde::Serialize;
4use sqlx::{FromRow, Result};
5use time::OffsetDateTime;
6
7pub use async_sqlx_session::PostgresSessionStore as SessionStore;
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9
10macro_rules! test_support {
11 ($self:ident, { $($token:tt)* }) => {{
12 let body = async {
13 $($token)*
14 };
15 if $self.test_mode {
16 yield_now().await;
17 block_on(body)
18 } else {
19 body.await
20 }
21 }};
22}
23
24#[derive(Clone)]
25pub struct Db {
26 pool: sqlx::PgPool,
27 test_mode: bool,
28}
29
30impl Db {
31 pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
32 let pool = DbOptions::new()
33 .max_connections(max_connections)
34 .connect(url)
35 .await
36 .context("failed to connect to postgres database")?;
37 Ok(Self {
38 pool,
39 test_mode: false,
40 })
41 }
42
43 // signups
44
45 pub async fn create_signup(
46 &self,
47 github_login: &str,
48 email_address: &str,
49 about: &str,
50 ) -> Result<SignupId> {
51 test_support!(self, {
52 let query = "
53 INSERT INTO signups (github_login, email_address, about)
54 VALUES ($1, $2, $3)
55 RETURNING id
56 ";
57 sqlx::query_scalar(query)
58 .bind(github_login)
59 .bind(email_address)
60 .bind(about)
61 .fetch_one(&self.pool)
62 .await
63 .map(SignupId)
64 })
65 }
66
67 pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
68 test_support!(self, {
69 let query = "SELECT * FROM signups ORDER BY github_login ASC";
70 sqlx::query_as(query).fetch_all(&self.pool).await
71 })
72 }
73
74 pub async fn delete_signup(&self, id: SignupId) -> Result<()> {
75 test_support!(self, {
76 let query = "DELETE FROM signups WHERE id = $1";
77 sqlx::query(query)
78 .bind(id.0)
79 .execute(&self.pool)
80 .await
81 .map(drop)
82 })
83 }
84
85 // users
86
87 #[allow(unused)] // Help rust-analyzer
88 #[cfg(any(test, feature = "seed-support"))]
89 pub async fn get_user(&self, github_login: &str) -> Result<Option<UserId>> {
90 test_support!(self, {
91 let query = "
92 SELECT id
93 FROM users
94 WHERE github_login = $1
95 ";
96 sqlx::query_scalar(query)
97 .bind(github_login)
98 .fetch_optional(&self.pool)
99 .await
100 })
101 }
102
103 pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
104 test_support!(self, {
105 let query = "
106 INSERT INTO users (github_login, admin)
107 VALUES ($1, $2)
108 RETURNING id
109 ";
110 sqlx::query_scalar(query)
111 .bind(github_login)
112 .bind(admin)
113 .fetch_one(&self.pool)
114 .await
115 .map(UserId)
116 })
117 }
118
119 pub async fn get_all_users(&self) -> Result<Vec<User>> {
120 test_support!(self, {
121 let query = "SELECT * FROM users ORDER BY github_login ASC";
122 sqlx::query_as(query).fetch_all(&self.pool).await
123 })
124 }
125
126 pub async fn get_users_by_ids(
127 &self,
128 requester_id: UserId,
129 ids: impl Iterator<Item = UserId>,
130 ) -> Result<Vec<User>> {
131 test_support!(self, {
132 // Only return users that are in a common channel with the requesting user.
133 let query = "
134 SELECT users.*
135 FROM
136 users LEFT JOIN channel_memberships
137 ON
138 channel_memberships.user_id = users.id
139 WHERE
140 users.id = $2 OR
141 (
142 users.id = ANY ($1) AND
143 channel_memberships.channel_id IN (
144 SELECT channel_id
145 FROM channel_memberships
146 WHERE channel_memberships.user_id = $2
147 )
148 )
149 ";
150
151 sqlx::query_as(query)
152 .bind(&ids.map(|id| id.0).collect::<Vec<_>>())
153 .bind(requester_id)
154 .fetch_all(&self.pool)
155 .await
156 })
157 }
158
159 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
160 test_support!(self, {
161 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
162 sqlx::query_as(query)
163 .bind(github_login)
164 .fetch_optional(&self.pool)
165 .await
166 })
167 }
168
169 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
170 test_support!(self, {
171 let query = "UPDATE users SET admin = $1 WHERE id = $2";
172 sqlx::query(query)
173 .bind(is_admin)
174 .bind(id.0)
175 .execute(&self.pool)
176 .await
177 .map(drop)
178 })
179 }
180
181 pub async fn delete_user(&self, id: UserId) -> Result<()> {
182 test_support!(self, {
183 let query = "DELETE FROM users WHERE id = $1;";
184 sqlx::query(query)
185 .bind(id.0)
186 .execute(&self.pool)
187 .await
188 .map(drop)
189 })
190 }
191
192 // access tokens
193
194 pub async fn create_access_token_hash(
195 &self,
196 user_id: UserId,
197 access_token_hash: String,
198 ) -> Result<()> {
199 test_support!(self, {
200 let query = "
201 INSERT INTO access_tokens (user_id, hash)
202 VALUES ($1, $2)
203 ";
204 sqlx::query(query)
205 .bind(user_id.0)
206 .bind(access_token_hash)
207 .execute(&self.pool)
208 .await
209 .map(drop)
210 })
211 }
212
213 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
214 test_support!(self, {
215 let query = "SELECT hash FROM access_tokens WHERE user_id = $1";
216 sqlx::query_scalar(query)
217 .bind(user_id.0)
218 .fetch_all(&self.pool)
219 .await
220 })
221 }
222
223 // orgs
224
225 #[allow(unused)] // Help rust-analyzer
226 #[cfg(any(test, feature = "seed-support"))]
227 pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
228 test_support!(self, {
229 let query = "
230 SELECT *
231 FROM orgs
232 WHERE slug = $1
233 ";
234 sqlx::query_as(query)
235 .bind(slug)
236 .fetch_optional(&self.pool)
237 .await
238 })
239 }
240
241 #[cfg(any(test, feature = "seed-support"))]
242 pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
243 test_support!(self, {
244 let query = "
245 INSERT INTO orgs (name, slug)
246 VALUES ($1, $2)
247 RETURNING id
248 ";
249 sqlx::query_scalar(query)
250 .bind(name)
251 .bind(slug)
252 .fetch_one(&self.pool)
253 .await
254 .map(OrgId)
255 })
256 }
257
258 #[cfg(any(test, feature = "seed-support"))]
259 pub async fn add_org_member(
260 &self,
261 org_id: OrgId,
262 user_id: UserId,
263 is_admin: bool,
264 ) -> Result<()> {
265 test_support!(self, {
266 let query = "
267 INSERT INTO org_memberships (org_id, user_id, admin)
268 VALUES ($1, $2, $3)
269 ON CONFLICT DO NOTHING
270 ";
271 sqlx::query(query)
272 .bind(org_id.0)
273 .bind(user_id.0)
274 .bind(is_admin)
275 .execute(&self.pool)
276 .await
277 .map(drop)
278 })
279 }
280
281 // channels
282
283 #[cfg(any(test, feature = "seed-support"))]
284 pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
285 test_support!(self, {
286 let query = "
287 INSERT INTO channels (owner_id, owner_is_user, name)
288 VALUES ($1, false, $2)
289 RETURNING id
290 ";
291 sqlx::query_scalar(query)
292 .bind(org_id.0)
293 .bind(name)
294 .fetch_one(&self.pool)
295 .await
296 .map(ChannelId)
297 })
298 }
299
300 #[allow(unused)] // Help rust-analyzer
301 #[cfg(any(test, feature = "seed-support"))]
302 pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
303 test_support!(self, {
304 let query = "
305 SELECT *
306 FROM channels
307 WHERE
308 channels.owner_is_user = false AND
309 channels.owner_id = $1
310 ";
311 sqlx::query_as(query)
312 .bind(org_id.0)
313 .fetch_all(&self.pool)
314 .await
315 })
316 }
317
318 pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
319 test_support!(self, {
320 let query = "
321 SELECT
322 channels.id, channels.name
323 FROM
324 channel_memberships, channels
325 WHERE
326 channel_memberships.user_id = $1 AND
327 channel_memberships.channel_id = channels.id
328 ";
329 sqlx::query_as(query)
330 .bind(user_id.0)
331 .fetch_all(&self.pool)
332 .await
333 })
334 }
335
336 pub async fn can_user_access_channel(
337 &self,
338 user_id: UserId,
339 channel_id: ChannelId,
340 ) -> Result<bool> {
341 test_support!(self, {
342 let query = "
343 SELECT id
344 FROM channel_memberships
345 WHERE user_id = $1 AND channel_id = $2
346 LIMIT 1
347 ";
348 sqlx::query_scalar::<_, i32>(query)
349 .bind(user_id.0)
350 .bind(channel_id.0)
351 .fetch_optional(&self.pool)
352 .await
353 .map(|e| e.is_some())
354 })
355 }
356
357 #[cfg(any(test, feature = "seed-support"))]
358 pub async fn add_channel_member(
359 &self,
360 channel_id: ChannelId,
361 user_id: UserId,
362 is_admin: bool,
363 ) -> Result<()> {
364 test_support!(self, {
365 let query = "
366 INSERT INTO channel_memberships (channel_id, user_id, admin)
367 VALUES ($1, $2, $3)
368 ON CONFLICT DO NOTHING
369 ";
370 sqlx::query(query)
371 .bind(channel_id.0)
372 .bind(user_id.0)
373 .bind(is_admin)
374 .execute(&self.pool)
375 .await
376 .map(drop)
377 })
378 }
379
380 // messages
381
382 pub async fn create_channel_message(
383 &self,
384 channel_id: ChannelId,
385 sender_id: UserId,
386 body: &str,
387 timestamp: OffsetDateTime,
388 ) -> Result<MessageId> {
389 test_support!(self, {
390 let query = "
391 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at)
392 VALUES ($1, $2, $3, $4)
393 RETURNING id
394 ";
395 sqlx::query_scalar(query)
396 .bind(channel_id.0)
397 .bind(sender_id.0)
398 .bind(body)
399 .bind(timestamp)
400 .fetch_one(&self.pool)
401 .await
402 .map(MessageId)
403 })
404 }
405
406 pub async fn get_channel_messages(
407 &self,
408 channel_id: ChannelId,
409 count: usize,
410 before_id: Option<MessageId>,
411 ) -> Result<Vec<ChannelMessage>> {
412 test_support!(self, {
413 let query = r#"
414 SELECT * FROM (
415 SELECT
416 id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at
417 FROM
418 channel_messages
419 WHERE
420 channel_id = $1 AND
421 id < $2
422 ORDER BY id DESC
423 LIMIT $3
424 ) as recent_messages
425 ORDER BY id ASC
426 "#;
427 sqlx::query_as(query)
428 .bind(channel_id.0)
429 .bind(before_id.unwrap_or(MessageId::MAX))
430 .bind(count as i64)
431 .fetch_all(&self.pool)
432 .await
433 })
434 }
435}
436
437macro_rules! id_type {
438 ($name:ident) => {
439 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, sqlx::Type, Serialize)]
440 #[sqlx(transparent)]
441 #[serde(transparent)]
442 pub struct $name(pub i32);
443
444 impl $name {
445 #[allow(unused)]
446 pub const MAX: Self = Self(i32::MAX);
447
448 #[allow(unused)]
449 pub fn from_proto(value: u64) -> Self {
450 Self(value as i32)
451 }
452
453 #[allow(unused)]
454 pub fn to_proto(&self) -> u64 {
455 self.0 as u64
456 }
457 }
458 };
459}
460
461id_type!(UserId);
462#[derive(Debug, FromRow, Serialize, PartialEq)]
463pub struct User {
464 pub id: UserId,
465 pub github_login: String,
466 pub admin: bool,
467}
468
469id_type!(OrgId);
470#[derive(FromRow)]
471pub struct Org {
472 pub id: OrgId,
473 pub name: String,
474 pub slug: String,
475}
476
477id_type!(SignupId);
478#[derive(Debug, FromRow, Serialize)]
479pub struct Signup {
480 pub id: SignupId,
481 pub github_login: String,
482 pub email_address: String,
483 pub about: String,
484}
485
486id_type!(ChannelId);
487#[derive(Debug, FromRow, Serialize)]
488pub struct Channel {
489 pub id: ChannelId,
490 pub name: String,
491}
492
493id_type!(MessageId);
494#[derive(Debug, FromRow)]
495pub struct ChannelMessage {
496 pub id: MessageId,
497 pub sender_id: UserId,
498 pub body: String,
499 pub sent_at: OffsetDateTime,
500}
501
502#[cfg(test)]
503pub mod tests {
504 use super::*;
505 use rand::prelude::*;
506 use sqlx::{
507 migrate::{MigrateDatabase, Migrator},
508 Postgres,
509 };
510 use std::path::Path;
511
512 pub struct TestDb {
513 pub db: Db,
514 pub name: String,
515 pub url: String,
516 }
517
518 impl TestDb {
519 pub fn new() -> Self {
520 // Enable tests to run in parallel by serializing the creation of each test database.
521 lazy_static::lazy_static! {
522 static ref DB_CREATION: std::sync::Mutex<()> = std::sync::Mutex::new(());
523 }
524
525 let mut rng = StdRng::from_entropy();
526 let name = format!("zed-test-{}", rng.gen::<u128>());
527 let url = format!("postgres://postgres@localhost/{}", name);
528 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
529 let db = block_on(async {
530 {
531 let _lock = DB_CREATION.lock();
532 Postgres::create_database(&url)
533 .await
534 .expect("failed to create test db");
535 }
536 let mut db = Db::new(&url, 5).await.unwrap();
537 db.test_mode = true;
538 let migrator = Migrator::new(migrations_path).await.unwrap();
539 migrator.run(&db.pool).await.unwrap();
540 db
541 });
542
543 Self { db, name, url }
544 }
545
546 pub fn db(&self) -> &Db {
547 &self.db
548 }
549 }
550
551 impl Drop for TestDb {
552 fn drop(&mut self) {
553 block_on(async {
554 let query = "
555 SELECT pg_terminate_backend(pg_stat_activity.pid)
556 FROM pg_stat_activity
557 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
558 ";
559 sqlx::query(query)
560 .bind(&self.name)
561 .execute(&self.db.pool)
562 .await
563 .unwrap();
564 self.db.pool.close().await;
565 Postgres::drop_database(&self.url).await.unwrap();
566 });
567 }
568 }
569
570 #[gpui::test]
571 async fn test_get_users_by_ids() {
572 let test_db = TestDb::new();
573 let db = test_db.db();
574 let user_id = db.create_user("user", false).await.unwrap();
575 assert_eq!(
576 db.get_users_by_ids(user_id, Some(user_id).iter().copied())
577 .await
578 .unwrap(),
579 vec![User {
580 id: user_id,
581 github_login: "user".to_string(),
582 admin: false,
583 }]
584 )
585 }
586
587 #[gpui::test]
588 async fn test_recent_channel_messages() {
589 let test_db = TestDb::new();
590 let db = test_db.db();
591 let user = db.create_user("user", false).await.unwrap();
592 let org = db.create_org("org", "org").await.unwrap();
593 let channel = db.create_org_channel(org, "channel").await.unwrap();
594 for i in 0..10 {
595 db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc())
596 .await
597 .unwrap();
598 }
599
600 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
601 assert_eq!(
602 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
603 ["5", "6", "7", "8", "9"]
604 );
605
606 let prev_messages = db
607 .get_channel_messages(channel, 4, Some(messages[0].id))
608 .await
609 .unwrap();
610 assert_eq!(
611 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
612 ["1", "2", "3", "4"]
613 );
614 }
615}