1use anyhow::Context;
2use anyhow::Result;
3use async_trait::async_trait;
4use futures::executor::block_on;
5use serde::Serialize;
6pub use sqlx::postgres::PgPoolOptions as DbOptions;
7use sqlx::{types::Uuid, FromRow};
8use time::OffsetDateTime;
9use tokio::task::yield_now;
10
11macro_rules! test_support {
12 ($self:ident, { $($token:tt)* }) => {{
13 let body = async {
14 $($token)*
15 };
16 if $self.test_mode {
17 yield_now().await;
18 block_on(body)
19 } else {
20 body.await
21 }
22 }};
23}
24
25#[async_trait]
26pub trait Db: Send + Sync {
27 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
28 async fn get_all_users(&self) -> Result<Vec<User>>;
29 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
30 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
31 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
32 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
33 async fn destroy_user(&self, id: UserId) -> Result<()>;
34 async fn create_access_token_hash(
35 &self,
36 user_id: UserId,
37 access_token_hash: &str,
38 max_access_token_count: usize,
39 ) -> Result<()>;
40 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
41 #[cfg(any(test, feature = "seed-support"))]
42 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
43 #[cfg(any(test, feature = "seed-support"))]
44 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
45 #[cfg(any(test, feature = "seed-support"))]
46 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
47 #[cfg(any(test, feature = "seed-support"))]
48 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
49 #[cfg(any(test, feature = "seed-support"))]
50 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
51 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
52 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
53 -> Result<bool>;
54 #[cfg(any(test, feature = "seed-support"))]
55 async fn add_channel_member(
56 &self,
57 channel_id: ChannelId,
58 user_id: UserId,
59 is_admin: bool,
60 ) -> Result<()>;
61 async fn create_channel_message(
62 &self,
63 channel_id: ChannelId,
64 sender_id: UserId,
65 body: &str,
66 timestamp: OffsetDateTime,
67 nonce: u128,
68 ) -> Result<MessageId>;
69 async fn get_channel_messages(
70 &self,
71 channel_id: ChannelId,
72 count: usize,
73 before_id: Option<MessageId>,
74 ) -> Result<Vec<ChannelMessage>>;
75 #[cfg(test)]
76 async fn teardown(&self, name: &str, url: &str);
77}
78
79pub struct PostgresDb {
80 pool: sqlx::PgPool,
81 test_mode: bool,
82}
83
84impl PostgresDb {
85 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
86 let pool = DbOptions::new()
87 .max_connections(max_connections)
88 .connect(url)
89 .await
90 .context("failed to connect to postgres database")?;
91 Ok(Self {
92 pool,
93 test_mode: false,
94 })
95 }
96}
97
98#[async_trait]
99impl Db for PostgresDb {
100 // users
101
102 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
103 test_support!(self, {
104 let query = "
105 INSERT INTO users (github_login, admin)
106 VALUES ($1, $2)
107 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
108 RETURNING id
109 ";
110 Ok(sqlx::query_scalar(query)
111 .bind(github_login)
112 .bind(admin)
113 .fetch_one(&self.pool)
114 .await
115 .map(UserId)?)
116 })
117 }
118
119 async fn get_all_users(&self) -> Result<Vec<User>> {
120 test_support!(self, {
121 let query = "SELECT * FROM users ORDER BY github_login ASC";
122 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
123 })
124 }
125
126 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
127 let users = self.get_users_by_ids(vec![id]).await?;
128 Ok(users.into_iter().next())
129 }
130
131 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
132 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
133 test_support!(self, {
134 let query = "
135 SELECT users.*
136 FROM users
137 WHERE users.id = ANY ($1)
138 ";
139
140 Ok(sqlx::query_as(query)
141 .bind(&ids)
142 .fetch_all(&self.pool)
143 .await?)
144 })
145 }
146
147 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
148 test_support!(self, {
149 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
150 Ok(sqlx::query_as(query)
151 .bind(github_login)
152 .fetch_optional(&self.pool)
153 .await?)
154 })
155 }
156
157 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
158 test_support!(self, {
159 let query = "UPDATE users SET admin = $1 WHERE id = $2";
160 Ok(sqlx::query(query)
161 .bind(is_admin)
162 .bind(id.0)
163 .execute(&self.pool)
164 .await
165 .map(drop)?)
166 })
167 }
168
169 async fn destroy_user(&self, id: UserId) -> Result<()> {
170 test_support!(self, {
171 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
172 sqlx::query(query)
173 .bind(id.0)
174 .execute(&self.pool)
175 .await
176 .map(drop)?;
177 let query = "DELETE FROM users WHERE id = $1;";
178 Ok(sqlx::query(query)
179 .bind(id.0)
180 .execute(&self.pool)
181 .await
182 .map(drop)?)
183 })
184 }
185
186 // access tokens
187
188 async fn create_access_token_hash(
189 &self,
190 user_id: UserId,
191 access_token_hash: &str,
192 max_access_token_count: usize,
193 ) -> Result<()> {
194 test_support!(self, {
195 let insert_query = "
196 INSERT INTO access_tokens (user_id, hash)
197 VALUES ($1, $2);
198 ";
199 let cleanup_query = "
200 DELETE FROM access_tokens
201 WHERE id IN (
202 SELECT id from access_tokens
203 WHERE user_id = $1
204 ORDER BY id DESC
205 OFFSET $3
206 )
207 ";
208
209 let mut tx = self.pool.begin().await?;
210 sqlx::query(insert_query)
211 .bind(user_id.0)
212 .bind(access_token_hash)
213 .execute(&mut tx)
214 .await?;
215 sqlx::query(cleanup_query)
216 .bind(user_id.0)
217 .bind(access_token_hash)
218 .bind(max_access_token_count as u32)
219 .execute(&mut tx)
220 .await?;
221 Ok(tx.commit().await?)
222 })
223 }
224
225 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
226 test_support!(self, {
227 let query = "
228 SELECT hash
229 FROM access_tokens
230 WHERE user_id = $1
231 ORDER BY id DESC
232 ";
233 Ok(sqlx::query_scalar(query)
234 .bind(user_id.0)
235 .fetch_all(&self.pool)
236 .await?)
237 })
238 }
239
240 // orgs
241
242 #[allow(unused)] // Help rust-analyzer
243 #[cfg(any(test, feature = "seed-support"))]
244 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
245 test_support!(self, {
246 let query = "
247 SELECT *
248 FROM orgs
249 WHERE slug = $1
250 ";
251 Ok(sqlx::query_as(query)
252 .bind(slug)
253 .fetch_optional(&self.pool)
254 .await?)
255 })
256 }
257
258 #[cfg(any(test, feature = "seed-support"))]
259 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
260 test_support!(self, {
261 let query = "
262 INSERT INTO orgs (name, slug)
263 VALUES ($1, $2)
264 RETURNING id
265 ";
266 Ok(sqlx::query_scalar(query)
267 .bind(name)
268 .bind(slug)
269 .fetch_one(&self.pool)
270 .await
271 .map(OrgId)?)
272 })
273 }
274
275 #[cfg(any(test, feature = "seed-support"))]
276 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
277 test_support!(self, {
278 let query = "
279 INSERT INTO org_memberships (org_id, user_id, admin)
280 VALUES ($1, $2, $3)
281 ON CONFLICT DO NOTHING
282 ";
283 Ok(sqlx::query(query)
284 .bind(org_id.0)
285 .bind(user_id.0)
286 .bind(is_admin)
287 .execute(&self.pool)
288 .await
289 .map(drop)?)
290 })
291 }
292
293 // channels
294
295 #[cfg(any(test, feature = "seed-support"))]
296 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
297 test_support!(self, {
298 let query = "
299 INSERT INTO channels (owner_id, owner_is_user, name)
300 VALUES ($1, false, $2)
301 RETURNING id
302 ";
303 Ok(sqlx::query_scalar(query)
304 .bind(org_id.0)
305 .bind(name)
306 .fetch_one(&self.pool)
307 .await
308 .map(ChannelId)?)
309 })
310 }
311
312 #[allow(unused)] // Help rust-analyzer
313 #[cfg(any(test, feature = "seed-support"))]
314 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
315 test_support!(self, {
316 let query = "
317 SELECT *
318 FROM channels
319 WHERE
320 channels.owner_is_user = false AND
321 channels.owner_id = $1
322 ";
323 Ok(sqlx::query_as(query)
324 .bind(org_id.0)
325 .fetch_all(&self.pool)
326 .await?)
327 })
328 }
329
330 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
331 test_support!(self, {
332 let query = "
333 SELECT
334 channels.*
335 FROM
336 channel_memberships, channels
337 WHERE
338 channel_memberships.user_id = $1 AND
339 channel_memberships.channel_id = channels.id
340 ";
341 Ok(sqlx::query_as(query)
342 .bind(user_id.0)
343 .fetch_all(&self.pool)
344 .await?)
345 })
346 }
347
348 async fn can_user_access_channel(
349 &self,
350 user_id: UserId,
351 channel_id: ChannelId,
352 ) -> Result<bool> {
353 test_support!(self, {
354 let query = "
355 SELECT id
356 FROM channel_memberships
357 WHERE user_id = $1 AND channel_id = $2
358 LIMIT 1
359 ";
360 Ok(sqlx::query_scalar::<_, i32>(query)
361 .bind(user_id.0)
362 .bind(channel_id.0)
363 .fetch_optional(&self.pool)
364 .await
365 .map(|e| e.is_some())?)
366 })
367 }
368
369 #[cfg(any(test, feature = "seed-support"))]
370 async fn add_channel_member(
371 &self,
372 channel_id: ChannelId,
373 user_id: UserId,
374 is_admin: bool,
375 ) -> Result<()> {
376 test_support!(self, {
377 let query = "
378 INSERT INTO channel_memberships (channel_id, user_id, admin)
379 VALUES ($1, $2, $3)
380 ON CONFLICT DO NOTHING
381 ";
382 Ok(sqlx::query(query)
383 .bind(channel_id.0)
384 .bind(user_id.0)
385 .bind(is_admin)
386 .execute(&self.pool)
387 .await
388 .map(drop)?)
389 })
390 }
391
392 // messages
393
394 async fn create_channel_message(
395 &self,
396 channel_id: ChannelId,
397 sender_id: UserId,
398 body: &str,
399 timestamp: OffsetDateTime,
400 nonce: u128,
401 ) -> Result<MessageId> {
402 test_support!(self, {
403 let query = "
404 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
405 VALUES ($1, $2, $3, $4, $5)
406 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
407 RETURNING id
408 ";
409 Ok(sqlx::query_scalar(query)
410 .bind(channel_id.0)
411 .bind(sender_id.0)
412 .bind(body)
413 .bind(timestamp)
414 .bind(Uuid::from_u128(nonce))
415 .fetch_one(&self.pool)
416 .await
417 .map(MessageId)?)
418 })
419 }
420
421 async fn get_channel_messages(
422 &self,
423 channel_id: ChannelId,
424 count: usize,
425 before_id: Option<MessageId>,
426 ) -> Result<Vec<ChannelMessage>> {
427 test_support!(self, {
428 let query = r#"
429 SELECT * FROM (
430 SELECT
431 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
432 FROM
433 channel_messages
434 WHERE
435 channel_id = $1 AND
436 id < $2
437 ORDER BY id DESC
438 LIMIT $3
439 ) as recent_messages
440 ORDER BY id ASC
441 "#;
442 Ok(sqlx::query_as(query)
443 .bind(channel_id.0)
444 .bind(before_id.unwrap_or(MessageId::MAX))
445 .bind(count as i64)
446 .fetch_all(&self.pool)
447 .await?)
448 })
449 }
450
451 #[cfg(test)]
452 async fn teardown(&self, name: &str, url: &str) {
453 use util::ResultExt;
454
455 test_support!(self, {
456 let query = "
457 SELECT pg_terminate_backend(pg_stat_activity.pid)
458 FROM pg_stat_activity
459 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
460 ";
461 sqlx::query(query)
462 .bind(name)
463 .execute(&self.pool)
464 .await
465 .log_err();
466 self.pool.close().await;
467 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
468 .await
469 .log_err();
470 })
471 }
472}
473
474macro_rules! id_type {
475 ($name:ident) => {
476 #[derive(
477 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
478 )]
479 #[sqlx(transparent)]
480 #[serde(transparent)]
481 pub struct $name(pub i32);
482
483 impl $name {
484 #[allow(unused)]
485 pub const MAX: Self = Self(i32::MAX);
486
487 #[allow(unused)]
488 pub fn from_proto(value: u64) -> Self {
489 Self(value as i32)
490 }
491
492 #[allow(unused)]
493 pub fn to_proto(&self) -> u64 {
494 self.0 as u64
495 }
496 }
497 };
498}
499
500id_type!(UserId);
501#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
502pub struct User {
503 pub id: UserId,
504 pub github_login: String,
505 pub admin: bool,
506}
507
508id_type!(OrgId);
509#[derive(FromRow)]
510pub struct Org {
511 pub id: OrgId,
512 pub name: String,
513 pub slug: String,
514}
515
516id_type!(ChannelId);
517#[derive(Clone, Debug, FromRow, Serialize)]
518pub struct Channel {
519 pub id: ChannelId,
520 pub name: String,
521 pub owner_id: i32,
522 pub owner_is_user: bool,
523}
524
525id_type!(MessageId);
526#[derive(Clone, Debug, FromRow)]
527pub struct ChannelMessage {
528 pub id: MessageId,
529 pub channel_id: ChannelId,
530 pub sender_id: UserId,
531 pub body: String,
532 pub sent_at: OffsetDateTime,
533 pub nonce: Uuid,
534}
535
536#[cfg(test)]
537pub mod tests {
538 use super::*;
539 use anyhow::anyhow;
540 use collections::BTreeMap;
541 use gpui::{executor::Background, TestAppContext};
542 use lazy_static::lazy_static;
543 use parking_lot::Mutex;
544 use rand::prelude::*;
545 use sqlx::{
546 migrate::{MigrateDatabase, Migrator},
547 Postgres,
548 };
549 use std::{path::Path, sync::Arc};
550 use util::post_inc;
551
552 #[gpui::test]
553 async fn test_get_users_by_ids(cx: &mut TestAppContext) {
554 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
555 let db = test_db.db();
556
557 let user = db.create_user("user", false).await.unwrap();
558 let friend1 = db.create_user("friend-1", false).await.unwrap();
559 let friend2 = db.create_user("friend-2", false).await.unwrap();
560 let friend3 = db.create_user("friend-3", false).await.unwrap();
561
562 assert_eq!(
563 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
564 .await
565 .unwrap(),
566 vec![
567 User {
568 id: user,
569 github_login: "user".to_string(),
570 admin: false,
571 },
572 User {
573 id: friend1,
574 github_login: "friend-1".to_string(),
575 admin: false,
576 },
577 User {
578 id: friend2,
579 github_login: "friend-2".to_string(),
580 admin: false,
581 },
582 User {
583 id: friend3,
584 github_login: "friend-3".to_string(),
585 admin: false,
586 }
587 ]
588 );
589 }
590 }
591
592 #[gpui::test]
593 async fn test_recent_channel_messages(cx: &mut TestAppContext) {
594 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
595 let db = test_db.db();
596 let user = db.create_user("user", false).await.unwrap();
597 let org = db.create_org("org", "org").await.unwrap();
598 let channel = db.create_org_channel(org, "channel").await.unwrap();
599 for i in 0..10 {
600 db.create_channel_message(
601 channel,
602 user,
603 &i.to_string(),
604 OffsetDateTime::now_utc(),
605 i,
606 )
607 .await
608 .unwrap();
609 }
610
611 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
612 assert_eq!(
613 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
614 ["5", "6", "7", "8", "9"]
615 );
616
617 let prev_messages = db
618 .get_channel_messages(channel, 4, Some(messages[0].id))
619 .await
620 .unwrap();
621 assert_eq!(
622 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
623 ["1", "2", "3", "4"]
624 );
625 }
626 }
627
628 #[gpui::test]
629 async fn test_channel_message_nonces(cx: &mut TestAppContext) {
630 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
631 let db = test_db.db();
632 let user = db.create_user("user", false).await.unwrap();
633 let org = db.create_org("org", "org").await.unwrap();
634 let channel = db.create_org_channel(org, "channel").await.unwrap();
635
636 let msg1_id = db
637 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
638 .await
639 .unwrap();
640 let msg2_id = db
641 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
642 .await
643 .unwrap();
644 let msg3_id = db
645 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
646 .await
647 .unwrap();
648 let msg4_id = db
649 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
650 .await
651 .unwrap();
652
653 assert_ne!(msg1_id, msg2_id);
654 assert_eq!(msg1_id, msg3_id);
655 assert_eq!(msg2_id, msg4_id);
656 }
657 }
658
659 #[gpui::test]
660 async fn test_create_access_tokens() {
661 let test_db = TestDb::postgres();
662 let db = test_db.db();
663 let user = db.create_user("the-user", false).await.unwrap();
664
665 db.create_access_token_hash(user, "h1", 3).await.unwrap();
666 db.create_access_token_hash(user, "h2", 3).await.unwrap();
667 assert_eq!(
668 db.get_access_token_hashes(user).await.unwrap(),
669 &["h2".to_string(), "h1".to_string()]
670 );
671
672 db.create_access_token_hash(user, "h3", 3).await.unwrap();
673 assert_eq!(
674 db.get_access_token_hashes(user).await.unwrap(),
675 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
676 );
677
678 db.create_access_token_hash(user, "h4", 3).await.unwrap();
679 assert_eq!(
680 db.get_access_token_hashes(user).await.unwrap(),
681 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
682 );
683
684 db.create_access_token_hash(user, "h5", 3).await.unwrap();
685 assert_eq!(
686 db.get_access_token_hashes(user).await.unwrap(),
687 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
688 );
689 }
690
691 pub struct TestDb {
692 pub db: Option<Arc<dyn Db>>,
693 pub name: String,
694 pub url: String,
695 }
696
697 impl TestDb {
698 pub fn postgres() -> Self {
699 lazy_static! {
700 static ref LOCK: Mutex<()> = Mutex::new(());
701 }
702
703 let _guard = LOCK.lock();
704 let mut rng = StdRng::from_entropy();
705 let name = format!("zed-test-{}", rng.gen::<u128>());
706 let url = format!("postgres://postgres@localhost/{}", name);
707 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
708 let db = block_on(async {
709 Postgres::create_database(&url)
710 .await
711 .expect("failed to create test db");
712 let mut db = PostgresDb::new(&url, 5).await.unwrap();
713 db.test_mode = true;
714 let migrator = Migrator::new(migrations_path).await.unwrap();
715 migrator.run(&db.pool).await.unwrap();
716 db
717 });
718 Self {
719 db: Some(Arc::new(db)),
720 name,
721 url,
722 }
723 }
724
725 pub fn fake(background: Arc<Background>) -> Self {
726 Self {
727 db: Some(Arc::new(FakeDb::new(background))),
728 name: "fake".to_string(),
729 url: "fake".to_string(),
730 }
731 }
732
733 pub fn db(&self) -> &Arc<dyn Db> {
734 self.db.as_ref().unwrap()
735 }
736 }
737
738 impl Drop for TestDb {
739 fn drop(&mut self) {
740 if let Some(db) = self.db.take() {
741 block_on(db.teardown(&self.name, &self.url));
742 }
743 }
744 }
745
746 pub struct FakeDb {
747 background: Arc<Background>,
748 users: Mutex<BTreeMap<UserId, User>>,
749 next_user_id: Mutex<i32>,
750 orgs: Mutex<BTreeMap<OrgId, Org>>,
751 next_org_id: Mutex<i32>,
752 org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
753 channels: Mutex<BTreeMap<ChannelId, Channel>>,
754 next_channel_id: Mutex<i32>,
755 channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
756 channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
757 next_channel_message_id: Mutex<i32>,
758 }
759
760 impl FakeDb {
761 pub fn new(background: Arc<Background>) -> Self {
762 Self {
763 background,
764 users: Default::default(),
765 next_user_id: Mutex::new(1),
766 orgs: Default::default(),
767 next_org_id: Mutex::new(1),
768 org_memberships: Default::default(),
769 channels: Default::default(),
770 next_channel_id: Mutex::new(1),
771 channel_memberships: Default::default(),
772 channel_messages: Default::default(),
773 next_channel_message_id: Mutex::new(1),
774 }
775 }
776 }
777
778 #[async_trait]
779 impl Db for FakeDb {
780 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
781 self.background.simulate_random_delay().await;
782
783 let mut users = self.users.lock();
784 if let Some(user) = users
785 .values()
786 .find(|user| user.github_login == github_login)
787 {
788 Ok(user.id)
789 } else {
790 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
791 users.insert(
792 user_id,
793 User {
794 id: user_id,
795 github_login: github_login.to_string(),
796 admin,
797 },
798 );
799 Ok(user_id)
800 }
801 }
802
803 async fn get_all_users(&self) -> Result<Vec<User>> {
804 unimplemented!()
805 }
806
807 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
808 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
809 }
810
811 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
812 self.background.simulate_random_delay().await;
813 let users = self.users.lock();
814 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
815 }
816
817 async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
818 unimplemented!()
819 }
820
821 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
822 unimplemented!()
823 }
824
825 async fn destroy_user(&self, _id: UserId) -> Result<()> {
826 unimplemented!()
827 }
828
829 async fn create_access_token_hash(
830 &self,
831 _user_id: UserId,
832 _access_token_hash: &str,
833 _max_access_token_count: usize,
834 ) -> Result<()> {
835 unimplemented!()
836 }
837
838 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
839 unimplemented!()
840 }
841
842 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
843 unimplemented!()
844 }
845
846 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
847 self.background.simulate_random_delay().await;
848 let mut orgs = self.orgs.lock();
849 if orgs.values().any(|org| org.slug == slug) {
850 Err(anyhow!("org already exists"))
851 } else {
852 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
853 orgs.insert(
854 org_id,
855 Org {
856 id: org_id,
857 name: name.to_string(),
858 slug: slug.to_string(),
859 },
860 );
861 Ok(org_id)
862 }
863 }
864
865 async fn add_org_member(
866 &self,
867 org_id: OrgId,
868 user_id: UserId,
869 is_admin: bool,
870 ) -> Result<()> {
871 self.background.simulate_random_delay().await;
872 if !self.orgs.lock().contains_key(&org_id) {
873 return Err(anyhow!("org does not exist"));
874 }
875 if !self.users.lock().contains_key(&user_id) {
876 return Err(anyhow!("user does not exist"));
877 }
878
879 self.org_memberships
880 .lock()
881 .entry((org_id, user_id))
882 .or_insert(is_admin);
883 Ok(())
884 }
885
886 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
887 self.background.simulate_random_delay().await;
888 if !self.orgs.lock().contains_key(&org_id) {
889 return Err(anyhow!("org does not exist"));
890 }
891
892 let mut channels = self.channels.lock();
893 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
894 channels.insert(
895 channel_id,
896 Channel {
897 id: channel_id,
898 name: name.to_string(),
899 owner_id: org_id.0,
900 owner_is_user: false,
901 },
902 );
903 Ok(channel_id)
904 }
905
906 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
907 self.background.simulate_random_delay().await;
908 Ok(self
909 .channels
910 .lock()
911 .values()
912 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
913 .cloned()
914 .collect())
915 }
916
917 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
918 self.background.simulate_random_delay().await;
919 let channels = self.channels.lock();
920 let memberships = self.channel_memberships.lock();
921 Ok(channels
922 .values()
923 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
924 .cloned()
925 .collect())
926 }
927
928 async fn can_user_access_channel(
929 &self,
930 user_id: UserId,
931 channel_id: ChannelId,
932 ) -> Result<bool> {
933 self.background.simulate_random_delay().await;
934 Ok(self
935 .channel_memberships
936 .lock()
937 .contains_key(&(channel_id, user_id)))
938 }
939
940 async fn add_channel_member(
941 &self,
942 channel_id: ChannelId,
943 user_id: UserId,
944 is_admin: bool,
945 ) -> Result<()> {
946 self.background.simulate_random_delay().await;
947 if !self.channels.lock().contains_key(&channel_id) {
948 return Err(anyhow!("channel does not exist"));
949 }
950 if !self.users.lock().contains_key(&user_id) {
951 return Err(anyhow!("user does not exist"));
952 }
953
954 self.channel_memberships
955 .lock()
956 .entry((channel_id, user_id))
957 .or_insert(is_admin);
958 Ok(())
959 }
960
961 async fn create_channel_message(
962 &self,
963 channel_id: ChannelId,
964 sender_id: UserId,
965 body: &str,
966 timestamp: OffsetDateTime,
967 nonce: u128,
968 ) -> Result<MessageId> {
969 self.background.simulate_random_delay().await;
970 if !self.channels.lock().contains_key(&channel_id) {
971 return Err(anyhow!("channel does not exist"));
972 }
973 if !self.users.lock().contains_key(&sender_id) {
974 return Err(anyhow!("user does not exist"));
975 }
976
977 let mut messages = self.channel_messages.lock();
978 if let Some(message) = messages
979 .values()
980 .find(|message| message.nonce.as_u128() == nonce)
981 {
982 Ok(message.id)
983 } else {
984 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
985 messages.insert(
986 message_id,
987 ChannelMessage {
988 id: message_id,
989 channel_id,
990 sender_id,
991 body: body.to_string(),
992 sent_at: timestamp,
993 nonce: Uuid::from_u128(nonce),
994 },
995 );
996 Ok(message_id)
997 }
998 }
999
1000 async fn get_channel_messages(
1001 &self,
1002 channel_id: ChannelId,
1003 count: usize,
1004 before_id: Option<MessageId>,
1005 ) -> Result<Vec<ChannelMessage>> {
1006 let mut messages = self
1007 .channel_messages
1008 .lock()
1009 .values()
1010 .rev()
1011 .filter(|message| {
1012 message.channel_id == channel_id
1013 && message.id < before_id.unwrap_or(MessageId::MAX)
1014 })
1015 .take(count)
1016 .cloned()
1017 .collect::<Vec<_>>();
1018 messages.sort_unstable_by_key(|message| message.id);
1019 Ok(messages)
1020 }
1021
1022 async fn teardown(&self, _name: &str, _url: &str) {}
1023 }
1024}