1use std::{ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::Serialize;
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::OffsetDateTime;
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
29 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
30 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
31 async fn destroy_user(&self, id: UserId) -> Result<()>;
32
33 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
34 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
35 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
36 async fn redeem_invite_code(
37 &self,
38 code: &str,
39 login: &str,
40 email_address: Option<&str>,
41 ) -> Result<UserId>;
42
43 /// Registers a new project for the given user.
44 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
45
46 /// Unregisters a project for the given project id.
47 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
48
49 /// Create a new project for the given user.
50 async fn update_worktree_extensions(
51 &self,
52 project_id: ProjectId,
53 worktree_id: u64,
54 extensions: HashMap<String, usize>,
55 ) -> Result<()>;
56
57 /// Record which users have been active in which projects during
58 /// a given period of time.
59 async fn record_project_activity(
60 &self,
61 time_period: Range<OffsetDateTime>,
62 active_projects: &[(UserId, ProjectId)],
63 ) -> Result<()>;
64
65 /// Get the users that have been most active during the given time period,
66 /// along with the amount of time they have been active in each project.
67 async fn summarize_project_activity(
68 &self,
69 time_period: Range<OffsetDateTime>,
70 max_user_count: usize,
71 ) -> Result<Vec<UserActivitySummary>>;
72
73 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
74 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
75 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
76 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
77 async fn dismiss_contact_notification(
78 &self,
79 responder_id: UserId,
80 requester_id: UserId,
81 ) -> Result<()>;
82 async fn respond_to_contact_request(
83 &self,
84 responder_id: UserId,
85 requester_id: UserId,
86 accept: bool,
87 ) -> Result<()>;
88
89 async fn create_access_token_hash(
90 &self,
91 user_id: UserId,
92 access_token_hash: &str,
93 max_access_token_count: usize,
94 ) -> Result<()>;
95 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
96 #[cfg(any(test, feature = "seed-support"))]
97
98 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
99 #[cfg(any(test, feature = "seed-support"))]
100 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
101 #[cfg(any(test, feature = "seed-support"))]
102 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
103 #[cfg(any(test, feature = "seed-support"))]
104 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
105 #[cfg(any(test, feature = "seed-support"))]
106
107 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
108 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
109 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
110 -> Result<bool>;
111 #[cfg(any(test, feature = "seed-support"))]
112 async fn add_channel_member(
113 &self,
114 channel_id: ChannelId,
115 user_id: UserId,
116 is_admin: bool,
117 ) -> Result<()>;
118 async fn create_channel_message(
119 &self,
120 channel_id: ChannelId,
121 sender_id: UserId,
122 body: &str,
123 timestamp: OffsetDateTime,
124 nonce: u128,
125 ) -> Result<MessageId>;
126 async fn get_channel_messages(
127 &self,
128 channel_id: ChannelId,
129 count: usize,
130 before_id: Option<MessageId>,
131 ) -> Result<Vec<ChannelMessage>>;
132 #[cfg(test)]
133 async fn teardown(&self, url: &str);
134 #[cfg(test)]
135 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
136}
137
138pub struct PostgresDb {
139 pool: sqlx::PgPool,
140}
141
142impl PostgresDb {
143 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
144 let pool = DbOptions::new()
145 .max_connections(max_connections)
146 .connect(&url)
147 .await
148 .context("failed to connect to postgres database")?;
149 Ok(Self { pool })
150 }
151}
152
153#[async_trait]
154impl Db for PostgresDb {
155 // users
156
157 async fn create_user(
158 &self,
159 github_login: &str,
160 email_address: Option<&str>,
161 admin: bool,
162 ) -> Result<UserId> {
163 let query = "
164 INSERT INTO users (github_login, email_address, admin)
165 VALUES ($1, $2, $3)
166 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
167 RETURNING id
168 ";
169 Ok(sqlx::query_scalar(query)
170 .bind(github_login)
171 .bind(email_address)
172 .bind(admin)
173 .fetch_one(&self.pool)
174 .await
175 .map(UserId)?)
176 }
177
178 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
179 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
180 Ok(sqlx::query_as(query)
181 .bind(limit)
182 .bind(page * limit)
183 .fetch_all(&self.pool)
184 .await?)
185 }
186
187 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
188 let mut query = QueryBuilder::new(
189 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
190 );
191 query.push_values(
192 users,
193 |mut query, (github_login, email_address, invite_count)| {
194 query
195 .push_bind(github_login)
196 .push_bind(email_address)
197 .push_bind(false)
198 .push_bind(nanoid!(16))
199 .push_bind(invite_count as u32);
200 },
201 );
202 query.push(
203 "
204 ON CONFLICT (github_login) DO UPDATE SET
205 github_login = excluded.github_login,
206 invite_count = excluded.invite_count,
207 invite_code = CASE WHEN users.invite_code IS NULL
208 THEN excluded.invite_code
209 ELSE users.invite_code
210 END
211 RETURNING id
212 ",
213 );
214
215 let rows = query.build().fetch_all(&self.pool).await?;
216 Ok(rows
217 .into_iter()
218 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
219 .collect())
220 }
221
222 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
223 let like_string = fuzzy_like_string(name_query);
224 let query = "
225 SELECT users.*
226 FROM users
227 WHERE github_login ILIKE $1
228 ORDER BY github_login <-> $2
229 LIMIT $3
230 ";
231 Ok(sqlx::query_as(query)
232 .bind(like_string)
233 .bind(name_query)
234 .bind(limit)
235 .fetch_all(&self.pool)
236 .await?)
237 }
238
239 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
240 let users = self.get_users_by_ids(vec![id]).await?;
241 Ok(users.into_iter().next())
242 }
243
244 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
245 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
246 let query = "
247 SELECT users.*
248 FROM users
249 WHERE users.id = ANY ($1)
250 ";
251 Ok(sqlx::query_as(query)
252 .bind(&ids)
253 .fetch_all(&self.pool)
254 .await?)
255 }
256
257 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
258 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
259 Ok(sqlx::query_as(query)
260 .bind(github_login)
261 .fetch_optional(&self.pool)
262 .await?)
263 }
264
265 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
266 let query = "UPDATE users SET admin = $1 WHERE id = $2";
267 Ok(sqlx::query(query)
268 .bind(is_admin)
269 .bind(id.0)
270 .execute(&self.pool)
271 .await
272 .map(drop)?)
273 }
274
275 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
276 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
277 Ok(sqlx::query(query)
278 .bind(connected_once)
279 .bind(id.0)
280 .execute(&self.pool)
281 .await
282 .map(drop)?)
283 }
284
285 async fn destroy_user(&self, id: UserId) -> Result<()> {
286 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
287 sqlx::query(query)
288 .bind(id.0)
289 .execute(&self.pool)
290 .await
291 .map(drop)?;
292 let query = "DELETE FROM users WHERE id = $1;";
293 Ok(sqlx::query(query)
294 .bind(id.0)
295 .execute(&self.pool)
296 .await
297 .map(drop)?)
298 }
299
300 // invite codes
301
302 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
303 let mut tx = self.pool.begin().await?;
304 if count > 0 {
305 sqlx::query(
306 "
307 UPDATE users
308 SET invite_code = $1
309 WHERE id = $2 AND invite_code IS NULL
310 ",
311 )
312 .bind(nanoid!(16))
313 .bind(id)
314 .execute(&mut tx)
315 .await?;
316 }
317
318 sqlx::query(
319 "
320 UPDATE users
321 SET invite_count = $1
322 WHERE id = $2
323 ",
324 )
325 .bind(count)
326 .bind(id)
327 .execute(&mut tx)
328 .await?;
329 tx.commit().await?;
330 Ok(())
331 }
332
333 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
334 let result: Option<(String, i32)> = sqlx::query_as(
335 "
336 SELECT invite_code, invite_count
337 FROM users
338 WHERE id = $1 AND invite_code IS NOT NULL
339 ",
340 )
341 .bind(id)
342 .fetch_optional(&self.pool)
343 .await?;
344 if let Some((code, count)) = result {
345 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
346 } else {
347 Ok(None)
348 }
349 }
350
351 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
352 sqlx::query_as(
353 "
354 SELECT *
355 FROM users
356 WHERE invite_code = $1
357 ",
358 )
359 .bind(code)
360 .fetch_optional(&self.pool)
361 .await?
362 .ok_or_else(|| {
363 Error::Http(
364 StatusCode::NOT_FOUND,
365 "that invite code does not exist".to_string(),
366 )
367 })
368 }
369
370 async fn redeem_invite_code(
371 &self,
372 code: &str,
373 login: &str,
374 email_address: Option<&str>,
375 ) -> Result<UserId> {
376 let mut tx = self.pool.begin().await?;
377
378 let inviter_id: Option<UserId> = sqlx::query_scalar(
379 "
380 UPDATE users
381 SET invite_count = invite_count - 1
382 WHERE
383 invite_code = $1 AND
384 invite_count > 0
385 RETURNING id
386 ",
387 )
388 .bind(code)
389 .fetch_optional(&mut tx)
390 .await?;
391
392 let inviter_id = match inviter_id {
393 Some(inviter_id) => inviter_id,
394 None => {
395 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
396 .bind(code)
397 .fetch_optional(&mut tx)
398 .await?
399 .is_some()
400 {
401 Err(Error::Http(
402 StatusCode::UNAUTHORIZED,
403 "no invites remaining".to_string(),
404 ))?
405 } else {
406 Err(Error::Http(
407 StatusCode::NOT_FOUND,
408 "invite code not found".to_string(),
409 ))?
410 }
411 }
412 };
413
414 let invitee_id = sqlx::query_scalar(
415 "
416 INSERT INTO users
417 (github_login, email_address, admin, inviter_id)
418 VALUES
419 ($1, $2, 'f', $3)
420 RETURNING id
421 ",
422 )
423 .bind(login)
424 .bind(email_address)
425 .bind(inviter_id)
426 .fetch_one(&mut tx)
427 .await
428 .map(UserId)?;
429
430 sqlx::query(
431 "
432 INSERT INTO contacts
433 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
434 VALUES
435 ($1, $2, 't', 't', 't')
436 ",
437 )
438 .bind(inviter_id)
439 .bind(invitee_id)
440 .execute(&mut tx)
441 .await?;
442
443 tx.commit().await?;
444 Ok(invitee_id)
445 }
446
447 // projects
448
449 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
450 Ok(sqlx::query_scalar(
451 "
452 INSERT INTO projects(host_user_id)
453 VALUES ($1)
454 RETURNING id
455 ",
456 )
457 .bind(host_user_id)
458 .fetch_one(&self.pool)
459 .await
460 .map(ProjectId)?)
461 }
462
463 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
464 sqlx::query(
465 "
466 UPDATE projects
467 SET unregistered = 't'
468 WHERE project_id = $1
469 ",
470 )
471 .bind(project_id)
472 .execute(&self.pool)
473 .await?;
474 Ok(())
475 }
476
477 async fn update_worktree_extensions(
478 &self,
479 project_id: ProjectId,
480 worktree_id: u64,
481 extensions: HashMap<String, usize>,
482 ) -> Result<()> {
483 let mut query = QueryBuilder::new(
484 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
485 );
486 query.push_values(extensions, |mut query, (extension, count)| {
487 query
488 .push_bind(project_id)
489 .push_bind(worktree_id as i32)
490 .push_bind(extension)
491 .push_bind(count as u32);
492 });
493 query.push(
494 "
495 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
496 count = excluded.count
497 ",
498 );
499 query.build().execute(&self.pool).await?;
500
501 Ok(())
502 }
503
504 async fn record_project_activity(
505 &self,
506 time_period: Range<OffsetDateTime>,
507 projects: &[(UserId, ProjectId)],
508 ) -> Result<()> {
509 let query = "
510 INSERT INTO project_activity_periods
511 (ended_at, duration_millis, user_id, project_id)
512 VALUES
513 ($1, $2, $3, $4);
514 ";
515
516 let mut tx = self.pool.begin().await?;
517 let duration_millis =
518 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
519 for (user_id, project_id) in projects {
520 sqlx::query(query)
521 .bind(time_period.end)
522 .bind(duration_millis)
523 .bind(user_id)
524 .bind(project_id)
525 .execute(&mut tx)
526 .await?;
527 }
528 tx.commit().await?;
529
530 Ok(())
531 }
532
533 async fn summarize_project_activity(
534 &self,
535 time_period: Range<OffsetDateTime>,
536 max_user_count: usize,
537 ) -> Result<Vec<UserActivitySummary>> {
538 let query = "
539 WITH
540 project_durations AS (
541 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
542 FROM project_activity_periods
543 WHERE $1 <= ended_at AND ended_at <= $2
544 GROUP BY user_id, project_id
545 ),
546 user_durations AS (
547 SELECT user_id, SUM(project_duration) as total_duration
548 FROM project_durations
549 GROUP BY user_id
550 ORDER BY total_duration DESC
551 LIMIT $3
552 )
553 SELECT user_durations.user_id, users.github_login, project_id, project_duration
554 FROM user_durations, project_durations, users
555 WHERE
556 user_durations.user_id = project_durations.user_id AND
557 user_durations.user_id = users.id
558 ORDER BY user_id ASC, project_duration DESC
559 ";
560
561 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
562 .bind(time_period.start)
563 .bind(time_period.end)
564 .bind(max_user_count as i32)
565 .fetch(&self.pool);
566
567 let mut result = Vec::<UserActivitySummary>::new();
568 while let Some(row) = rows.next().await {
569 let (user_id, github_login, project_id, duration_millis) = row?;
570 let project_id = project_id;
571 let duration = Duration::from_millis(duration_millis as u64);
572 if let Some(last_summary) = result.last_mut() {
573 if last_summary.id == user_id {
574 last_summary.project_activity.push((project_id, duration));
575 continue;
576 }
577 }
578 result.push(UserActivitySummary {
579 id: user_id,
580 project_activity: vec![(project_id, duration)],
581 github_login,
582 });
583 }
584
585 Ok(result)
586 }
587
588 // contacts
589
590 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
591 let query = "
592 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
593 FROM contacts
594 WHERE user_id_a = $1 OR user_id_b = $1;
595 ";
596
597 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
598 .bind(user_id)
599 .fetch(&self.pool);
600
601 let mut contacts = vec![Contact::Accepted {
602 user_id,
603 should_notify: false,
604 }];
605 while let Some(row) = rows.next().await {
606 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
607
608 if user_id_a == user_id {
609 if accepted {
610 contacts.push(Contact::Accepted {
611 user_id: user_id_b,
612 should_notify: should_notify && a_to_b,
613 });
614 } else if a_to_b {
615 contacts.push(Contact::Outgoing { user_id: user_id_b })
616 } else {
617 contacts.push(Contact::Incoming {
618 user_id: user_id_b,
619 should_notify,
620 });
621 }
622 } else {
623 if accepted {
624 contacts.push(Contact::Accepted {
625 user_id: user_id_a,
626 should_notify: should_notify && !a_to_b,
627 });
628 } else if a_to_b {
629 contacts.push(Contact::Incoming {
630 user_id: user_id_a,
631 should_notify,
632 });
633 } else {
634 contacts.push(Contact::Outgoing { user_id: user_id_a });
635 }
636 }
637 }
638
639 contacts.sort_unstable_by_key(|contact| contact.user_id());
640
641 Ok(contacts)
642 }
643
644 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
645 let (id_a, id_b) = if user_id_1 < user_id_2 {
646 (user_id_1, user_id_2)
647 } else {
648 (user_id_2, user_id_1)
649 };
650
651 let query = "
652 SELECT 1 FROM contacts
653 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
654 LIMIT 1
655 ";
656 Ok(sqlx::query_scalar::<_, i32>(query)
657 .bind(id_a.0)
658 .bind(id_b.0)
659 .fetch_optional(&self.pool)
660 .await?
661 .is_some())
662 }
663
664 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
665 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
666 (sender_id, receiver_id, true)
667 } else {
668 (receiver_id, sender_id, false)
669 };
670 let query = "
671 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
672 VALUES ($1, $2, $3, 'f', 't')
673 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
674 SET
675 accepted = 't',
676 should_notify = 'f'
677 WHERE
678 NOT contacts.accepted AND
679 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
680 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
681 ";
682 let result = sqlx::query(query)
683 .bind(id_a.0)
684 .bind(id_b.0)
685 .bind(a_to_b)
686 .execute(&self.pool)
687 .await?;
688
689 if result.rows_affected() == 1 {
690 Ok(())
691 } else {
692 Err(anyhow!("contact already requested"))?
693 }
694 }
695
696 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
697 let (id_a, id_b) = if responder_id < requester_id {
698 (responder_id, requester_id)
699 } else {
700 (requester_id, responder_id)
701 };
702 let query = "
703 DELETE FROM contacts
704 WHERE user_id_a = $1 AND user_id_b = $2;
705 ";
706 let result = sqlx::query(query)
707 .bind(id_a.0)
708 .bind(id_b.0)
709 .execute(&self.pool)
710 .await?;
711
712 if result.rows_affected() == 1 {
713 Ok(())
714 } else {
715 Err(anyhow!("no such contact"))?
716 }
717 }
718
719 async fn dismiss_contact_notification(
720 &self,
721 user_id: UserId,
722 contact_user_id: UserId,
723 ) -> Result<()> {
724 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
725 (user_id, contact_user_id, true)
726 } else {
727 (contact_user_id, user_id, false)
728 };
729
730 let query = "
731 UPDATE contacts
732 SET should_notify = 'f'
733 WHERE
734 user_id_a = $1 AND user_id_b = $2 AND
735 (
736 (a_to_b = $3 AND accepted) OR
737 (a_to_b != $3 AND NOT accepted)
738 );
739 ";
740
741 let result = sqlx::query(query)
742 .bind(id_a.0)
743 .bind(id_b.0)
744 .bind(a_to_b)
745 .execute(&self.pool)
746 .await?;
747
748 if result.rows_affected() == 0 {
749 Err(anyhow!("no such contact request"))?;
750 }
751
752 Ok(())
753 }
754
755 async fn respond_to_contact_request(
756 &self,
757 responder_id: UserId,
758 requester_id: UserId,
759 accept: bool,
760 ) -> Result<()> {
761 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
762 (responder_id, requester_id, false)
763 } else {
764 (requester_id, responder_id, true)
765 };
766 let result = if accept {
767 let query = "
768 UPDATE contacts
769 SET accepted = 't', should_notify = 't'
770 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
771 ";
772 sqlx::query(query)
773 .bind(id_a.0)
774 .bind(id_b.0)
775 .bind(a_to_b)
776 .execute(&self.pool)
777 .await?
778 } else {
779 let query = "
780 DELETE FROM contacts
781 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
782 ";
783 sqlx::query(query)
784 .bind(id_a.0)
785 .bind(id_b.0)
786 .bind(a_to_b)
787 .execute(&self.pool)
788 .await?
789 };
790 if result.rows_affected() == 1 {
791 Ok(())
792 } else {
793 Err(anyhow!("no such contact request"))?
794 }
795 }
796
797 // access tokens
798
799 async fn create_access_token_hash(
800 &self,
801 user_id: UserId,
802 access_token_hash: &str,
803 max_access_token_count: usize,
804 ) -> Result<()> {
805 let insert_query = "
806 INSERT INTO access_tokens (user_id, hash)
807 VALUES ($1, $2);
808 ";
809 let cleanup_query = "
810 DELETE FROM access_tokens
811 WHERE id IN (
812 SELECT id from access_tokens
813 WHERE user_id = $1
814 ORDER BY id DESC
815 OFFSET $3
816 )
817 ";
818
819 let mut tx = self.pool.begin().await?;
820 sqlx::query(insert_query)
821 .bind(user_id.0)
822 .bind(access_token_hash)
823 .execute(&mut tx)
824 .await?;
825 sqlx::query(cleanup_query)
826 .bind(user_id.0)
827 .bind(access_token_hash)
828 .bind(max_access_token_count as u32)
829 .execute(&mut tx)
830 .await?;
831 Ok(tx.commit().await?)
832 }
833
834 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
835 let query = "
836 SELECT hash
837 FROM access_tokens
838 WHERE user_id = $1
839 ORDER BY id DESC
840 ";
841 Ok(sqlx::query_scalar(query)
842 .bind(user_id.0)
843 .fetch_all(&self.pool)
844 .await?)
845 }
846
847 // orgs
848
849 #[allow(unused)] // Help rust-analyzer
850 #[cfg(any(test, feature = "seed-support"))]
851 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
852 let query = "
853 SELECT *
854 FROM orgs
855 WHERE slug = $1
856 ";
857 Ok(sqlx::query_as(query)
858 .bind(slug)
859 .fetch_optional(&self.pool)
860 .await?)
861 }
862
863 #[cfg(any(test, feature = "seed-support"))]
864 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
865 let query = "
866 INSERT INTO orgs (name, slug)
867 VALUES ($1, $2)
868 RETURNING id
869 ";
870 Ok(sqlx::query_scalar(query)
871 .bind(name)
872 .bind(slug)
873 .fetch_one(&self.pool)
874 .await
875 .map(OrgId)?)
876 }
877
878 #[cfg(any(test, feature = "seed-support"))]
879 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
880 let query = "
881 INSERT INTO org_memberships (org_id, user_id, admin)
882 VALUES ($1, $2, $3)
883 ON CONFLICT DO NOTHING
884 ";
885 Ok(sqlx::query(query)
886 .bind(org_id.0)
887 .bind(user_id.0)
888 .bind(is_admin)
889 .execute(&self.pool)
890 .await
891 .map(drop)?)
892 }
893
894 // channels
895
896 #[cfg(any(test, feature = "seed-support"))]
897 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
898 let query = "
899 INSERT INTO channels (owner_id, owner_is_user, name)
900 VALUES ($1, false, $2)
901 RETURNING id
902 ";
903 Ok(sqlx::query_scalar(query)
904 .bind(org_id.0)
905 .bind(name)
906 .fetch_one(&self.pool)
907 .await
908 .map(ChannelId)?)
909 }
910
911 #[allow(unused)] // Help rust-analyzer
912 #[cfg(any(test, feature = "seed-support"))]
913 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
914 let query = "
915 SELECT *
916 FROM channels
917 WHERE
918 channels.owner_is_user = false AND
919 channels.owner_id = $1
920 ";
921 Ok(sqlx::query_as(query)
922 .bind(org_id.0)
923 .fetch_all(&self.pool)
924 .await?)
925 }
926
927 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
928 let query = "
929 SELECT
930 channels.*
931 FROM
932 channel_memberships, channels
933 WHERE
934 channel_memberships.user_id = $1 AND
935 channel_memberships.channel_id = channels.id
936 ";
937 Ok(sqlx::query_as(query)
938 .bind(user_id.0)
939 .fetch_all(&self.pool)
940 .await?)
941 }
942
943 async fn can_user_access_channel(
944 &self,
945 user_id: UserId,
946 channel_id: ChannelId,
947 ) -> Result<bool> {
948 let query = "
949 SELECT id
950 FROM channel_memberships
951 WHERE user_id = $1 AND channel_id = $2
952 LIMIT 1
953 ";
954 Ok(sqlx::query_scalar::<_, i32>(query)
955 .bind(user_id.0)
956 .bind(channel_id.0)
957 .fetch_optional(&self.pool)
958 .await
959 .map(|e| e.is_some())?)
960 }
961
962 #[cfg(any(test, feature = "seed-support"))]
963 async fn add_channel_member(
964 &self,
965 channel_id: ChannelId,
966 user_id: UserId,
967 is_admin: bool,
968 ) -> Result<()> {
969 let query = "
970 INSERT INTO channel_memberships (channel_id, user_id, admin)
971 VALUES ($1, $2, $3)
972 ON CONFLICT DO NOTHING
973 ";
974 Ok(sqlx::query(query)
975 .bind(channel_id.0)
976 .bind(user_id.0)
977 .bind(is_admin)
978 .execute(&self.pool)
979 .await
980 .map(drop)?)
981 }
982
983 // messages
984
985 async fn create_channel_message(
986 &self,
987 channel_id: ChannelId,
988 sender_id: UserId,
989 body: &str,
990 timestamp: OffsetDateTime,
991 nonce: u128,
992 ) -> Result<MessageId> {
993 let query = "
994 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
995 VALUES ($1, $2, $3, $4, $5)
996 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
997 RETURNING id
998 ";
999 Ok(sqlx::query_scalar(query)
1000 .bind(channel_id.0)
1001 .bind(sender_id.0)
1002 .bind(body)
1003 .bind(timestamp)
1004 .bind(Uuid::from_u128(nonce))
1005 .fetch_one(&self.pool)
1006 .await
1007 .map(MessageId)?)
1008 }
1009
1010 async fn get_channel_messages(
1011 &self,
1012 channel_id: ChannelId,
1013 count: usize,
1014 before_id: Option<MessageId>,
1015 ) -> Result<Vec<ChannelMessage>> {
1016 let query = r#"
1017 SELECT * FROM (
1018 SELECT
1019 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1020 FROM
1021 channel_messages
1022 WHERE
1023 channel_id = $1 AND
1024 id < $2
1025 ORDER BY id DESC
1026 LIMIT $3
1027 ) as recent_messages
1028 ORDER BY id ASC
1029 "#;
1030 Ok(sqlx::query_as(query)
1031 .bind(channel_id.0)
1032 .bind(before_id.unwrap_or(MessageId::MAX))
1033 .bind(count as i64)
1034 .fetch_all(&self.pool)
1035 .await?)
1036 }
1037
1038 #[cfg(test)]
1039 async fn teardown(&self, url: &str) {
1040 use util::ResultExt;
1041
1042 let query = "
1043 SELECT pg_terminate_backend(pg_stat_activity.pid)
1044 FROM pg_stat_activity
1045 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1046 ";
1047 sqlx::query(query).execute(&self.pool).await.log_err();
1048 self.pool.close().await;
1049 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1050 .await
1051 .log_err();
1052 }
1053
1054 #[cfg(test)]
1055 fn as_fake(&self) -> Option<&tests::FakeDb> {
1056 None
1057 }
1058}
1059
1060macro_rules! id_type {
1061 ($name:ident) => {
1062 #[derive(
1063 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
1064 )]
1065 #[sqlx(transparent)]
1066 #[serde(transparent)]
1067 pub struct $name(pub i32);
1068
1069 impl $name {
1070 #[allow(unused)]
1071 pub const MAX: Self = Self(i32::MAX);
1072
1073 #[allow(unused)]
1074 pub fn from_proto(value: u64) -> Self {
1075 Self(value as i32)
1076 }
1077
1078 #[allow(unused)]
1079 pub fn to_proto(&self) -> u64 {
1080 self.0 as u64
1081 }
1082 }
1083
1084 impl std::fmt::Display for $name {
1085 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1086 self.0.fmt(f)
1087 }
1088 }
1089 };
1090}
1091
1092id_type!(UserId);
1093#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1094pub struct User {
1095 pub id: UserId,
1096 pub github_login: String,
1097 pub email_address: Option<String>,
1098 pub admin: bool,
1099 pub invite_code: Option<String>,
1100 pub invite_count: i32,
1101 pub connected_once: bool,
1102}
1103
1104id_type!(ProjectId);
1105#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1106pub struct Project {
1107 pub id: ProjectId,
1108 pub host_user_id: UserId,
1109 pub unregistered: bool,
1110}
1111
1112#[derive(Clone, Debug, PartialEq, Serialize)]
1113pub struct UserActivitySummary {
1114 pub id: UserId,
1115 pub github_login: String,
1116 pub project_activity: Vec<(ProjectId, Duration)>,
1117}
1118
1119id_type!(OrgId);
1120#[derive(FromRow)]
1121pub struct Org {
1122 pub id: OrgId,
1123 pub name: String,
1124 pub slug: String,
1125}
1126
1127id_type!(ChannelId);
1128#[derive(Clone, Debug, FromRow, Serialize)]
1129pub struct Channel {
1130 pub id: ChannelId,
1131 pub name: String,
1132 pub owner_id: i32,
1133 pub owner_is_user: bool,
1134}
1135
1136id_type!(MessageId);
1137#[derive(Clone, Debug, FromRow)]
1138pub struct ChannelMessage {
1139 pub id: MessageId,
1140 pub channel_id: ChannelId,
1141 pub sender_id: UserId,
1142 pub body: String,
1143 pub sent_at: OffsetDateTime,
1144 pub nonce: Uuid,
1145}
1146
1147#[derive(Clone, Debug, PartialEq, Eq)]
1148pub enum Contact {
1149 Accepted {
1150 user_id: UserId,
1151 should_notify: bool,
1152 },
1153 Outgoing {
1154 user_id: UserId,
1155 },
1156 Incoming {
1157 user_id: UserId,
1158 should_notify: bool,
1159 },
1160}
1161
1162impl Contact {
1163 pub fn user_id(&self) -> UserId {
1164 match self {
1165 Contact::Accepted { user_id, .. } => *user_id,
1166 Contact::Outgoing { user_id } => *user_id,
1167 Contact::Incoming { user_id, .. } => *user_id,
1168 }
1169 }
1170}
1171
1172#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1173pub struct IncomingContactRequest {
1174 pub requester_id: UserId,
1175 pub should_notify: bool,
1176}
1177
1178fn fuzzy_like_string(string: &str) -> String {
1179 let mut result = String::with_capacity(string.len() * 2 + 1);
1180 for c in string.chars() {
1181 if c.is_alphanumeric() {
1182 result.push('%');
1183 result.push(c);
1184 }
1185 }
1186 result.push('%');
1187 result
1188}
1189
1190#[cfg(test)]
1191pub mod tests {
1192 use super::*;
1193 use anyhow::anyhow;
1194 use collections::BTreeMap;
1195 use gpui::executor::Background;
1196 use lazy_static::lazy_static;
1197 use parking_lot::Mutex;
1198 use rand::prelude::*;
1199 use sqlx::{
1200 migrate::{MigrateDatabase, Migrator},
1201 Postgres,
1202 };
1203 use std::{path::Path, sync::Arc};
1204 use util::post_inc;
1205
1206 #[tokio::test(flavor = "multi_thread")]
1207 async fn test_get_users_by_ids() {
1208 for test_db in [
1209 TestDb::postgres().await,
1210 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1211 ] {
1212 let db = test_db.db();
1213
1214 let user = db.create_user("user", None, false).await.unwrap();
1215 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1216 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1217 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1218
1219 assert_eq!(
1220 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1221 .await
1222 .unwrap(),
1223 vec![
1224 User {
1225 id: user,
1226 github_login: "user".to_string(),
1227 admin: false,
1228 ..Default::default()
1229 },
1230 User {
1231 id: friend1,
1232 github_login: "friend-1".to_string(),
1233 admin: false,
1234 ..Default::default()
1235 },
1236 User {
1237 id: friend2,
1238 github_login: "friend-2".to_string(),
1239 admin: false,
1240 ..Default::default()
1241 },
1242 User {
1243 id: friend3,
1244 github_login: "friend-3".to_string(),
1245 admin: false,
1246 ..Default::default()
1247 }
1248 ]
1249 );
1250 }
1251 }
1252
1253 #[tokio::test(flavor = "multi_thread")]
1254 async fn test_create_users() {
1255 let db = TestDb::postgres().await;
1256 let db = db.db();
1257
1258 // Create the first batch of users, ensuring invite counts are assigned
1259 // correctly and the respective invite codes are unique.
1260 let user_ids_batch_1 = db
1261 .create_users(vec![
1262 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1263 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1264 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1265 ])
1266 .await
1267 .unwrap();
1268 assert_eq!(user_ids_batch_1.len(), 3);
1269
1270 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1271 assert_eq!(users.len(), 3);
1272 assert_eq!(users[0].github_login, "user1");
1273 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1274 assert_eq!(users[0].invite_count, 5);
1275 assert_eq!(users[1].github_login, "user2");
1276 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1277 assert_eq!(users[1].invite_count, 4);
1278 assert_eq!(users[2].github_login, "user3");
1279 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1280 assert_eq!(users[2].invite_count, 3);
1281
1282 let invite_code_1 = users[0].invite_code.clone().unwrap();
1283 let invite_code_2 = users[1].invite_code.clone().unwrap();
1284 let invite_code_3 = users[2].invite_code.clone().unwrap();
1285 assert_ne!(invite_code_1, invite_code_2);
1286 assert_ne!(invite_code_1, invite_code_3);
1287 assert_ne!(invite_code_2, invite_code_3);
1288
1289 // Create the second batch of users and include a user that is already in the database, ensuring
1290 // the invite count for the existing user is updated without changing their invite code.
1291 let user_ids_batch_2 = db
1292 .create_users(vec![
1293 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1294 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1295 ])
1296 .await
1297 .unwrap();
1298 assert_eq!(user_ids_batch_2.len(), 2);
1299 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1300
1301 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1302 assert_eq!(users.len(), 2);
1303 assert_eq!(users[0].github_login, "user2");
1304 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1305 assert_eq!(users[0].invite_count, 10);
1306 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1307 assert_eq!(users[1].github_login, "user4");
1308 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1309 assert_eq!(users[1].invite_count, 2);
1310
1311 let invite_code_4 = users[1].invite_code.clone().unwrap();
1312 assert_ne!(invite_code_4, invite_code_1);
1313 assert_ne!(invite_code_4, invite_code_2);
1314 assert_ne!(invite_code_4, invite_code_3);
1315 }
1316
1317 #[tokio::test(flavor = "multi_thread")]
1318 async fn test_project_activity() {
1319 let test_db = TestDb::postgres().await;
1320 let db = test_db.db();
1321
1322 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1323 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1324 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1325 let project_1 = db.register_project(user_1).await.unwrap();
1326 let project_2 = db.register_project(user_2).await.unwrap();
1327 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1328
1329 // User 2 opens a project
1330 let t1 = t0 + Duration::from_secs(10);
1331 db.record_project_activity(t0..t1, &[(user_2, project_2)])
1332 .await
1333 .unwrap();
1334
1335 let t2 = t1 + Duration::from_secs(10);
1336 db.record_project_activity(t1..t2, &[(user_2, project_2)])
1337 .await
1338 .unwrap();
1339
1340 // User 1 joins the project
1341 let t3 = t2 + Duration::from_secs(10);
1342 db.record_project_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1343 .await
1344 .unwrap();
1345
1346 // User 1 opens another project
1347 let t4 = t3 + Duration::from_secs(10);
1348 db.record_project_activity(
1349 t3..t4,
1350 &[
1351 (user_2, project_2),
1352 (user_1, project_2),
1353 (user_1, project_1),
1354 ],
1355 )
1356 .await
1357 .unwrap();
1358
1359 // User 3 joins that project
1360 let t5 = t4 + Duration::from_secs(10);
1361 db.record_project_activity(
1362 t4..t5,
1363 &[
1364 (user_2, project_2),
1365 (user_1, project_2),
1366 (user_1, project_1),
1367 (user_3, project_1),
1368 ],
1369 )
1370 .await
1371 .unwrap();
1372
1373 // User 2 leaves
1374 let t6 = t5 + Duration::from_secs(5);
1375 db.record_project_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1376 .await
1377 .unwrap();
1378
1379 let summary = db.summarize_project_activity(t0..t6, 10).await.unwrap();
1380 assert_eq!(
1381 summary,
1382 &[
1383 UserActivitySummary {
1384 id: user_1,
1385 github_login: "user_1".to_string(),
1386 project_activity: vec![
1387 (project_2, Duration::from_secs(30)),
1388 (project_1, Duration::from_secs(25))
1389 ]
1390 },
1391 UserActivitySummary {
1392 id: user_2,
1393 github_login: "user_2".to_string(),
1394 project_activity: vec![(project_2, Duration::from_secs(50))]
1395 },
1396 UserActivitySummary {
1397 id: user_3,
1398 github_login: "user_3".to_string(),
1399 project_activity: vec![(project_1, Duration::from_secs(15))]
1400 },
1401 ]
1402 );
1403 }
1404
1405 #[tokio::test(flavor = "multi_thread")]
1406 async fn test_recent_channel_messages() {
1407 for test_db in [
1408 TestDb::postgres().await,
1409 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1410 ] {
1411 let db = test_db.db();
1412 let user = db.create_user("user", None, false).await.unwrap();
1413 let org = db.create_org("org", "org").await.unwrap();
1414 let channel = db.create_org_channel(org, "channel").await.unwrap();
1415 for i in 0..10 {
1416 db.create_channel_message(
1417 channel,
1418 user,
1419 &i.to_string(),
1420 OffsetDateTime::now_utc(),
1421 i,
1422 )
1423 .await
1424 .unwrap();
1425 }
1426
1427 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1428 assert_eq!(
1429 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1430 ["5", "6", "7", "8", "9"]
1431 );
1432
1433 let prev_messages = db
1434 .get_channel_messages(channel, 4, Some(messages[0].id))
1435 .await
1436 .unwrap();
1437 assert_eq!(
1438 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1439 ["1", "2", "3", "4"]
1440 );
1441 }
1442 }
1443
1444 #[tokio::test(flavor = "multi_thread")]
1445 async fn test_channel_message_nonces() {
1446 for test_db in [
1447 TestDb::postgres().await,
1448 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1449 ] {
1450 let db = test_db.db();
1451 let user = db.create_user("user", None, false).await.unwrap();
1452 let org = db.create_org("org", "org").await.unwrap();
1453 let channel = db.create_org_channel(org, "channel").await.unwrap();
1454
1455 let msg1_id = db
1456 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1457 .await
1458 .unwrap();
1459 let msg2_id = db
1460 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1461 .await
1462 .unwrap();
1463 let msg3_id = db
1464 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1465 .await
1466 .unwrap();
1467 let msg4_id = db
1468 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1469 .await
1470 .unwrap();
1471
1472 assert_ne!(msg1_id, msg2_id);
1473 assert_eq!(msg1_id, msg3_id);
1474 assert_eq!(msg2_id, msg4_id);
1475 }
1476 }
1477
1478 #[tokio::test(flavor = "multi_thread")]
1479 async fn test_create_access_tokens() {
1480 let test_db = TestDb::postgres().await;
1481 let db = test_db.db();
1482 let user = db.create_user("the-user", None, false).await.unwrap();
1483
1484 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1485 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1486 assert_eq!(
1487 db.get_access_token_hashes(user).await.unwrap(),
1488 &["h2".to_string(), "h1".to_string()]
1489 );
1490
1491 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1492 assert_eq!(
1493 db.get_access_token_hashes(user).await.unwrap(),
1494 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1495 );
1496
1497 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1498 assert_eq!(
1499 db.get_access_token_hashes(user).await.unwrap(),
1500 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1501 );
1502
1503 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1504 assert_eq!(
1505 db.get_access_token_hashes(user).await.unwrap(),
1506 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1507 );
1508 }
1509
1510 #[test]
1511 fn test_fuzzy_like_string() {
1512 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1513 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1514 assert_eq!(fuzzy_like_string(" z "), "%z%");
1515 }
1516
1517 #[tokio::test(flavor = "multi_thread")]
1518 async fn test_fuzzy_search_users() {
1519 let test_db = TestDb::postgres().await;
1520 let db = test_db.db();
1521 for github_login in [
1522 "California",
1523 "colorado",
1524 "oregon",
1525 "washington",
1526 "florida",
1527 "delaware",
1528 "rhode-island",
1529 ] {
1530 db.create_user(github_login, None, false).await.unwrap();
1531 }
1532
1533 assert_eq!(
1534 fuzzy_search_user_names(db, "clr").await,
1535 &["colorado", "California"]
1536 );
1537 assert_eq!(
1538 fuzzy_search_user_names(db, "ro").await,
1539 &["rhode-island", "colorado", "oregon"],
1540 );
1541
1542 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1543 db.fuzzy_search_users(query, 10)
1544 .await
1545 .unwrap()
1546 .into_iter()
1547 .map(|user| user.github_login)
1548 .collect::<Vec<_>>()
1549 }
1550 }
1551
1552 #[tokio::test(flavor = "multi_thread")]
1553 async fn test_add_contacts() {
1554 for test_db in [
1555 TestDb::postgres().await,
1556 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1557 ] {
1558 let db = test_db.db();
1559
1560 let user_1 = db.create_user("user1", None, false).await.unwrap();
1561 let user_2 = db.create_user("user2", None, false).await.unwrap();
1562 let user_3 = db.create_user("user3", None, false).await.unwrap();
1563
1564 // User starts with no contacts
1565 assert_eq!(
1566 db.get_contacts(user_1).await.unwrap(),
1567 vec![Contact::Accepted {
1568 user_id: user_1,
1569 should_notify: false
1570 }],
1571 );
1572
1573 // User requests a contact. Both users see the pending request.
1574 db.send_contact_request(user_1, user_2).await.unwrap();
1575 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1576 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1577 assert_eq!(
1578 db.get_contacts(user_1).await.unwrap(),
1579 &[
1580 Contact::Accepted {
1581 user_id: user_1,
1582 should_notify: false
1583 },
1584 Contact::Outgoing { user_id: user_2 }
1585 ],
1586 );
1587 assert_eq!(
1588 db.get_contacts(user_2).await.unwrap(),
1589 &[
1590 Contact::Incoming {
1591 user_id: user_1,
1592 should_notify: true
1593 },
1594 Contact::Accepted {
1595 user_id: user_2,
1596 should_notify: false
1597 },
1598 ]
1599 );
1600
1601 // User 2 dismisses the contact request notification without accepting or rejecting.
1602 // We shouldn't notify them again.
1603 db.dismiss_contact_notification(user_1, user_2)
1604 .await
1605 .unwrap_err();
1606 db.dismiss_contact_notification(user_2, user_1)
1607 .await
1608 .unwrap();
1609 assert_eq!(
1610 db.get_contacts(user_2).await.unwrap(),
1611 &[
1612 Contact::Incoming {
1613 user_id: user_1,
1614 should_notify: false
1615 },
1616 Contact::Accepted {
1617 user_id: user_2,
1618 should_notify: false
1619 },
1620 ]
1621 );
1622
1623 // User can't accept their own contact request
1624 db.respond_to_contact_request(user_1, user_2, true)
1625 .await
1626 .unwrap_err();
1627
1628 // User accepts a contact request. Both users see the contact.
1629 db.respond_to_contact_request(user_2, user_1, true)
1630 .await
1631 .unwrap();
1632 assert_eq!(
1633 db.get_contacts(user_1).await.unwrap(),
1634 &[
1635 Contact::Accepted {
1636 user_id: user_1,
1637 should_notify: false
1638 },
1639 Contact::Accepted {
1640 user_id: user_2,
1641 should_notify: true
1642 }
1643 ],
1644 );
1645 assert!(db.has_contact(user_1, user_2).await.unwrap());
1646 assert!(db.has_contact(user_2, user_1).await.unwrap());
1647 assert_eq!(
1648 db.get_contacts(user_2).await.unwrap(),
1649 &[
1650 Contact::Accepted {
1651 user_id: user_1,
1652 should_notify: false,
1653 },
1654 Contact::Accepted {
1655 user_id: user_2,
1656 should_notify: false,
1657 },
1658 ]
1659 );
1660
1661 // Users cannot re-request existing contacts.
1662 db.send_contact_request(user_1, user_2).await.unwrap_err();
1663 db.send_contact_request(user_2, user_1).await.unwrap_err();
1664
1665 // Users can't dismiss notifications of them accepting other users' requests.
1666 db.dismiss_contact_notification(user_2, user_1)
1667 .await
1668 .unwrap_err();
1669 assert_eq!(
1670 db.get_contacts(user_1).await.unwrap(),
1671 &[
1672 Contact::Accepted {
1673 user_id: user_1,
1674 should_notify: false
1675 },
1676 Contact::Accepted {
1677 user_id: user_2,
1678 should_notify: true,
1679 },
1680 ]
1681 );
1682
1683 // Users can dismiss notifications of other users accepting their requests.
1684 db.dismiss_contact_notification(user_1, user_2)
1685 .await
1686 .unwrap();
1687 assert_eq!(
1688 db.get_contacts(user_1).await.unwrap(),
1689 &[
1690 Contact::Accepted {
1691 user_id: user_1,
1692 should_notify: false
1693 },
1694 Contact::Accepted {
1695 user_id: user_2,
1696 should_notify: false,
1697 },
1698 ]
1699 );
1700
1701 // Users send each other concurrent contact requests and
1702 // see that they are immediately accepted.
1703 db.send_contact_request(user_1, user_3).await.unwrap();
1704 db.send_contact_request(user_3, user_1).await.unwrap();
1705 assert_eq!(
1706 db.get_contacts(user_1).await.unwrap(),
1707 &[
1708 Contact::Accepted {
1709 user_id: user_1,
1710 should_notify: false
1711 },
1712 Contact::Accepted {
1713 user_id: user_2,
1714 should_notify: false,
1715 },
1716 Contact::Accepted {
1717 user_id: user_3,
1718 should_notify: false
1719 },
1720 ]
1721 );
1722 assert_eq!(
1723 db.get_contacts(user_3).await.unwrap(),
1724 &[
1725 Contact::Accepted {
1726 user_id: user_1,
1727 should_notify: false
1728 },
1729 Contact::Accepted {
1730 user_id: user_3,
1731 should_notify: false
1732 }
1733 ],
1734 );
1735
1736 // User declines a contact request. Both users see that it is gone.
1737 db.send_contact_request(user_2, user_3).await.unwrap();
1738 db.respond_to_contact_request(user_3, user_2, false)
1739 .await
1740 .unwrap();
1741 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1742 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1743 assert_eq!(
1744 db.get_contacts(user_2).await.unwrap(),
1745 &[
1746 Contact::Accepted {
1747 user_id: user_1,
1748 should_notify: false
1749 },
1750 Contact::Accepted {
1751 user_id: user_2,
1752 should_notify: false
1753 }
1754 ]
1755 );
1756 assert_eq!(
1757 db.get_contacts(user_3).await.unwrap(),
1758 &[
1759 Contact::Accepted {
1760 user_id: user_1,
1761 should_notify: false
1762 },
1763 Contact::Accepted {
1764 user_id: user_3,
1765 should_notify: false
1766 }
1767 ],
1768 );
1769 }
1770 }
1771
1772 #[tokio::test(flavor = "multi_thread")]
1773 async fn test_invite_codes() {
1774 let postgres = TestDb::postgres().await;
1775 let db = postgres.db();
1776 let user1 = db.create_user("user-1", None, false).await.unwrap();
1777
1778 // Initially, user 1 has no invite code
1779 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1780
1781 // Setting invite count to 0 when no code is assigned does not assign a new code
1782 db.set_invite_count(user1, 0).await.unwrap();
1783 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1784
1785 // User 1 creates an invite code that can be used twice.
1786 db.set_invite_count(user1, 2).await.unwrap();
1787 let (invite_code, invite_count) =
1788 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1789 assert_eq!(invite_count, 2);
1790
1791 // User 2 redeems the invite code and becomes a contact of user 1.
1792 let user2 = db
1793 .redeem_invite_code(&invite_code, "user-2", None)
1794 .await
1795 .unwrap();
1796 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1797 assert_eq!(invite_count, 1);
1798 assert_eq!(
1799 db.get_contacts(user1).await.unwrap(),
1800 [
1801 Contact::Accepted {
1802 user_id: user1,
1803 should_notify: false
1804 },
1805 Contact::Accepted {
1806 user_id: user2,
1807 should_notify: true
1808 }
1809 ]
1810 );
1811 assert_eq!(
1812 db.get_contacts(user2).await.unwrap(),
1813 [
1814 Contact::Accepted {
1815 user_id: user1,
1816 should_notify: false
1817 },
1818 Contact::Accepted {
1819 user_id: user2,
1820 should_notify: false
1821 }
1822 ]
1823 );
1824
1825 // User 3 redeems the invite code and becomes a contact of user 1.
1826 let user3 = db
1827 .redeem_invite_code(&invite_code, "user-3", None)
1828 .await
1829 .unwrap();
1830 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1831 assert_eq!(invite_count, 0);
1832 assert_eq!(
1833 db.get_contacts(user1).await.unwrap(),
1834 [
1835 Contact::Accepted {
1836 user_id: user1,
1837 should_notify: false
1838 },
1839 Contact::Accepted {
1840 user_id: user2,
1841 should_notify: true
1842 },
1843 Contact::Accepted {
1844 user_id: user3,
1845 should_notify: true
1846 }
1847 ]
1848 );
1849 assert_eq!(
1850 db.get_contacts(user3).await.unwrap(),
1851 [
1852 Contact::Accepted {
1853 user_id: user1,
1854 should_notify: false
1855 },
1856 Contact::Accepted {
1857 user_id: user3,
1858 should_notify: false
1859 },
1860 ]
1861 );
1862
1863 // Trying to reedem the code for the third time results in an error.
1864 db.redeem_invite_code(&invite_code, "user-4", None)
1865 .await
1866 .unwrap_err();
1867
1868 // Invite count can be updated after the code has been created.
1869 db.set_invite_count(user1, 2).await.unwrap();
1870 let (latest_code, invite_count) =
1871 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1872 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1873 assert_eq!(invite_count, 2);
1874
1875 // User 4 can now redeem the invite code and becomes a contact of user 1.
1876 let user4 = db
1877 .redeem_invite_code(&invite_code, "user-4", None)
1878 .await
1879 .unwrap();
1880 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1881 assert_eq!(invite_count, 1);
1882 assert_eq!(
1883 db.get_contacts(user1).await.unwrap(),
1884 [
1885 Contact::Accepted {
1886 user_id: user1,
1887 should_notify: false
1888 },
1889 Contact::Accepted {
1890 user_id: user2,
1891 should_notify: true
1892 },
1893 Contact::Accepted {
1894 user_id: user3,
1895 should_notify: true
1896 },
1897 Contact::Accepted {
1898 user_id: user4,
1899 should_notify: true
1900 }
1901 ]
1902 );
1903 assert_eq!(
1904 db.get_contacts(user4).await.unwrap(),
1905 [
1906 Contact::Accepted {
1907 user_id: user1,
1908 should_notify: false
1909 },
1910 Contact::Accepted {
1911 user_id: user4,
1912 should_notify: false
1913 },
1914 ]
1915 );
1916
1917 // An existing user cannot redeem invite codes.
1918 db.redeem_invite_code(&invite_code, "user-2", None)
1919 .await
1920 .unwrap_err();
1921 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1922 assert_eq!(invite_count, 1);
1923 }
1924
1925 pub struct TestDb {
1926 pub db: Option<Arc<dyn Db>>,
1927 pub url: String,
1928 }
1929
1930 impl TestDb {
1931 pub async fn postgres() -> Self {
1932 lazy_static! {
1933 static ref LOCK: Mutex<()> = Mutex::new(());
1934 }
1935
1936 let _guard = LOCK.lock();
1937 let mut rng = StdRng::from_entropy();
1938 let name = format!("zed-test-{}", rng.gen::<u128>());
1939 let url = format!("postgres://postgres@localhost/{}", name);
1940 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1941 Postgres::create_database(&url)
1942 .await
1943 .expect("failed to create test db");
1944 let db = PostgresDb::new(&url, 5).await.unwrap();
1945 let migrator = Migrator::new(migrations_path).await.unwrap();
1946 migrator.run(&db.pool).await.unwrap();
1947 Self {
1948 db: Some(Arc::new(db)),
1949 url,
1950 }
1951 }
1952
1953 pub fn fake(background: Arc<Background>) -> Self {
1954 Self {
1955 db: Some(Arc::new(FakeDb::new(background))),
1956 url: Default::default(),
1957 }
1958 }
1959
1960 pub fn db(&self) -> &Arc<dyn Db> {
1961 self.db.as_ref().unwrap()
1962 }
1963 }
1964
1965 impl Drop for TestDb {
1966 fn drop(&mut self) {
1967 if let Some(db) = self.db.take() {
1968 futures::executor::block_on(db.teardown(&self.url));
1969 }
1970 }
1971 }
1972
1973 pub struct FakeDb {
1974 background: Arc<Background>,
1975 pub users: Mutex<BTreeMap<UserId, User>>,
1976 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1977 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), usize>>,
1978 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1979 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1980 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1981 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1982 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1983 pub contacts: Mutex<Vec<FakeContact>>,
1984 next_channel_message_id: Mutex<i32>,
1985 next_user_id: Mutex<i32>,
1986 next_org_id: Mutex<i32>,
1987 next_channel_id: Mutex<i32>,
1988 next_project_id: Mutex<i32>,
1989 }
1990
1991 #[derive(Debug)]
1992 pub struct FakeContact {
1993 pub requester_id: UserId,
1994 pub responder_id: UserId,
1995 pub accepted: bool,
1996 pub should_notify: bool,
1997 }
1998
1999 impl FakeDb {
2000 pub fn new(background: Arc<Background>) -> Self {
2001 Self {
2002 background,
2003 users: Default::default(),
2004 next_user_id: Mutex::new(1),
2005 projects: Default::default(),
2006 worktree_extensions: Default::default(),
2007 next_project_id: Mutex::new(1),
2008 orgs: Default::default(),
2009 next_org_id: Mutex::new(1),
2010 org_memberships: Default::default(),
2011 channels: Default::default(),
2012 next_channel_id: Mutex::new(1),
2013 channel_memberships: Default::default(),
2014 channel_messages: Default::default(),
2015 next_channel_message_id: Mutex::new(1),
2016 contacts: Default::default(),
2017 }
2018 }
2019 }
2020
2021 #[async_trait]
2022 impl Db for FakeDb {
2023 async fn create_user(
2024 &self,
2025 github_login: &str,
2026 email_address: Option<&str>,
2027 admin: bool,
2028 ) -> Result<UserId> {
2029 self.background.simulate_random_delay().await;
2030
2031 let mut users = self.users.lock();
2032 if let Some(user) = users
2033 .values()
2034 .find(|user| user.github_login == github_login)
2035 {
2036 Ok(user.id)
2037 } else {
2038 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2039 users.insert(
2040 user_id,
2041 User {
2042 id: user_id,
2043 github_login: github_login.to_string(),
2044 email_address: email_address.map(str::to_string),
2045 admin,
2046 invite_code: None,
2047 invite_count: 0,
2048 connected_once: false,
2049 },
2050 );
2051 Ok(user_id)
2052 }
2053 }
2054
2055 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2056 unimplemented!()
2057 }
2058
2059 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2060 unimplemented!()
2061 }
2062
2063 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2064 unimplemented!()
2065 }
2066
2067 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2068 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2069 }
2070
2071 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2072 self.background.simulate_random_delay().await;
2073 let users = self.users.lock();
2074 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2075 }
2076
2077 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2078 Ok(self
2079 .users
2080 .lock()
2081 .values()
2082 .find(|user| user.github_login == github_login)
2083 .cloned())
2084 }
2085
2086 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2087 unimplemented!()
2088 }
2089
2090 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2091 self.background.simulate_random_delay().await;
2092 let mut users = self.users.lock();
2093 let mut user = users
2094 .get_mut(&id)
2095 .ok_or_else(|| anyhow!("user not found"))?;
2096 user.connected_once = connected_once;
2097 Ok(())
2098 }
2099
2100 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2101 unimplemented!()
2102 }
2103
2104 // invite codes
2105
2106 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2107 unimplemented!()
2108 }
2109
2110 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2111 Ok(None)
2112 }
2113
2114 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2115 unimplemented!()
2116 }
2117
2118 async fn redeem_invite_code(
2119 &self,
2120 _code: &str,
2121 _login: &str,
2122 _email_address: Option<&str>,
2123 ) -> Result<UserId> {
2124 unimplemented!()
2125 }
2126
2127 // projects
2128
2129 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2130 self.background.simulate_random_delay().await;
2131 if !self.users.lock().contains_key(&host_user_id) {
2132 Err(anyhow!("no such user"))?;
2133 }
2134
2135 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2136 self.projects.lock().insert(
2137 project_id,
2138 Project {
2139 id: project_id,
2140 host_user_id,
2141 unregistered: false,
2142 },
2143 );
2144 Ok(project_id)
2145 }
2146
2147 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2148 self.projects
2149 .lock()
2150 .get_mut(&project_id)
2151 .ok_or_else(|| anyhow!("no such project"))?
2152 .unregistered = true;
2153 Ok(())
2154 }
2155
2156 async fn update_worktree_extensions(
2157 &self,
2158 project_id: ProjectId,
2159 worktree_id: u64,
2160 extensions: HashMap<String, usize>,
2161 ) -> Result<()> {
2162 self.background.simulate_random_delay().await;
2163 if !self.projects.lock().contains_key(&project_id) {
2164 Err(anyhow!("no such project"))?;
2165 }
2166
2167 for (extension, count) in extensions {
2168 self.worktree_extensions
2169 .lock()
2170 .insert((project_id, worktree_id, extension), count);
2171 }
2172
2173 Ok(())
2174 }
2175
2176 async fn record_project_activity(
2177 &self,
2178 _period: Range<OffsetDateTime>,
2179 _active_projects: &[(UserId, ProjectId)],
2180 ) -> Result<()> {
2181 unimplemented!()
2182 }
2183
2184 async fn summarize_project_activity(
2185 &self,
2186 _period: Range<OffsetDateTime>,
2187 _limit: usize,
2188 ) -> Result<Vec<UserActivitySummary>> {
2189 unimplemented!()
2190 }
2191
2192 // contacts
2193
2194 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2195 self.background.simulate_random_delay().await;
2196 let mut contacts = vec![Contact::Accepted {
2197 user_id: id,
2198 should_notify: false,
2199 }];
2200
2201 for contact in self.contacts.lock().iter() {
2202 if contact.requester_id == id {
2203 if contact.accepted {
2204 contacts.push(Contact::Accepted {
2205 user_id: contact.responder_id,
2206 should_notify: contact.should_notify,
2207 });
2208 } else {
2209 contacts.push(Contact::Outgoing {
2210 user_id: contact.responder_id,
2211 });
2212 }
2213 } else if contact.responder_id == id {
2214 if contact.accepted {
2215 contacts.push(Contact::Accepted {
2216 user_id: contact.requester_id,
2217 should_notify: false,
2218 });
2219 } else {
2220 contacts.push(Contact::Incoming {
2221 user_id: contact.requester_id,
2222 should_notify: contact.should_notify,
2223 });
2224 }
2225 }
2226 }
2227
2228 contacts.sort_unstable_by_key(|contact| contact.user_id());
2229 Ok(contacts)
2230 }
2231
2232 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2233 self.background.simulate_random_delay().await;
2234 Ok(self.contacts.lock().iter().any(|contact| {
2235 contact.accepted
2236 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2237 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2238 }))
2239 }
2240
2241 async fn send_contact_request(
2242 &self,
2243 requester_id: UserId,
2244 responder_id: UserId,
2245 ) -> Result<()> {
2246 let mut contacts = self.contacts.lock();
2247 for contact in contacts.iter_mut() {
2248 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2249 if contact.accepted {
2250 Err(anyhow!("contact already exists"))?;
2251 } else {
2252 Err(anyhow!("contact already requested"))?;
2253 }
2254 }
2255 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2256 if contact.accepted {
2257 Err(anyhow!("contact already exists"))?;
2258 } else {
2259 contact.accepted = true;
2260 contact.should_notify = false;
2261 return Ok(());
2262 }
2263 }
2264 }
2265 contacts.push(FakeContact {
2266 requester_id,
2267 responder_id,
2268 accepted: false,
2269 should_notify: true,
2270 });
2271 Ok(())
2272 }
2273
2274 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2275 self.contacts.lock().retain(|contact| {
2276 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2277 });
2278 Ok(())
2279 }
2280
2281 async fn dismiss_contact_notification(
2282 &self,
2283 user_id: UserId,
2284 contact_user_id: UserId,
2285 ) -> Result<()> {
2286 let mut contacts = self.contacts.lock();
2287 for contact in contacts.iter_mut() {
2288 if contact.requester_id == contact_user_id
2289 && contact.responder_id == user_id
2290 && !contact.accepted
2291 {
2292 contact.should_notify = false;
2293 return Ok(());
2294 }
2295 if contact.requester_id == user_id
2296 && contact.responder_id == contact_user_id
2297 && contact.accepted
2298 {
2299 contact.should_notify = false;
2300 return Ok(());
2301 }
2302 }
2303 Err(anyhow!("no such notification"))?
2304 }
2305
2306 async fn respond_to_contact_request(
2307 &self,
2308 responder_id: UserId,
2309 requester_id: UserId,
2310 accept: bool,
2311 ) -> Result<()> {
2312 let mut contacts = self.contacts.lock();
2313 for (ix, contact) in contacts.iter_mut().enumerate() {
2314 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2315 if contact.accepted {
2316 Err(anyhow!("contact already confirmed"))?;
2317 }
2318 if accept {
2319 contact.accepted = true;
2320 contact.should_notify = true;
2321 } else {
2322 contacts.remove(ix);
2323 }
2324 return Ok(());
2325 }
2326 }
2327 Err(anyhow!("no such contact request"))?
2328 }
2329
2330 async fn create_access_token_hash(
2331 &self,
2332 _user_id: UserId,
2333 _access_token_hash: &str,
2334 _max_access_token_count: usize,
2335 ) -> Result<()> {
2336 unimplemented!()
2337 }
2338
2339 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2340 unimplemented!()
2341 }
2342
2343 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2344 unimplemented!()
2345 }
2346
2347 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2348 self.background.simulate_random_delay().await;
2349 let mut orgs = self.orgs.lock();
2350 if orgs.values().any(|org| org.slug == slug) {
2351 Err(anyhow!("org already exists"))?
2352 } else {
2353 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2354 orgs.insert(
2355 org_id,
2356 Org {
2357 id: org_id,
2358 name: name.to_string(),
2359 slug: slug.to_string(),
2360 },
2361 );
2362 Ok(org_id)
2363 }
2364 }
2365
2366 async fn add_org_member(
2367 &self,
2368 org_id: OrgId,
2369 user_id: UserId,
2370 is_admin: bool,
2371 ) -> Result<()> {
2372 self.background.simulate_random_delay().await;
2373 if !self.orgs.lock().contains_key(&org_id) {
2374 Err(anyhow!("org does not exist"))?;
2375 }
2376 if !self.users.lock().contains_key(&user_id) {
2377 Err(anyhow!("user does not exist"))?;
2378 }
2379
2380 self.org_memberships
2381 .lock()
2382 .entry((org_id, user_id))
2383 .or_insert(is_admin);
2384 Ok(())
2385 }
2386
2387 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2388 self.background.simulate_random_delay().await;
2389 if !self.orgs.lock().contains_key(&org_id) {
2390 Err(anyhow!("org does not exist"))?;
2391 }
2392
2393 let mut channels = self.channels.lock();
2394 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2395 channels.insert(
2396 channel_id,
2397 Channel {
2398 id: channel_id,
2399 name: name.to_string(),
2400 owner_id: org_id.0,
2401 owner_is_user: false,
2402 },
2403 );
2404 Ok(channel_id)
2405 }
2406
2407 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2408 self.background.simulate_random_delay().await;
2409 Ok(self
2410 .channels
2411 .lock()
2412 .values()
2413 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2414 .cloned()
2415 .collect())
2416 }
2417
2418 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2419 self.background.simulate_random_delay().await;
2420 let channels = self.channels.lock();
2421 let memberships = self.channel_memberships.lock();
2422 Ok(channels
2423 .values()
2424 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2425 .cloned()
2426 .collect())
2427 }
2428
2429 async fn can_user_access_channel(
2430 &self,
2431 user_id: UserId,
2432 channel_id: ChannelId,
2433 ) -> Result<bool> {
2434 self.background.simulate_random_delay().await;
2435 Ok(self
2436 .channel_memberships
2437 .lock()
2438 .contains_key(&(channel_id, user_id)))
2439 }
2440
2441 async fn add_channel_member(
2442 &self,
2443 channel_id: ChannelId,
2444 user_id: UserId,
2445 is_admin: bool,
2446 ) -> Result<()> {
2447 self.background.simulate_random_delay().await;
2448 if !self.channels.lock().contains_key(&channel_id) {
2449 Err(anyhow!("channel does not exist"))?;
2450 }
2451 if !self.users.lock().contains_key(&user_id) {
2452 Err(anyhow!("user does not exist"))?;
2453 }
2454
2455 self.channel_memberships
2456 .lock()
2457 .entry((channel_id, user_id))
2458 .or_insert(is_admin);
2459 Ok(())
2460 }
2461
2462 async fn create_channel_message(
2463 &self,
2464 channel_id: ChannelId,
2465 sender_id: UserId,
2466 body: &str,
2467 timestamp: OffsetDateTime,
2468 nonce: u128,
2469 ) -> Result<MessageId> {
2470 self.background.simulate_random_delay().await;
2471 if !self.channels.lock().contains_key(&channel_id) {
2472 Err(anyhow!("channel does not exist"))?;
2473 }
2474 if !self.users.lock().contains_key(&sender_id) {
2475 Err(anyhow!("user does not exist"))?;
2476 }
2477
2478 let mut messages = self.channel_messages.lock();
2479 if let Some(message) = messages
2480 .values()
2481 .find(|message| message.nonce.as_u128() == nonce)
2482 {
2483 Ok(message.id)
2484 } else {
2485 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2486 messages.insert(
2487 message_id,
2488 ChannelMessage {
2489 id: message_id,
2490 channel_id,
2491 sender_id,
2492 body: body.to_string(),
2493 sent_at: timestamp,
2494 nonce: Uuid::from_u128(nonce),
2495 },
2496 );
2497 Ok(message_id)
2498 }
2499 }
2500
2501 async fn get_channel_messages(
2502 &self,
2503 channel_id: ChannelId,
2504 count: usize,
2505 before_id: Option<MessageId>,
2506 ) -> Result<Vec<ChannelMessage>> {
2507 let mut messages = self
2508 .channel_messages
2509 .lock()
2510 .values()
2511 .rev()
2512 .filter(|message| {
2513 message.channel_id == channel_id
2514 && message.id < before_id.unwrap_or(MessageId::MAX)
2515 })
2516 .take(count)
2517 .cloned()
2518 .collect::<Vec<_>>();
2519 messages.sort_unstable_by_key(|message| message.id);
2520 Ok(messages)
2521 }
2522
2523 async fn teardown(&self, _: &str) {}
2524
2525 #[cfg(test)]
2526 fn as_fake(&self) -> Option<&FakeDb> {
2527 Some(self)
2528 }
2529 }
2530}