1use anyhow::Context;
2use anyhow::Result;
3use async_trait::async_trait;
4use serde::Serialize;
5pub use sqlx::postgres::PgPoolOptions as DbOptions;
6use sqlx::{types::Uuid, FromRow};
7use time::OffsetDateTime;
8
9#[async_trait]
10pub trait Db: Send + Sync {
11 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
12 async fn get_all_users(&self) -> Result<Vec<User>>;
13 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
14 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
15 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
16 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
17 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
18 async fn destroy_user(&self, id: UserId) -> Result<()>;
19 async fn create_access_token_hash(
20 &self,
21 user_id: UserId,
22 access_token_hash: &str,
23 max_access_token_count: usize,
24 ) -> Result<()>;
25 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
26 #[cfg(any(test, feature = "seed-support"))]
27 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
28 #[cfg(any(test, feature = "seed-support"))]
29 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
30 #[cfg(any(test, feature = "seed-support"))]
31 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
32 #[cfg(any(test, feature = "seed-support"))]
33 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
34 #[cfg(any(test, feature = "seed-support"))]
35 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
36 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
37 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
38 -> Result<bool>;
39 #[cfg(any(test, feature = "seed-support"))]
40 async fn add_channel_member(
41 &self,
42 channel_id: ChannelId,
43 user_id: UserId,
44 is_admin: bool,
45 ) -> Result<()>;
46 async fn create_channel_message(
47 &self,
48 channel_id: ChannelId,
49 sender_id: UserId,
50 body: &str,
51 timestamp: OffsetDateTime,
52 nonce: u128,
53 ) -> Result<MessageId>;
54 async fn get_channel_messages(
55 &self,
56 channel_id: ChannelId,
57 count: usize,
58 before_id: Option<MessageId>,
59 ) -> Result<Vec<ChannelMessage>>;
60 #[cfg(test)]
61 async fn teardown(&self, url: &str);
62}
63
64pub struct PostgresDb {
65 pool: sqlx::PgPool,
66}
67
68impl PostgresDb {
69 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
70 let pool = DbOptions::new()
71 .max_connections(max_connections)
72 .connect(&url)
73 .await
74 .context("failed to connect to postgres database")?;
75 Ok(Self { pool })
76 }
77}
78
79#[async_trait]
80impl Db for PostgresDb {
81 // users
82
83 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
84 let query = "
85 INSERT INTO users (github_login, admin)
86 VALUES ($1, $2)
87 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
88 RETURNING id
89 ";
90 Ok(sqlx::query_scalar(query)
91 .bind(github_login)
92 .bind(admin)
93 .fetch_one(&self.pool)
94 .await
95 .map(UserId)?)
96 }
97
98 async fn get_all_users(&self) -> Result<Vec<User>> {
99 let query = "SELECT * FROM users ORDER BY github_login ASC";
100 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
101 }
102
103 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
104 let like_string = fuzzy_like_string(name_query);
105 let query = "
106 SELECT users.*
107 FROM users
108 WHERE github_login like $1
109 ORDER BY github_login <-> $2
110 LIMIT $3
111 ";
112 Ok(sqlx::query_as(query)
113 .bind(like_string)
114 .bind(name_query)
115 .bind(limit)
116 .fetch_all(&self.pool)
117 .await?)
118 }
119
120 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
121 let users = self.get_users_by_ids(vec![id]).await?;
122 Ok(users.into_iter().next())
123 }
124
125 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
126 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
127 let query = "
128 SELECT users.*
129 FROM users
130 WHERE users.id = ANY ($1)
131 ";
132 Ok(sqlx::query_as(query)
133 .bind(&ids)
134 .fetch_all(&self.pool)
135 .await?)
136 }
137
138 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
139 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
140 Ok(sqlx::query_as(query)
141 .bind(github_login)
142 .fetch_optional(&self.pool)
143 .await?)
144 }
145
146 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
147 let query = "UPDATE users SET admin = $1 WHERE id = $2";
148 Ok(sqlx::query(query)
149 .bind(is_admin)
150 .bind(id.0)
151 .execute(&self.pool)
152 .await
153 .map(drop)?)
154 }
155
156 async fn destroy_user(&self, id: UserId) -> Result<()> {
157 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
158 sqlx::query(query)
159 .bind(id.0)
160 .execute(&self.pool)
161 .await
162 .map(drop)?;
163 let query = "DELETE FROM users WHERE id = $1;";
164 Ok(sqlx::query(query)
165 .bind(id.0)
166 .execute(&self.pool)
167 .await
168 .map(drop)?)
169 }
170
171 // access tokens
172
173 async fn create_access_token_hash(
174 &self,
175 user_id: UserId,
176 access_token_hash: &str,
177 max_access_token_count: usize,
178 ) -> Result<()> {
179 let insert_query = "
180 INSERT INTO access_tokens (user_id, hash)
181 VALUES ($1, $2);
182 ";
183 let cleanup_query = "
184 DELETE FROM access_tokens
185 WHERE id IN (
186 SELECT id from access_tokens
187 WHERE user_id = $1
188 ORDER BY id DESC
189 OFFSET $3
190 )
191 ";
192
193 let mut tx = self.pool.begin().await?;
194 sqlx::query(insert_query)
195 .bind(user_id.0)
196 .bind(access_token_hash)
197 .execute(&mut tx)
198 .await?;
199 sqlx::query(cleanup_query)
200 .bind(user_id.0)
201 .bind(access_token_hash)
202 .bind(max_access_token_count as u32)
203 .execute(&mut tx)
204 .await?;
205 Ok(tx.commit().await?)
206 }
207
208 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
209 let query = "
210 SELECT hash
211 FROM access_tokens
212 WHERE user_id = $1
213 ORDER BY id DESC
214 ";
215 Ok(sqlx::query_scalar(query)
216 .bind(user_id.0)
217 .fetch_all(&self.pool)
218 .await?)
219 }
220
221 // orgs
222
223 #[allow(unused)] // Help rust-analyzer
224 #[cfg(any(test, feature = "seed-support"))]
225 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
226 let query = "
227 SELECT *
228 FROM orgs
229 WHERE slug = $1
230 ";
231 Ok(sqlx::query_as(query)
232 .bind(slug)
233 .fetch_optional(&self.pool)
234 .await?)
235 }
236
237 #[cfg(any(test, feature = "seed-support"))]
238 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
239 let query = "
240 INSERT INTO orgs (name, slug)
241 VALUES ($1, $2)
242 RETURNING id
243 ";
244 Ok(sqlx::query_scalar(query)
245 .bind(name)
246 .bind(slug)
247 .fetch_one(&self.pool)
248 .await
249 .map(OrgId)?)
250 }
251
252 #[cfg(any(test, feature = "seed-support"))]
253 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
254 let query = "
255 INSERT INTO org_memberships (org_id, user_id, admin)
256 VALUES ($1, $2, $3)
257 ON CONFLICT DO NOTHING
258 ";
259 Ok(sqlx::query(query)
260 .bind(org_id.0)
261 .bind(user_id.0)
262 .bind(is_admin)
263 .execute(&self.pool)
264 .await
265 .map(drop)?)
266 }
267
268 // channels
269
270 #[cfg(any(test, feature = "seed-support"))]
271 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
272 let query = "
273 INSERT INTO channels (owner_id, owner_is_user, name)
274 VALUES ($1, false, $2)
275 RETURNING id
276 ";
277 Ok(sqlx::query_scalar(query)
278 .bind(org_id.0)
279 .bind(name)
280 .fetch_one(&self.pool)
281 .await
282 .map(ChannelId)?)
283 }
284
285 #[allow(unused)] // Help rust-analyzer
286 #[cfg(any(test, feature = "seed-support"))]
287 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
288 let query = "
289 SELECT *
290 FROM channels
291 WHERE
292 channels.owner_is_user = false AND
293 channels.owner_id = $1
294 ";
295 Ok(sqlx::query_as(query)
296 .bind(org_id.0)
297 .fetch_all(&self.pool)
298 .await?)
299 }
300
301 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
302 let query = "
303 SELECT
304 channels.*
305 FROM
306 channel_memberships, channels
307 WHERE
308 channel_memberships.user_id = $1 AND
309 channel_memberships.channel_id = channels.id
310 ";
311 Ok(sqlx::query_as(query)
312 .bind(user_id.0)
313 .fetch_all(&self.pool)
314 .await?)
315 }
316
317 async fn can_user_access_channel(
318 &self,
319 user_id: UserId,
320 channel_id: ChannelId,
321 ) -> Result<bool> {
322 let query = "
323 SELECT id
324 FROM channel_memberships
325 WHERE user_id = $1 AND channel_id = $2
326 LIMIT 1
327 ";
328 Ok(sqlx::query_scalar::<_, i32>(query)
329 .bind(user_id.0)
330 .bind(channel_id.0)
331 .fetch_optional(&self.pool)
332 .await
333 .map(|e| e.is_some())?)
334 }
335
336 #[cfg(any(test, feature = "seed-support"))]
337 async fn add_channel_member(
338 &self,
339 channel_id: ChannelId,
340 user_id: UserId,
341 is_admin: bool,
342 ) -> Result<()> {
343 let query = "
344 INSERT INTO channel_memberships (channel_id, user_id, admin)
345 VALUES ($1, $2, $3)
346 ON CONFLICT DO NOTHING
347 ";
348 Ok(sqlx::query(query)
349 .bind(channel_id.0)
350 .bind(user_id.0)
351 .bind(is_admin)
352 .execute(&self.pool)
353 .await
354 .map(drop)?)
355 }
356
357 // messages
358
359 async fn create_channel_message(
360 &self,
361 channel_id: ChannelId,
362 sender_id: UserId,
363 body: &str,
364 timestamp: OffsetDateTime,
365 nonce: u128,
366 ) -> Result<MessageId> {
367 let query = "
368 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
369 VALUES ($1, $2, $3, $4, $5)
370 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
371 RETURNING id
372 ";
373 Ok(sqlx::query_scalar(query)
374 .bind(channel_id.0)
375 .bind(sender_id.0)
376 .bind(body)
377 .bind(timestamp)
378 .bind(Uuid::from_u128(nonce))
379 .fetch_one(&self.pool)
380 .await
381 .map(MessageId)?)
382 }
383
384 async fn get_channel_messages(
385 &self,
386 channel_id: ChannelId,
387 count: usize,
388 before_id: Option<MessageId>,
389 ) -> Result<Vec<ChannelMessage>> {
390 let query = r#"
391 SELECT * FROM (
392 SELECT
393 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
394 FROM
395 channel_messages
396 WHERE
397 channel_id = $1 AND
398 id < $2
399 ORDER BY id DESC
400 LIMIT $3
401 ) as recent_messages
402 ORDER BY id ASC
403 "#;
404 Ok(sqlx::query_as(query)
405 .bind(channel_id.0)
406 .bind(before_id.unwrap_or(MessageId::MAX))
407 .bind(count as i64)
408 .fetch_all(&self.pool)
409 .await?)
410 }
411
412 #[cfg(test)]
413 async fn teardown(&self, url: &str) {
414 use util::ResultExt;
415
416 let query = "
417 SELECT pg_terminate_backend(pg_stat_activity.pid)
418 FROM pg_stat_activity
419 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
420 ";
421 sqlx::query(query).execute(&self.pool).await.log_err();
422 self.pool.close().await;
423 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
424 .await
425 .log_err();
426 }
427}
428
429macro_rules! id_type {
430 ($name:ident) => {
431 #[derive(
432 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
433 )]
434 #[sqlx(transparent)]
435 #[serde(transparent)]
436 pub struct $name(pub i32);
437
438 impl $name {
439 #[allow(unused)]
440 pub const MAX: Self = Self(i32::MAX);
441
442 #[allow(unused)]
443 pub fn from_proto(value: u64) -> Self {
444 Self(value as i32)
445 }
446
447 #[allow(unused)]
448 pub fn to_proto(&self) -> u64 {
449 self.0 as u64
450 }
451 }
452
453 impl std::fmt::Display for $name {
454 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
455 self.0.fmt(f)
456 }
457 }
458 };
459}
460
461id_type!(UserId);
462#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
463pub struct User {
464 pub id: UserId,
465 pub github_login: String,
466 pub admin: bool,
467}
468
469id_type!(OrgId);
470#[derive(FromRow)]
471pub struct Org {
472 pub id: OrgId,
473 pub name: String,
474 pub slug: String,
475}
476
477id_type!(ChannelId);
478#[derive(Clone, Debug, FromRow, Serialize)]
479pub struct Channel {
480 pub id: ChannelId,
481 pub name: String,
482 pub owner_id: i32,
483 pub owner_is_user: bool,
484}
485
486id_type!(MessageId);
487#[derive(Clone, Debug, FromRow)]
488pub struct ChannelMessage {
489 pub id: MessageId,
490 pub channel_id: ChannelId,
491 pub sender_id: UserId,
492 pub body: String,
493 pub sent_at: OffsetDateTime,
494 pub nonce: Uuid,
495}
496
497fn fuzzy_like_string(string: &str) -> String {
498 let mut result = String::with_capacity(string.len() * 2 + 1);
499 for c in string.chars() {
500 if c.is_alphanumeric() {
501 result.push('%');
502 result.push(c);
503 }
504 }
505 result.push('%');
506 result
507}
508
509#[cfg(test)]
510pub mod tests {
511 use super::*;
512 use anyhow::anyhow;
513 use collections::BTreeMap;
514 use gpui::executor::Background;
515 use lazy_static::lazy_static;
516 use parking_lot::Mutex;
517 use rand::prelude::*;
518 use sqlx::{
519 migrate::{MigrateDatabase, Migrator},
520 Postgres,
521 };
522 use std::{path::Path, sync::Arc};
523 use util::post_inc;
524
525 #[tokio::test(flavor = "multi_thread")]
526 async fn test_get_users_by_ids() {
527 for test_db in [
528 TestDb::postgres().await,
529 TestDb::fake(Arc::new(gpui::executor::Background::new())),
530 ] {
531 let db = test_db.db();
532
533 let user = db.create_user("user", false).await.unwrap();
534 let friend1 = db.create_user("friend-1", false).await.unwrap();
535 let friend2 = db.create_user("friend-2", false).await.unwrap();
536 let friend3 = db.create_user("friend-3", false).await.unwrap();
537
538 assert_eq!(
539 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
540 .await
541 .unwrap(),
542 vec![
543 User {
544 id: user,
545 github_login: "user".to_string(),
546 admin: false,
547 },
548 User {
549 id: friend1,
550 github_login: "friend-1".to_string(),
551 admin: false,
552 },
553 User {
554 id: friend2,
555 github_login: "friend-2".to_string(),
556 admin: false,
557 },
558 User {
559 id: friend3,
560 github_login: "friend-3".to_string(),
561 admin: false,
562 }
563 ]
564 );
565 }
566 }
567
568 #[tokio::test(flavor = "multi_thread")]
569 async fn test_recent_channel_messages() {
570 for test_db in [
571 TestDb::postgres().await,
572 TestDb::fake(Arc::new(gpui::executor::Background::new())),
573 ] {
574 let db = test_db.db();
575 let user = db.create_user("user", false).await.unwrap();
576 let org = db.create_org("org", "org").await.unwrap();
577 let channel = db.create_org_channel(org, "channel").await.unwrap();
578 for i in 0..10 {
579 db.create_channel_message(
580 channel,
581 user,
582 &i.to_string(),
583 OffsetDateTime::now_utc(),
584 i,
585 )
586 .await
587 .unwrap();
588 }
589
590 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
591 assert_eq!(
592 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
593 ["5", "6", "7", "8", "9"]
594 );
595
596 let prev_messages = db
597 .get_channel_messages(channel, 4, Some(messages[0].id))
598 .await
599 .unwrap();
600 assert_eq!(
601 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
602 ["1", "2", "3", "4"]
603 );
604 }
605 }
606
607 #[tokio::test(flavor = "multi_thread")]
608 async fn test_channel_message_nonces() {
609 for test_db in [
610 TestDb::postgres().await,
611 TestDb::fake(Arc::new(gpui::executor::Background::new())),
612 ] {
613 let db = test_db.db();
614 let user = db.create_user("user", false).await.unwrap();
615 let org = db.create_org("org", "org").await.unwrap();
616 let channel = db.create_org_channel(org, "channel").await.unwrap();
617
618 let msg1_id = db
619 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
620 .await
621 .unwrap();
622 let msg2_id = db
623 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
624 .await
625 .unwrap();
626 let msg3_id = db
627 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
628 .await
629 .unwrap();
630 let msg4_id = db
631 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
632 .await
633 .unwrap();
634
635 assert_ne!(msg1_id, msg2_id);
636 assert_eq!(msg1_id, msg3_id);
637 assert_eq!(msg2_id, msg4_id);
638 }
639 }
640
641 #[tokio::test(flavor = "multi_thread")]
642 async fn test_create_access_tokens() {
643 let test_db = TestDb::postgres().await;
644 let db = test_db.db();
645 let user = db.create_user("the-user", false).await.unwrap();
646
647 db.create_access_token_hash(user, "h1", 3).await.unwrap();
648 db.create_access_token_hash(user, "h2", 3).await.unwrap();
649 assert_eq!(
650 db.get_access_token_hashes(user).await.unwrap(),
651 &["h2".to_string(), "h1".to_string()]
652 );
653
654 db.create_access_token_hash(user, "h3", 3).await.unwrap();
655 assert_eq!(
656 db.get_access_token_hashes(user).await.unwrap(),
657 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
658 );
659
660 db.create_access_token_hash(user, "h4", 3).await.unwrap();
661 assert_eq!(
662 db.get_access_token_hashes(user).await.unwrap(),
663 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
664 );
665
666 db.create_access_token_hash(user, "h5", 3).await.unwrap();
667 assert_eq!(
668 db.get_access_token_hashes(user).await.unwrap(),
669 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
670 );
671 }
672
673 #[test]
674 fn test_fuzzy_like_string() {
675 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
676 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
677 assert_eq!(fuzzy_like_string(" z "), "%z%");
678 }
679
680 #[tokio::test(flavor = "multi_thread")]
681 async fn test_fuzzy_search_users() {
682 let test_db = TestDb::postgres().await;
683 let db = test_db.db();
684 for github_login in [
685 "california",
686 "colorado",
687 "oregon",
688 "washington",
689 "florida",
690 "delaware",
691 "rhode-island",
692 ] {
693 db.create_user(github_login, false).await.unwrap();
694 }
695
696 assert_eq!(
697 fuzzy_search_user_names(db, "clr").await,
698 &["colorado", "california"]
699 );
700 assert_eq!(
701 fuzzy_search_user_names(db, "ro").await,
702 &["rhode-island", "colorado", "oregon"],
703 );
704
705 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
706 db.fuzzy_search_users(query, 10)
707 .await
708 .unwrap()
709 .into_iter()
710 .map(|user| user.github_login)
711 .collect::<Vec<_>>()
712 }
713 }
714
715 pub struct TestDb {
716 pub db: Option<Arc<dyn Db>>,
717 pub url: String,
718 }
719
720 impl TestDb {
721 pub async fn postgres() -> Self {
722 lazy_static! {
723 static ref LOCK: Mutex<()> = Mutex::new(());
724 }
725
726 let _guard = LOCK.lock();
727 let mut rng = StdRng::from_entropy();
728 let name = format!("zed-test-{}", rng.gen::<u128>());
729 let url = format!("postgres://postgres@localhost/{}", name);
730 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
731 Postgres::create_database(&url)
732 .await
733 .expect("failed to create test db");
734 let db = PostgresDb::new(&url, 5).await.unwrap();
735 let migrator = Migrator::new(migrations_path).await.unwrap();
736 migrator.run(&db.pool).await.unwrap();
737 Self {
738 db: Some(Arc::new(db)),
739 url,
740 }
741 }
742
743 pub fn fake(background: Arc<Background>) -> Self {
744 Self {
745 db: Some(Arc::new(FakeDb::new(background))),
746 url: Default::default(),
747 }
748 }
749
750 pub fn db(&self) -> &Arc<dyn Db> {
751 self.db.as_ref().unwrap()
752 }
753 }
754
755 impl Drop for TestDb {
756 fn drop(&mut self) {
757 if let Some(db) = self.db.take() {
758 futures::executor::block_on(db.teardown(&self.url));
759 }
760 }
761 }
762
763 pub struct FakeDb {
764 background: Arc<Background>,
765 users: Mutex<BTreeMap<UserId, User>>,
766 next_user_id: Mutex<i32>,
767 orgs: Mutex<BTreeMap<OrgId, Org>>,
768 next_org_id: Mutex<i32>,
769 org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
770 channels: Mutex<BTreeMap<ChannelId, Channel>>,
771 next_channel_id: Mutex<i32>,
772 channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
773 channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
774 next_channel_message_id: Mutex<i32>,
775 }
776
777 impl FakeDb {
778 pub fn new(background: Arc<Background>) -> Self {
779 Self {
780 background,
781 users: Default::default(),
782 next_user_id: Mutex::new(1),
783 orgs: Default::default(),
784 next_org_id: Mutex::new(1),
785 org_memberships: Default::default(),
786 channels: Default::default(),
787 next_channel_id: Mutex::new(1),
788 channel_memberships: Default::default(),
789 channel_messages: Default::default(),
790 next_channel_message_id: Mutex::new(1),
791 }
792 }
793 }
794
795 #[async_trait]
796 impl Db for FakeDb {
797 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
798 self.background.simulate_random_delay().await;
799
800 let mut users = self.users.lock();
801 if let Some(user) = users
802 .values()
803 .find(|user| user.github_login == github_login)
804 {
805 Ok(user.id)
806 } else {
807 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
808 users.insert(
809 user_id,
810 User {
811 id: user_id,
812 github_login: github_login.to_string(),
813 admin,
814 },
815 );
816 Ok(user_id)
817 }
818 }
819
820 async fn get_all_users(&self) -> Result<Vec<User>> {
821 unimplemented!()
822 }
823
824 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
825 unimplemented!()
826 }
827
828 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
829 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
830 }
831
832 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
833 self.background.simulate_random_delay().await;
834 let users = self.users.lock();
835 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
836 }
837
838 async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
839 unimplemented!()
840 }
841
842 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
843 unimplemented!()
844 }
845
846 async fn destroy_user(&self, _id: UserId) -> Result<()> {
847 unimplemented!()
848 }
849
850 async fn create_access_token_hash(
851 &self,
852 _user_id: UserId,
853 _access_token_hash: &str,
854 _max_access_token_count: usize,
855 ) -> Result<()> {
856 unimplemented!()
857 }
858
859 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
860 unimplemented!()
861 }
862
863 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
864 unimplemented!()
865 }
866
867 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
868 self.background.simulate_random_delay().await;
869 let mut orgs = self.orgs.lock();
870 if orgs.values().any(|org| org.slug == slug) {
871 Err(anyhow!("org already exists"))
872 } else {
873 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
874 orgs.insert(
875 org_id,
876 Org {
877 id: org_id,
878 name: name.to_string(),
879 slug: slug.to_string(),
880 },
881 );
882 Ok(org_id)
883 }
884 }
885
886 async fn add_org_member(
887 &self,
888 org_id: OrgId,
889 user_id: UserId,
890 is_admin: bool,
891 ) -> Result<()> {
892 self.background.simulate_random_delay().await;
893 if !self.orgs.lock().contains_key(&org_id) {
894 return Err(anyhow!("org does not exist"));
895 }
896 if !self.users.lock().contains_key(&user_id) {
897 return Err(anyhow!("user does not exist"));
898 }
899
900 self.org_memberships
901 .lock()
902 .entry((org_id, user_id))
903 .or_insert(is_admin);
904 Ok(())
905 }
906
907 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
908 self.background.simulate_random_delay().await;
909 if !self.orgs.lock().contains_key(&org_id) {
910 return Err(anyhow!("org does not exist"));
911 }
912
913 let mut channels = self.channels.lock();
914 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
915 channels.insert(
916 channel_id,
917 Channel {
918 id: channel_id,
919 name: name.to_string(),
920 owner_id: org_id.0,
921 owner_is_user: false,
922 },
923 );
924 Ok(channel_id)
925 }
926
927 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
928 self.background.simulate_random_delay().await;
929 Ok(self
930 .channels
931 .lock()
932 .values()
933 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
934 .cloned()
935 .collect())
936 }
937
938 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
939 self.background.simulate_random_delay().await;
940 let channels = self.channels.lock();
941 let memberships = self.channel_memberships.lock();
942 Ok(channels
943 .values()
944 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
945 .cloned()
946 .collect())
947 }
948
949 async fn can_user_access_channel(
950 &self,
951 user_id: UserId,
952 channel_id: ChannelId,
953 ) -> Result<bool> {
954 self.background.simulate_random_delay().await;
955 Ok(self
956 .channel_memberships
957 .lock()
958 .contains_key(&(channel_id, user_id)))
959 }
960
961 async fn add_channel_member(
962 &self,
963 channel_id: ChannelId,
964 user_id: UserId,
965 is_admin: bool,
966 ) -> Result<()> {
967 self.background.simulate_random_delay().await;
968 if !self.channels.lock().contains_key(&channel_id) {
969 return Err(anyhow!("channel does not exist"));
970 }
971 if !self.users.lock().contains_key(&user_id) {
972 return Err(anyhow!("user does not exist"));
973 }
974
975 self.channel_memberships
976 .lock()
977 .entry((channel_id, user_id))
978 .or_insert(is_admin);
979 Ok(())
980 }
981
982 async fn create_channel_message(
983 &self,
984 channel_id: ChannelId,
985 sender_id: UserId,
986 body: &str,
987 timestamp: OffsetDateTime,
988 nonce: u128,
989 ) -> Result<MessageId> {
990 self.background.simulate_random_delay().await;
991 if !self.channels.lock().contains_key(&channel_id) {
992 return Err(anyhow!("channel does not exist"));
993 }
994 if !self.users.lock().contains_key(&sender_id) {
995 return Err(anyhow!("user does not exist"));
996 }
997
998 let mut messages = self.channel_messages.lock();
999 if let Some(message) = messages
1000 .values()
1001 .find(|message| message.nonce.as_u128() == nonce)
1002 {
1003 Ok(message.id)
1004 } else {
1005 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
1006 messages.insert(
1007 message_id,
1008 ChannelMessage {
1009 id: message_id,
1010 channel_id,
1011 sender_id,
1012 body: body.to_string(),
1013 sent_at: timestamp,
1014 nonce: Uuid::from_u128(nonce),
1015 },
1016 );
1017 Ok(message_id)
1018 }
1019 }
1020
1021 async fn get_channel_messages(
1022 &self,
1023 channel_id: ChannelId,
1024 count: usize,
1025 before_id: Option<MessageId>,
1026 ) -> Result<Vec<ChannelMessage>> {
1027 let mut messages = self
1028 .channel_messages
1029 .lock()
1030 .values()
1031 .rev()
1032 .filter(|message| {
1033 message.channel_id == channel_id
1034 && message.id < before_id.unwrap_or(MessageId::MAX)
1035 })
1036 .take(count)
1037 .cloned()
1038 .collect::<Vec<_>>();
1039 messages.sort_unstable_by_key(|message| message.id);
1040 Ok(messages)
1041 }
1042
1043 async fn teardown(&self, _: &str) {}
1044 }
1045}