1use anyhow::Context;
2use anyhow::Result;
3use async_trait::async_trait;
4use serde::Serialize;
5pub use sqlx::postgres::PgPoolOptions as DbOptions;
6use sqlx::{types::Uuid, FromRow};
7use time::OffsetDateTime;
8
9#[async_trait]
10pub trait Db: Send + Sync {
11 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
12 async fn get_all_users(&self) -> Result<Vec<User>>;
13 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
14 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
15 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
16 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
17 async fn destroy_user(&self, id: UserId) -> Result<()>;
18 async fn create_access_token_hash(
19 &self,
20 user_id: UserId,
21 access_token_hash: &str,
22 max_access_token_count: usize,
23 ) -> Result<()>;
24 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
25 #[cfg(any(test, feature = "seed-support"))]
26 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
27 #[cfg(any(test, feature = "seed-support"))]
28 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
29 #[cfg(any(test, feature = "seed-support"))]
30 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
31 #[cfg(any(test, feature = "seed-support"))]
32 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
33 #[cfg(any(test, feature = "seed-support"))]
34 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
35 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
36 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
37 -> Result<bool>;
38 #[cfg(any(test, feature = "seed-support"))]
39 async fn add_channel_member(
40 &self,
41 channel_id: ChannelId,
42 user_id: UserId,
43 is_admin: bool,
44 ) -> Result<()>;
45 async fn create_channel_message(
46 &self,
47 channel_id: ChannelId,
48 sender_id: UserId,
49 body: &str,
50 timestamp: OffsetDateTime,
51 nonce: u128,
52 ) -> Result<MessageId>;
53 async fn get_channel_messages(
54 &self,
55 channel_id: ChannelId,
56 count: usize,
57 before_id: Option<MessageId>,
58 ) -> Result<Vec<ChannelMessage>>;
59 #[cfg(test)]
60 async fn teardown(&self, url: &str);
61}
62
63pub struct PostgresDb {
64 pool: sqlx::PgPool,
65}
66
67impl PostgresDb {
68 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
69 let pool = DbOptions::new()
70 .max_connections(max_connections)
71 .connect(&url)
72 .await
73 .context("failed to connect to postgres database")?;
74 Ok(Self { pool })
75 }
76}
77
78#[async_trait]
79impl Db for PostgresDb {
80 // users
81
82 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
83 let query = "
84 INSERT INTO users (github_login, admin)
85 VALUES ($1, $2)
86 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
87 RETURNING id
88 ";
89 Ok(sqlx::query_scalar(query)
90 .bind(github_login)
91 .bind(admin)
92 .fetch_one(&self.pool)
93 .await
94 .map(UserId)?)
95 }
96
97 async fn get_all_users(&self) -> Result<Vec<User>> {
98 let query = "SELECT * FROM users ORDER BY github_login ASC";
99 Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
100 }
101
102 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
103 let users = self.get_users_by_ids(vec![id]).await?;
104 Ok(users.into_iter().next())
105 }
106
107 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
108 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
109 let query = "
110 SELECT users.*
111 FROM users
112 WHERE users.id = ANY ($1)
113 ";
114 Ok(sqlx::query_as(query)
115 .bind(&ids)
116 .fetch_all(&self.pool)
117 .await?)
118 }
119
120 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
121 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
122 Ok(sqlx::query_as(query)
123 .bind(github_login)
124 .fetch_optional(&self.pool)
125 .await?)
126 }
127
128 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
129 let query = "UPDATE users SET admin = $1 WHERE id = $2";
130 Ok(sqlx::query(query)
131 .bind(is_admin)
132 .bind(id.0)
133 .execute(&self.pool)
134 .await
135 .map(drop)?)
136 }
137
138 async fn destroy_user(&self, id: UserId) -> Result<()> {
139 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
140 sqlx::query(query)
141 .bind(id.0)
142 .execute(&self.pool)
143 .await
144 .map(drop)?;
145 let query = "DELETE FROM users WHERE id = $1;";
146 Ok(sqlx::query(query)
147 .bind(id.0)
148 .execute(&self.pool)
149 .await
150 .map(drop)?)
151 }
152
153 // access tokens
154
155 async fn create_access_token_hash(
156 &self,
157 user_id: UserId,
158 access_token_hash: &str,
159 max_access_token_count: usize,
160 ) -> Result<()> {
161 let insert_query = "
162 INSERT INTO access_tokens (user_id, hash)
163 VALUES ($1, $2);
164 ";
165 let cleanup_query = "
166 DELETE FROM access_tokens
167 WHERE id IN (
168 SELECT id from access_tokens
169 WHERE user_id = $1
170 ORDER BY id DESC
171 OFFSET $3
172 )
173 ";
174
175 let mut tx = self.pool.begin().await?;
176 sqlx::query(insert_query)
177 .bind(user_id.0)
178 .bind(access_token_hash)
179 .execute(&mut tx)
180 .await?;
181 sqlx::query(cleanup_query)
182 .bind(user_id.0)
183 .bind(access_token_hash)
184 .bind(max_access_token_count as u32)
185 .execute(&mut tx)
186 .await?;
187 Ok(tx.commit().await?)
188 }
189
190 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
191 let query = "
192 SELECT hash
193 FROM access_tokens
194 WHERE user_id = $1
195 ORDER BY id DESC
196 ";
197 Ok(sqlx::query_scalar(query)
198 .bind(user_id.0)
199 .fetch_all(&self.pool)
200 .await?)
201 }
202
203 // orgs
204
205 #[allow(unused)] // Help rust-analyzer
206 #[cfg(any(test, feature = "seed-support"))]
207 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
208 let query = "
209 SELECT *
210 FROM orgs
211 WHERE slug = $1
212 ";
213 Ok(sqlx::query_as(query)
214 .bind(slug)
215 .fetch_optional(&self.pool)
216 .await?)
217 }
218
219 #[cfg(any(test, feature = "seed-support"))]
220 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
221 let query = "
222 INSERT INTO orgs (name, slug)
223 VALUES ($1, $2)
224 RETURNING id
225 ";
226 Ok(sqlx::query_scalar(query)
227 .bind(name)
228 .bind(slug)
229 .fetch_one(&self.pool)
230 .await
231 .map(OrgId)?)
232 }
233
234 #[cfg(any(test, feature = "seed-support"))]
235 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
236 let query = "
237 INSERT INTO org_memberships (org_id, user_id, admin)
238 VALUES ($1, $2, $3)
239 ON CONFLICT DO NOTHING
240 ";
241 Ok(sqlx::query(query)
242 .bind(org_id.0)
243 .bind(user_id.0)
244 .bind(is_admin)
245 .execute(&self.pool)
246 .await
247 .map(drop)?)
248 }
249
250 // channels
251
252 #[cfg(any(test, feature = "seed-support"))]
253 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
254 let query = "
255 INSERT INTO channels (owner_id, owner_is_user, name)
256 VALUES ($1, false, $2)
257 RETURNING id
258 ";
259 Ok(sqlx::query_scalar(query)
260 .bind(org_id.0)
261 .bind(name)
262 .fetch_one(&self.pool)
263 .await
264 .map(ChannelId)?)
265 }
266
267 #[allow(unused)] // Help rust-analyzer
268 #[cfg(any(test, feature = "seed-support"))]
269 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
270 let query = "
271 SELECT *
272 FROM channels
273 WHERE
274 channels.owner_is_user = false AND
275 channels.owner_id = $1
276 ";
277 Ok(sqlx::query_as(query)
278 .bind(org_id.0)
279 .fetch_all(&self.pool)
280 .await?)
281 }
282
283 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
284 let query = "
285 SELECT
286 channels.*
287 FROM
288 channel_memberships, channels
289 WHERE
290 channel_memberships.user_id = $1 AND
291 channel_memberships.channel_id = channels.id
292 ";
293 Ok(sqlx::query_as(query)
294 .bind(user_id.0)
295 .fetch_all(&self.pool)
296 .await?)
297 }
298
299 async fn can_user_access_channel(
300 &self,
301 user_id: UserId,
302 channel_id: ChannelId,
303 ) -> Result<bool> {
304 let query = "
305 SELECT id
306 FROM channel_memberships
307 WHERE user_id = $1 AND channel_id = $2
308 LIMIT 1
309 ";
310 Ok(sqlx::query_scalar::<_, i32>(query)
311 .bind(user_id.0)
312 .bind(channel_id.0)
313 .fetch_optional(&self.pool)
314 .await
315 .map(|e| e.is_some())?)
316 }
317
318 #[cfg(any(test, feature = "seed-support"))]
319 async fn add_channel_member(
320 &self,
321 channel_id: ChannelId,
322 user_id: UserId,
323 is_admin: bool,
324 ) -> Result<()> {
325 let query = "
326 INSERT INTO channel_memberships (channel_id, user_id, admin)
327 VALUES ($1, $2, $3)
328 ON CONFLICT DO NOTHING
329 ";
330 Ok(sqlx::query(query)
331 .bind(channel_id.0)
332 .bind(user_id.0)
333 .bind(is_admin)
334 .execute(&self.pool)
335 .await
336 .map(drop)?)
337 }
338
339 // messages
340
341 async fn create_channel_message(
342 &self,
343 channel_id: ChannelId,
344 sender_id: UserId,
345 body: &str,
346 timestamp: OffsetDateTime,
347 nonce: u128,
348 ) -> Result<MessageId> {
349 let query = "
350 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
351 VALUES ($1, $2, $3, $4, $5)
352 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
353 RETURNING id
354 ";
355 Ok(sqlx::query_scalar(query)
356 .bind(channel_id.0)
357 .bind(sender_id.0)
358 .bind(body)
359 .bind(timestamp)
360 .bind(Uuid::from_u128(nonce))
361 .fetch_one(&self.pool)
362 .await
363 .map(MessageId)?)
364 }
365
366 async fn get_channel_messages(
367 &self,
368 channel_id: ChannelId,
369 count: usize,
370 before_id: Option<MessageId>,
371 ) -> Result<Vec<ChannelMessage>> {
372 let query = r#"
373 SELECT * FROM (
374 SELECT
375 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
376 FROM
377 channel_messages
378 WHERE
379 channel_id = $1 AND
380 id < $2
381 ORDER BY id DESC
382 LIMIT $3
383 ) as recent_messages
384 ORDER BY id ASC
385 "#;
386 Ok(sqlx::query_as(query)
387 .bind(channel_id.0)
388 .bind(before_id.unwrap_or(MessageId::MAX))
389 .bind(count as i64)
390 .fetch_all(&self.pool)
391 .await?)
392 }
393
394 #[cfg(test)]
395 async fn teardown(&self, url: &str) {
396 use util::ResultExt;
397
398 let query = "
399 SELECT pg_terminate_backend(pg_stat_activity.pid)
400 FROM pg_stat_activity
401 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
402 ";
403 sqlx::query(query).execute(&self.pool).await.log_err();
404 self.pool.close().await;
405 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
406 .await
407 .log_err();
408 }
409}
410
411macro_rules! id_type {
412 ($name:ident) => {
413 #[derive(
414 Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
415 )]
416 #[sqlx(transparent)]
417 #[serde(transparent)]
418 pub struct $name(pub i32);
419
420 impl $name {
421 #[allow(unused)]
422 pub const MAX: Self = Self(i32::MAX);
423
424 #[allow(unused)]
425 pub fn from_proto(value: u64) -> Self {
426 Self(value as i32)
427 }
428
429 #[allow(unused)]
430 pub fn to_proto(&self) -> u64 {
431 self.0 as u64
432 }
433 }
434 };
435}
436
437id_type!(UserId);
438#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
439pub struct User {
440 pub id: UserId,
441 pub github_login: String,
442 pub admin: bool,
443}
444
445id_type!(OrgId);
446#[derive(FromRow)]
447pub struct Org {
448 pub id: OrgId,
449 pub name: String,
450 pub slug: String,
451}
452
453id_type!(ChannelId);
454#[derive(Clone, Debug, FromRow, Serialize)]
455pub struct Channel {
456 pub id: ChannelId,
457 pub name: String,
458 pub owner_id: i32,
459 pub owner_is_user: bool,
460}
461
462id_type!(MessageId);
463#[derive(Clone, Debug, FromRow)]
464pub struct ChannelMessage {
465 pub id: MessageId,
466 pub channel_id: ChannelId,
467 pub sender_id: UserId,
468 pub body: String,
469 pub sent_at: OffsetDateTime,
470 pub nonce: Uuid,
471}
472
473#[cfg(test)]
474pub mod tests {
475 use super::*;
476 use anyhow::anyhow;
477 use collections::BTreeMap;
478 use gpui::executor::Background;
479 use lazy_static::lazy_static;
480 use parking_lot::Mutex;
481 use rand::prelude::*;
482 use sqlx::{
483 migrate::{MigrateDatabase, Migrator},
484 Postgres,
485 };
486 use std::{path::Path, sync::Arc};
487 use util::post_inc;
488
489 #[tokio::test(flavor = "multi_thread")]
490 async fn test_get_users_by_ids() {
491 for test_db in [
492 TestDb::postgres().await,
493 TestDb::fake(Arc::new(gpui::executor::Background::new())),
494 ] {
495 let db = test_db.db();
496
497 let user = db.create_user("user", false).await.unwrap();
498 let friend1 = db.create_user("friend-1", false).await.unwrap();
499 let friend2 = db.create_user("friend-2", false).await.unwrap();
500 let friend3 = db.create_user("friend-3", false).await.unwrap();
501
502 assert_eq!(
503 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
504 .await
505 .unwrap(),
506 vec![
507 User {
508 id: user,
509 github_login: "user".to_string(),
510 admin: false,
511 },
512 User {
513 id: friend1,
514 github_login: "friend-1".to_string(),
515 admin: false,
516 },
517 User {
518 id: friend2,
519 github_login: "friend-2".to_string(),
520 admin: false,
521 },
522 User {
523 id: friend3,
524 github_login: "friend-3".to_string(),
525 admin: false,
526 }
527 ]
528 );
529 }
530 }
531
532 #[tokio::test(flavor = "multi_thread")]
533 async fn test_recent_channel_messages() {
534 for test_db in [
535 TestDb::postgres().await,
536 TestDb::fake(Arc::new(gpui::executor::Background::new())),
537 ] {
538 let db = test_db.db();
539 let user = db.create_user("user", false).await.unwrap();
540 let org = db.create_org("org", "org").await.unwrap();
541 let channel = db.create_org_channel(org, "channel").await.unwrap();
542 for i in 0..10 {
543 db.create_channel_message(
544 channel,
545 user,
546 &i.to_string(),
547 OffsetDateTime::now_utc(),
548 i,
549 )
550 .await
551 .unwrap();
552 }
553
554 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
555 assert_eq!(
556 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
557 ["5", "6", "7", "8", "9"]
558 );
559
560 let prev_messages = db
561 .get_channel_messages(channel, 4, Some(messages[0].id))
562 .await
563 .unwrap();
564 assert_eq!(
565 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
566 ["1", "2", "3", "4"]
567 );
568 }
569 }
570
571 #[tokio::test(flavor = "multi_thread")]
572 async fn test_channel_message_nonces() {
573 for test_db in [
574 TestDb::postgres().await,
575 TestDb::fake(Arc::new(gpui::executor::Background::new())),
576 ] {
577 let db = test_db.db();
578 let user = db.create_user("user", false).await.unwrap();
579 let org = db.create_org("org", "org").await.unwrap();
580 let channel = db.create_org_channel(org, "channel").await.unwrap();
581
582 let msg1_id = db
583 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
584 .await
585 .unwrap();
586 let msg2_id = db
587 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
588 .await
589 .unwrap();
590 let msg3_id = db
591 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
592 .await
593 .unwrap();
594 let msg4_id = db
595 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
596 .await
597 .unwrap();
598
599 assert_ne!(msg1_id, msg2_id);
600 assert_eq!(msg1_id, msg3_id);
601 assert_eq!(msg2_id, msg4_id);
602 }
603 }
604
605 #[tokio::test(flavor = "multi_thread")]
606 async fn test_create_access_tokens() {
607 let test_db = TestDb::postgres().await;
608 let db = test_db.db();
609 let user = db.create_user("the-user", false).await.unwrap();
610
611 db.create_access_token_hash(user, "h1", 3).await.unwrap();
612 db.create_access_token_hash(user, "h2", 3).await.unwrap();
613 assert_eq!(
614 db.get_access_token_hashes(user).await.unwrap(),
615 &["h2".to_string(), "h1".to_string()]
616 );
617
618 db.create_access_token_hash(user, "h3", 3).await.unwrap();
619 assert_eq!(
620 db.get_access_token_hashes(user).await.unwrap(),
621 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
622 );
623
624 db.create_access_token_hash(user, "h4", 3).await.unwrap();
625 assert_eq!(
626 db.get_access_token_hashes(user).await.unwrap(),
627 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
628 );
629
630 db.create_access_token_hash(user, "h5", 3).await.unwrap();
631 assert_eq!(
632 db.get_access_token_hashes(user).await.unwrap(),
633 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
634 );
635 }
636
637 pub struct TestDb {
638 pub db: Option<Arc<dyn Db>>,
639 pub url: String,
640 }
641
642 impl TestDb {
643 pub async fn postgres() -> Self {
644 lazy_static! {
645 static ref LOCK: Mutex<()> = Mutex::new(());
646 }
647
648 let _guard = LOCK.lock();
649 let mut rng = StdRng::from_entropy();
650 let name = format!("zed-test-{}", rng.gen::<u128>());
651 let url = format!("postgres://postgres@localhost/{}", name);
652 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
653 Postgres::create_database(&url)
654 .await
655 .expect("failed to create test db");
656 let db = PostgresDb::new(&url, 5).await.unwrap();
657 let migrator = Migrator::new(migrations_path).await.unwrap();
658 migrator.run(&db.pool).await.unwrap();
659 Self {
660 db: Some(Arc::new(db)),
661 url,
662 }
663 }
664
665 pub fn fake(background: Arc<Background>) -> Self {
666 Self {
667 db: Some(Arc::new(FakeDb::new(background))),
668 url: Default::default(),
669 }
670 }
671
672 pub fn db(&self) -> &Arc<dyn Db> {
673 self.db.as_ref().unwrap()
674 }
675 }
676
677 impl Drop for TestDb {
678 fn drop(&mut self) {
679 if let Some(db) = self.db.take() {
680 futures::executor::block_on(db.teardown(&self.url));
681 }
682 }
683 }
684
685 pub struct FakeDb {
686 background: Arc<Background>,
687 users: Mutex<BTreeMap<UserId, User>>,
688 next_user_id: Mutex<i32>,
689 orgs: Mutex<BTreeMap<OrgId, Org>>,
690 next_org_id: Mutex<i32>,
691 org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
692 channels: Mutex<BTreeMap<ChannelId, Channel>>,
693 next_channel_id: Mutex<i32>,
694 channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
695 channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
696 next_channel_message_id: Mutex<i32>,
697 }
698
699 impl FakeDb {
700 pub fn new(background: Arc<Background>) -> Self {
701 Self {
702 background,
703 users: Default::default(),
704 next_user_id: Mutex::new(1),
705 orgs: Default::default(),
706 next_org_id: Mutex::new(1),
707 org_memberships: Default::default(),
708 channels: Default::default(),
709 next_channel_id: Mutex::new(1),
710 channel_memberships: Default::default(),
711 channel_messages: Default::default(),
712 next_channel_message_id: Mutex::new(1),
713 }
714 }
715 }
716
717 #[async_trait]
718 impl Db for FakeDb {
719 async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
720 self.background.simulate_random_delay().await;
721
722 let mut users = self.users.lock();
723 if let Some(user) = users
724 .values()
725 .find(|user| user.github_login == github_login)
726 {
727 Ok(user.id)
728 } else {
729 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
730 users.insert(
731 user_id,
732 User {
733 id: user_id,
734 github_login: github_login.to_string(),
735 admin,
736 },
737 );
738 Ok(user_id)
739 }
740 }
741
742 async fn get_all_users(&self) -> Result<Vec<User>> {
743 unimplemented!()
744 }
745
746 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
747 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
748 }
749
750 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
751 self.background.simulate_random_delay().await;
752 let users = self.users.lock();
753 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
754 }
755
756 async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
757 unimplemented!()
758 }
759
760 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
761 unimplemented!()
762 }
763
764 async fn destroy_user(&self, _id: UserId) -> Result<()> {
765 unimplemented!()
766 }
767
768 async fn create_access_token_hash(
769 &self,
770 _user_id: UserId,
771 _access_token_hash: &str,
772 _max_access_token_count: usize,
773 ) -> Result<()> {
774 unimplemented!()
775 }
776
777 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
778 unimplemented!()
779 }
780
781 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
782 unimplemented!()
783 }
784
785 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
786 self.background.simulate_random_delay().await;
787 let mut orgs = self.orgs.lock();
788 if orgs.values().any(|org| org.slug == slug) {
789 Err(anyhow!("org already exists"))
790 } else {
791 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
792 orgs.insert(
793 org_id,
794 Org {
795 id: org_id,
796 name: name.to_string(),
797 slug: slug.to_string(),
798 },
799 );
800 Ok(org_id)
801 }
802 }
803
804 async fn add_org_member(
805 &self,
806 org_id: OrgId,
807 user_id: UserId,
808 is_admin: bool,
809 ) -> Result<()> {
810 self.background.simulate_random_delay().await;
811 if !self.orgs.lock().contains_key(&org_id) {
812 return Err(anyhow!("org does not exist"));
813 }
814 if !self.users.lock().contains_key(&user_id) {
815 return Err(anyhow!("user does not exist"));
816 }
817
818 self.org_memberships
819 .lock()
820 .entry((org_id, user_id))
821 .or_insert(is_admin);
822 Ok(())
823 }
824
825 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
826 self.background.simulate_random_delay().await;
827 if !self.orgs.lock().contains_key(&org_id) {
828 return Err(anyhow!("org does not exist"));
829 }
830
831 let mut channels = self.channels.lock();
832 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
833 channels.insert(
834 channel_id,
835 Channel {
836 id: channel_id,
837 name: name.to_string(),
838 owner_id: org_id.0,
839 owner_is_user: false,
840 },
841 );
842 Ok(channel_id)
843 }
844
845 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
846 self.background.simulate_random_delay().await;
847 Ok(self
848 .channels
849 .lock()
850 .values()
851 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
852 .cloned()
853 .collect())
854 }
855
856 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
857 self.background.simulate_random_delay().await;
858 let channels = self.channels.lock();
859 let memberships = self.channel_memberships.lock();
860 Ok(channels
861 .values()
862 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
863 .cloned()
864 .collect())
865 }
866
867 async fn can_user_access_channel(
868 &self,
869 user_id: UserId,
870 channel_id: ChannelId,
871 ) -> Result<bool> {
872 self.background.simulate_random_delay().await;
873 Ok(self
874 .channel_memberships
875 .lock()
876 .contains_key(&(channel_id, user_id)))
877 }
878
879 async fn add_channel_member(
880 &self,
881 channel_id: ChannelId,
882 user_id: UserId,
883 is_admin: bool,
884 ) -> Result<()> {
885 self.background.simulate_random_delay().await;
886 if !self.channels.lock().contains_key(&channel_id) {
887 return Err(anyhow!("channel does not exist"));
888 }
889 if !self.users.lock().contains_key(&user_id) {
890 return Err(anyhow!("user does not exist"));
891 }
892
893 self.channel_memberships
894 .lock()
895 .entry((channel_id, user_id))
896 .or_insert(is_admin);
897 Ok(())
898 }
899
900 async fn create_channel_message(
901 &self,
902 channel_id: ChannelId,
903 sender_id: UserId,
904 body: &str,
905 timestamp: OffsetDateTime,
906 nonce: u128,
907 ) -> Result<MessageId> {
908 self.background.simulate_random_delay().await;
909 if !self.channels.lock().contains_key(&channel_id) {
910 return Err(anyhow!("channel does not exist"));
911 }
912 if !self.users.lock().contains_key(&sender_id) {
913 return Err(anyhow!("user does not exist"));
914 }
915
916 let mut messages = self.channel_messages.lock();
917 if let Some(message) = messages
918 .values()
919 .find(|message| message.nonce.as_u128() == nonce)
920 {
921 Ok(message.id)
922 } else {
923 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
924 messages.insert(
925 message_id,
926 ChannelMessage {
927 id: message_id,
928 channel_id,
929 sender_id,
930 body: body.to_string(),
931 sent_at: timestamp,
932 nonce: Uuid::from_u128(nonce),
933 },
934 );
935 Ok(message_id)
936 }
937 }
938
939 async fn get_channel_messages(
940 &self,
941 channel_id: ChannelId,
942 count: usize,
943 before_id: Option<MessageId>,
944 ) -> Result<Vec<ChannelMessage>> {
945 let mut messages = self
946 .channel_messages
947 .lock()
948 .values()
949 .rev()
950 .filter(|message| {
951 message.channel_id == channel_id
952 && message.id < before_id.unwrap_or(MessageId::MAX)
953 })
954 .take(count)
955 .cloned()
956 .collect::<Vec<_>>();
957 messages.sort_unstable_by_key(|message| message.id);
958 Ok(messages)
959 }
960
961 async fn teardown(&self, _: &str) {}
962 }
963}