1use anyhow::Context;
2use anyhow::Result;
3use async_trait::async_trait;
4use serde::Serialize;
5pub use sqlx::postgres::PgPoolOptions as DbOptions;
6use sqlx::{types::Uuid, FromRow};
7use time::OffsetDateTime;
8
9#[async_trait]
10pub trait Db: Send + Sync {
11 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
12 async fn get_all_users(&self) -> Result<Vec<User>>;
13 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
14 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
15 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
16 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
17 async fn destroy_user(&self, id: UserId) -> Result<()>;
18 async fn create_access_token_hash(
19 &self,
20 user_id: UserId,
21 access_token_hash: &str,
22 max_access_token_count: usize,
23 ) -> Result<()>;
24 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
25 #[cfg(any(test, feature = "seed-support"))]
26 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
27 #[cfg(any(test, feature = "seed-support"))]
28 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
29 #[cfg(any(test, feature = "seed-support"))]
30 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
31 #[cfg(any(test, feature = "seed-support"))]
32 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
33 #[cfg(any(test, feature = "seed-support"))]
34 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
35 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
36 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
37 -> Result<bool>;
38 #[cfg(any(test, feature = "seed-support"))]
39 async fn add_channel_member(
40 &self,
41 channel_id: ChannelId,
42 user_id: UserId,
43 is_admin: bool,
44 ) -> Result<()>;
45 async fn create_channel_message(
46 &self,
47 channel_id: ChannelId,
48 sender_id: UserId,
49 body: &str,
50 timestamp: OffsetDateTime,
51 nonce: u128,
52 ) -> Result<MessageId>;
53 async fn get_channel_messages(
54 &self,
55 channel_id: ChannelId,
56 count: usize,
57 before_id: Option<MessageId>,
58 ) -> Result<Vec<ChannelMessage>>;
59 #[cfg(test)]
60 async fn teardown(&self, name: &str, url: &str);
61}
62
63pub struct PostgresDb {
64 pool: sqlx::PgPool,
65}
66
67impl PostgresDb {
68 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
69 let pool = DbOptions::new()
70 .max_connections(max_connections)
71 .connect(url)
72 .await
73 .context("failed to connect to postgres database")?;
74 Ok(Self { pool })
75 }
76}
77
78#[async_trait]
79impl Db for PostgresDb {
80 // users
81
82 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
83 let query = "
84 INSERT INTO users (github_login, admin)
85 VALUES ($1, $2)
86 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
87 RETURNING id
88 ";
89 Ok(sqlx::query_scalar(query)
90 .bind(github_login)
91 .bind(admin)
92 .fetch_one(&self.pool)
93 .await
94 .map(UserId)?)
95 }
96
97 async fn get_all_users(&self) -> Result<Vec<User>> {
98 let query = "SELECT * FROM users ORDER BY github_login ASC";
99 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
100 }
101
102 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
103 let users = self.get_users_by_ids(vec![id]).await?;
104 Ok(users.into_iter().next())
105 }
106
107 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
108 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
109 let query = "
110 SELECT users.*
111 FROM users
112 WHERE users.id = ANY ($1)
113 ";
114
115 Ok(sqlx::query_as(query)
116 .bind(&ids)
117 .fetch_all(&self.pool)
118 .await?)
119 }
120
121 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
122 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
123 Ok(sqlx::query_as(query)
124 .bind(github_login)
125 .fetch_optional(&self.pool)
126 .await?)
127 }
128
129 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
130 let query = "UPDATE users SET admin = $1 WHERE id = $2";
131 Ok(sqlx::query(query)
132 .bind(is_admin)
133 .bind(id.0)
134 .execute(&self.pool)
135 .await
136 .map(drop)?)
137 }
138
139 async fn destroy_user(&self, id: UserId) -> Result<()> {
140 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
141 sqlx::query(query)
142 .bind(id.0)
143 .execute(&self.pool)
144 .await
145 .map(drop)?;
146 let query = "DELETE FROM users WHERE id = $1;";
147 Ok(sqlx::query(query)
148 .bind(id.0)
149 .execute(&self.pool)
150 .await
151 .map(drop)?)
152 }
153
154 // access tokens
155
156 async fn create_access_token_hash(
157 &self,
158 user_id: UserId,
159 access_token_hash: &str,
160 max_access_token_count: usize,
161 ) -> Result<()> {
162 let insert_query = "
163 INSERT INTO access_tokens (user_id, hash)
164 VALUES ($1, $2);
165 ";
166 let cleanup_query = "
167 DELETE FROM access_tokens
168 WHERE id IN (
169 SELECT id from access_tokens
170 WHERE user_id = $1
171 ORDER BY id DESC
172 OFFSET $3
173 )
174 ";
175
176 let mut tx = self.pool.begin().await?;
177 sqlx::query(insert_query)
178 .bind(user_id.0)
179 .bind(access_token_hash)
180 .execute(&mut tx)
181 .await?;
182 sqlx::query(cleanup_query)
183 .bind(user_id.0)
184 .bind(access_token_hash)
185 .bind(max_access_token_count as u32)
186 .execute(&mut tx)
187 .await?;
188 Ok(tx.commit().await?)
189 }
190
191 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
192 let query = "
193 SELECT hash
194 FROM access_tokens
195 WHERE user_id = $1
196 ORDER BY id DESC
197 ";
198 Ok(sqlx::query_scalar(query)
199 .bind(user_id.0)
200 .fetch_all(&self.pool)
201 .await?)
202 }
203
204 // orgs
205
206 #[allow(unused)] // Help rust-analyzer
207 #[cfg(any(test, feature = "seed-support"))]
208 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
209 let query = "
210 SELECT *
211 FROM orgs
212 WHERE slug = $1
213 ";
214 Ok(sqlx::query_as(query)
215 .bind(slug)
216 .fetch_optional(&self.pool)
217 .await?)
218 }
219
220 #[cfg(any(test, feature = "seed-support"))]
221 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
222 let query = "
223 INSERT INTO orgs (name, slug)
224 VALUES ($1, $2)
225 RETURNING id
226 ";
227 Ok(sqlx::query_scalar(query)
228 .bind(name)
229 .bind(slug)
230 .fetch_one(&self.pool)
231 .await
232 .map(OrgId)?)
233 }
234
235 #[cfg(any(test, feature = "seed-support"))]
236 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
237 let query = "
238 INSERT INTO org_memberships (org_id, user_id, admin)
239 VALUES ($1, $2, $3)
240 ON CONFLICT DO NOTHING
241 ";
242 Ok(sqlx::query(query)
243 .bind(org_id.0)
244 .bind(user_id.0)
245 .bind(is_admin)
246 .execute(&self.pool)
247 .await
248 .map(drop)?)
249 }
250
251 // channels
252
253 #[cfg(any(test, feature = "seed-support"))]
254 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
255 let query = "
256 INSERT INTO channels (owner_id, owner_is_user, name)
257 VALUES ($1, false, $2)
258 RETURNING id
259 ";
260 Ok(sqlx::query_scalar(query)
261 .bind(org_id.0)
262 .bind(name)
263 .fetch_one(&self.pool)
264 .await
265 .map(ChannelId)?)
266 }
267
268 #[allow(unused)] // Help rust-analyzer
269 #[cfg(any(test, feature = "seed-support"))]
270 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
271 let query = "
272 SELECT *
273 FROM channels
274 WHERE
275 channels.owner_is_user = false AND
276 channels.owner_id = $1
277 ";
278 Ok(sqlx::query_as(query)
279 .bind(org_id.0)
280 .fetch_all(&self.pool)
281 .await?)
282 }
283
284 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
285 let query = "
286 SELECT
287 channels.*
288 FROM
289 channel_memberships, channels
290 WHERE
291 channel_memberships.user_id = $1 AND
292 channel_memberships.channel_id = channels.id
293 ";
294 Ok(sqlx::query_as(query)
295 .bind(user_id.0)
296 .fetch_all(&self.pool)
297 .await?)
298 }
299
300 async fn can_user_access_channel(
301 &self,
302 user_id: UserId,
303 channel_id: ChannelId,
304 ) -> Result<bool> {
305 let query = "
306 SELECT id
307 FROM channel_memberships
308 WHERE user_id = $1 AND channel_id = $2
309 LIMIT 1
310 ";
311 Ok(sqlx::query_scalar::<_, i32>(query)
312 .bind(user_id.0)
313 .bind(channel_id.0)
314 .fetch_optional(&self.pool)
315 .await
316 .map(|e| e.is_some())?)
317 }
318
319 #[cfg(any(test, feature = "seed-support"))]
320 async fn add_channel_member(
321 &self,
322 channel_id: ChannelId,
323 user_id: UserId,
324 is_admin: bool,
325 ) -> Result<()> {
326 let query = "
327 INSERT INTO channel_memberships (channel_id, user_id, admin)
328 VALUES ($1, $2, $3)
329 ON CONFLICT DO NOTHING
330 ";
331 Ok(sqlx::query(query)
332 .bind(channel_id.0)
333 .bind(user_id.0)
334 .bind(is_admin)
335 .execute(&self.pool)
336 .await
337 .map(drop)?)
338 }
339
340 // messages
341
342 async fn create_channel_message(
343 &self,
344 channel_id: ChannelId,
345 sender_id: UserId,
346 body: &str,
347 timestamp: OffsetDateTime,
348 nonce: u128,
349 ) -> Result<MessageId> {
350 let query = "
351 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
352 VALUES ($1, $2, $3, $4, $5)
353 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
354 RETURNING id
355 ";
356 Ok(sqlx::query_scalar(query)
357 .bind(channel_id.0)
358 .bind(sender_id.0)
359 .bind(body)
360 .bind(timestamp)
361 .bind(Uuid::from_u128(nonce))
362 .fetch_one(&self.pool)
363 .await
364 .map(MessageId)?)
365 }
366
367 async fn get_channel_messages(
368 &self,
369 channel_id: ChannelId,
370 count: usize,
371 before_id: Option<MessageId>,
372 ) -> Result<Vec<ChannelMessage>> {
373 let query = r#"
374 SELECT * FROM (
375 SELECT
376 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
377 FROM
378 channel_messages
379 WHERE
380 channel_id = $1 AND
381 id < $2
382 ORDER BY id DESC
383 LIMIT $3
384 ) as recent_messages
385 ORDER BY id ASC
386 "#;
387 Ok(sqlx::query_as(query)
388 .bind(channel_id.0)
389 .bind(before_id.unwrap_or(MessageId::MAX))
390 .bind(count as i64)
391 .fetch_all(&self.pool)
392 .await?)
393 }
394
395 #[cfg(test)]
396 async fn teardown(&self, name: &str, url: &str) {
397 use util::ResultExt;
398
399 let query = "
400 SELECT pg_terminate_backend(pg_stat_activity.pid)
401 FROM pg_stat_activity
402 WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
403 ";
404 sqlx::query(query)
405 .bind(name)
406 .execute(&self.pool)
407 .await
408 .log_err();
409 self.pool.close().await;
410 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
411 .await
412 .log_err();
413 }
414}
415
416macro_rules! id_type {
417 ($name:ident) => {
418 #[derive(
419 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
420 )]
421 #[sqlx(transparent)]
422 #[serde(transparent)]
423 pub struct $name(pub i32);
424
425 impl $name {
426 #[allow(unused)]
427 pub const MAX: Self = Self(i32::MAX);
428
429 #[allow(unused)]
430 pub fn from_proto(value: u64) -> Self {
431 Self(value as i32)
432 }
433
434 #[allow(unused)]
435 pub fn to_proto(&self) -> u64 {
436 self.0 as u64
437 }
438 }
439 };
440}
441
442id_type!(UserId);
443#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
444pub struct User {
445 pub id: UserId,
446 pub github_login: String,
447 pub admin: bool,
448}
449
450id_type!(OrgId);
451#[derive(FromRow)]
452pub struct Org {
453 pub id: OrgId,
454 pub name: String,
455 pub slug: String,
456}
457
458id_type!(ChannelId);
459#[derive(Clone, Debug, FromRow, Serialize)]
460pub struct Channel {
461 pub id: ChannelId,
462 pub name: String,
463 pub owner_id: i32,
464 pub owner_is_user: bool,
465}
466
467id_type!(MessageId);
468#[derive(Clone, Debug, FromRow)]
469pub struct ChannelMessage {
470 pub id: MessageId,
471 pub channel_id: ChannelId,
472 pub sender_id: UserId,
473 pub body: String,
474 pub sent_at: OffsetDateTime,
475 pub nonce: Uuid,
476}
477
478#[cfg(test)]
479pub mod tests {
480 use super::*;
481 use anyhow::anyhow;
482 use collections::BTreeMap;
483 use gpui::{executor::Background, TestAppContext};
484 use lazy_static::lazy_static;
485 use parking_lot::Mutex;
486 use rand::prelude::*;
487 use sqlx::{
488 migrate::{MigrateDatabase, Migrator},
489 Postgres,
490 };
491 use std::{path::Path, sync::Arc};
492 use util::post_inc;
493
494 #[gpui::test]
495 async fn test_get_users_by_ids(cx: &mut TestAppContext) {
496 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
497 let db = test_db.db();
498
499 let user = db.create_user("user", false).await.unwrap();
500 let friend1 = db.create_user("friend-1", false).await.unwrap();
501 let friend2 = db.create_user("friend-2", false).await.unwrap();
502 let friend3 = db.create_user("friend-3", false).await.unwrap();
503
504 assert_eq!(
505 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
506 .await
507 .unwrap(),
508 vec![
509 User {
510 id: user,
511 github_login: "user".to_string(),
512 admin: false,
513 },
514 User {
515 id: friend1,
516 github_login: "friend-1".to_string(),
517 admin: false,
518 },
519 User {
520 id: friend2,
521 github_login: "friend-2".to_string(),
522 admin: false,
523 },
524 User {
525 id: friend3,
526 github_login: "friend-3".to_string(),
527 admin: false,
528 }
529 ]
530 );
531 }
532 }
533
534 #[gpui::test]
535 async fn test_recent_channel_messages(cx: &mut TestAppContext) {
536 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
537 let db = test_db.db();
538 let user = db.create_user("user", false).await.unwrap();
539 let org = db.create_org("org", "org").await.unwrap();
540 let channel = db.create_org_channel(org, "channel").await.unwrap();
541 for i in 0..10 {
542 db.create_channel_message(
543 channel,
544 user,
545 &i.to_string(),
546 OffsetDateTime::now_utc(),
547 i,
548 )
549 .await
550 .unwrap();
551 }
552
553 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
554 assert_eq!(
555 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
556 ["5", "6", "7", "8", "9"]
557 );
558
559 let prev_messages = db
560 .get_channel_messages(channel, 4, Some(messages[0].id))
561 .await
562 .unwrap();
563 assert_eq!(
564 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
565 ["1", "2", "3", "4"]
566 );
567 }
568 }
569
570 #[gpui::test]
571 async fn test_channel_message_nonces(cx: &mut TestAppContext) {
572 for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
573 let db = test_db.db();
574 let user = db.create_user("user", false).await.unwrap();
575 let org = db.create_org("org", "org").await.unwrap();
576 let channel = db.create_org_channel(org, "channel").await.unwrap();
577
578 let msg1_id = db
579 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
580 .await
581 .unwrap();
582 let msg2_id = db
583 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
584 .await
585 .unwrap();
586 let msg3_id = db
587 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
588 .await
589 .unwrap();
590 let msg4_id = db
591 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
592 .await
593 .unwrap();
594
595 assert_ne!(msg1_id, msg2_id);
596 assert_eq!(msg1_id, msg3_id);
597 assert_eq!(msg2_id, msg4_id);
598 }
599 }
600
601 #[gpui::test]
602 async fn test_create_access_tokens() {
603 let test_db = TestDb::postgres();
604 let db = test_db.db();
605 let user = db.create_user("the-user", false).await.unwrap();
606
607 db.create_access_token_hash(user, "h1", 3).await.unwrap();
608 db.create_access_token_hash(user, "h2", 3).await.unwrap();
609 assert_eq!(
610 db.get_access_token_hashes(user).await.unwrap(),
611 &["h2".to_string(), "h1".to_string()]
612 );
613
614 db.create_access_token_hash(user, "h3", 3).await.unwrap();
615 assert_eq!(
616 db.get_access_token_hashes(user).await.unwrap(),
617 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
618 );
619
620 db.create_access_token_hash(user, "h4", 3).await.unwrap();
621 assert_eq!(
622 db.get_access_token_hashes(user).await.unwrap(),
623 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
624 );
625
626 db.create_access_token_hash(user, "h5", 3).await.unwrap();
627 assert_eq!(
628 db.get_access_token_hashes(user).await.unwrap(),
629 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
630 );
631 }
632
633 pub struct TestDb {
634 pub db: Option<Arc<dyn Db>>,
635 pub name: String,
636 pub url: String,
637 }
638
639 impl TestDb {
640 pub fn postgres() -> Self {
641 lazy_static! {
642 static ref LOCK: Mutex<()> = Mutex::new(());
643 }
644
645 let _guard = LOCK.lock();
646 let mut rng = StdRng::from_entropy();
647 let name = format!("zed-test-{}", rng.gen::<u128>());
648 let url = format!("postgres://postgres@localhost/{}", name);
649 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
650 let db = futures::executor::block_on(async {
651 Postgres::create_database(&url)
652 .await
653 .expect("failed to create test db");
654 let db = PostgresDb::new(&url, 5).await.unwrap();
655 let migrator = Migrator::new(migrations_path).await.unwrap();
656 migrator.run(&db.pool).await.unwrap();
657 db
658 });
659 Self {
660 db: Some(Arc::new(db)),
661 name,
662 url,
663 }
664 }
665
666 pub fn fake(background: Arc<Background>) -> Self {
667 Self {
668 db: Some(Arc::new(FakeDb::new(background))),
669 name: "fake".to_string(),
670 url: "fake".to_string(),
671 }
672 }
673
674 pub fn db(&self) -> &Arc<dyn Db> {
675 self.db.as_ref().unwrap()
676 }
677 }
678
679 impl Drop for TestDb {
680 fn drop(&mut self) {
681 if let Some(db) = self.db.take() {
682 futures::executor::block_on(db.teardown(&self.name, &self.url));
683 }
684 }
685 }
686
687 pub struct FakeDb {
688 background: Arc<Background>,
689 users: Mutex<BTreeMap<UserId, User>>,
690 next_user_id: Mutex<i32>,
691 orgs: Mutex<BTreeMap<OrgId, Org>>,
692 next_org_id: Mutex<i32>,
693 org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
694 channels: Mutex<BTreeMap<ChannelId, Channel>>,
695 next_channel_id: Mutex<i32>,
696 channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
697 channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
698 next_channel_message_id: Mutex<i32>,
699 }
700
701 impl FakeDb {
702 pub fn new(background: Arc<Background>) -> Self {
703 Self {
704 background,
705 users: Default::default(),
706 next_user_id: Mutex::new(1),
707 orgs: Default::default(),
708 next_org_id: Mutex::new(1),
709 org_memberships: Default::default(),
710 channels: Default::default(),
711 next_channel_id: Mutex::new(1),
712 channel_memberships: Default::default(),
713 channel_messages: Default::default(),
714 next_channel_message_id: Mutex::new(1),
715 }
716 }
717 }
718
719 #[async_trait]
720 impl Db for FakeDb {
721 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
722 self.background.simulate_random_delay().await;
723
724 let mut users = self.users.lock();
725 if let Some(user) = users
726 .values()
727 .find(|user| user.github_login == github_login)
728 {
729 Ok(user.id)
730 } else {
731 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
732 users.insert(
733 user_id,
734 User {
735 id: user_id,
736 github_login: github_login.to_string(),
737 admin,
738 },
739 );
740 Ok(user_id)
741 }
742 }
743
744 async fn get_all_users(&self) -> Result<Vec<User>> {
745 unimplemented!()
746 }
747
748 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
749 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
750 }
751
752 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
753 self.background.simulate_random_delay().await;
754 let users = self.users.lock();
755 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
756 }
757
758 async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
759 unimplemented!()
760 }
761
762 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
763 unimplemented!()
764 }
765
766 async fn destroy_user(&self, _id: UserId) -> Result<()> {
767 unimplemented!()
768 }
769
770 async fn create_access_token_hash(
771 &self,
772 _user_id: UserId,
773 _access_token_hash: &str,
774 _max_access_token_count: usize,
775 ) -> Result<()> {
776 unimplemented!()
777 }
778
779 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
780 unimplemented!()
781 }
782
783 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
784 unimplemented!()
785 }
786
787 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
788 self.background.simulate_random_delay().await;
789 let mut orgs = self.orgs.lock();
790 if orgs.values().any(|org| org.slug == slug) {
791 Err(anyhow!("org already exists"))
792 } else {
793 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
794 orgs.insert(
795 org_id,
796 Org {
797 id: org_id,
798 name: name.to_string(),
799 slug: slug.to_string(),
800 },
801 );
802 Ok(org_id)
803 }
804 }
805
806 async fn add_org_member(
807 &self,
808 org_id: OrgId,
809 user_id: UserId,
810 is_admin: bool,
811 ) -> Result<()> {
812 self.background.simulate_random_delay().await;
813 if !self.orgs.lock().contains_key(&org_id) {
814 return Err(anyhow!("org does not exist"));
815 }
816 if !self.users.lock().contains_key(&user_id) {
817 return Err(anyhow!("user does not exist"));
818 }
819
820 self.org_memberships
821 .lock()
822 .entry((org_id, user_id))
823 .or_insert(is_admin);
824 Ok(())
825 }
826
827 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
828 self.background.simulate_random_delay().await;
829 if !self.orgs.lock().contains_key(&org_id) {
830 return Err(anyhow!("org does not exist"));
831 }
832
833 let mut channels = self.channels.lock();
834 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
835 channels.insert(
836 channel_id,
837 Channel {
838 id: channel_id,
839 name: name.to_string(),
840 owner_id: org_id.0,
841 owner_is_user: false,
842 },
843 );
844 Ok(channel_id)
845 }
846
847 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
848 self.background.simulate_random_delay().await;
849 Ok(self
850 .channels
851 .lock()
852 .values()
853 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
854 .cloned()
855 .collect())
856 }
857
858 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
859 self.background.simulate_random_delay().await;
860 let channels = self.channels.lock();
861 let memberships = self.channel_memberships.lock();
862 Ok(channels
863 .values()
864 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
865 .cloned()
866 .collect())
867 }
868
869 async fn can_user_access_channel(
870 &self,
871 user_id: UserId,
872 channel_id: ChannelId,
873 ) -> Result<bool> {
874 self.background.simulate_random_delay().await;
875 Ok(self
876 .channel_memberships
877 .lock()
878 .contains_key(&(channel_id, user_id)))
879 }
880
881 async fn add_channel_member(
882 &self,
883 channel_id: ChannelId,
884 user_id: UserId,
885 is_admin: bool,
886 ) -> Result<()> {
887 self.background.simulate_random_delay().await;
888 if !self.channels.lock().contains_key(&channel_id) {
889 return Err(anyhow!("channel does not exist"));
890 }
891 if !self.users.lock().contains_key(&user_id) {
892 return Err(anyhow!("user does not exist"));
893 }
894
895 self.channel_memberships
896 .lock()
897 .entry((channel_id, user_id))
898 .or_insert(is_admin);
899 Ok(())
900 }
901
902 async fn create_channel_message(
903 &self,
904 channel_id: ChannelId,
905 sender_id: UserId,
906 body: &str,
907 timestamp: OffsetDateTime,
908 nonce: u128,
909 ) -> Result<MessageId> {
910 self.background.simulate_random_delay().await;
911 if !self.channels.lock().contains_key(&channel_id) {
912 return Err(anyhow!("channel does not exist"));
913 }
914 if !self.users.lock().contains_key(&sender_id) {
915 return Err(anyhow!("user does not exist"));
916 }
917
918 let mut messages = self.channel_messages.lock();
919 if let Some(message) = messages
920 .values()
921 .find(|message| message.nonce.as_u128() == nonce)
922 {
923 Ok(message.id)
924 } else {
925 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
926 messages.insert(
927 message_id,
928 ChannelMessage {
929 id: message_id,
930 channel_id,
931 sender_id,
932 body: body.to_string(),
933 sent_at: timestamp,
934 nonce: Uuid::from_u128(nonce),
935 },
936 );
937 Ok(message_id)
938 }
939 }
940
941 async fn get_channel_messages(
942 &self,
943 channel_id: ChannelId,
944 count: usize,
945 before_id: Option<MessageId>,
946 ) -> Result<Vec<ChannelMessage>> {
947 let mut messages = self
948 .channel_messages
949 .lock()
950 .values()
951 .rev()
952 .filter(|message| {
953 message.channel_id == channel_id
954 && message.id < before_id.unwrap_or(MessageId::MAX)
955 })
956 .take(count)
957 .cloned()
958 .collect::<Vec<_>>();
959 messages.sort_unstable_by_key(|message| message.id);
960 Ok(messages)
961 }
962
963 async fn teardown(&self, _name: &str, _url: &str) {}
964 }
965}