1use std::{cmp, ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use serde::{Deserialize, Serialize};
10pub use sqlx::postgres::PgPoolOptions as DbOptions;
11use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
12use time::{OffsetDateTime, PrimitiveDateTime};
13
14#[async_trait]
15pub trait Db: Send + Sync {
16 async fn create_user(
17 &self,
18 github_login: &str,
19 email_address: Option<&str>,
20 admin: bool,
21 ) -> Result<UserId>;
22 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
23 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
24 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
25 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
26 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
27 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
28 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
29 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
30 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
31 async fn destroy_user(&self, id: UserId) -> Result<()>;
32
33 async fn create_signup(&self, signup: Signup) -> Result<()>;
34 async fn get_signup_invites(&self, count: usize) -> Result<Vec<SignupInvite>>;
35 async fn record_signup_invites_sent(&self, signups: &[SignupInvite]) -> Result<()>;
36 async fn redeem_signup(&self, redemption: SignupRedemption) -> Result<UserId>;
37
38 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
39 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
40 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
41 async fn redeem_invite_code(
42 &self,
43 code: &str,
44 login: &str,
45 email_address: Option<&str>,
46 ) -> Result<UserId>;
47
48 /// Registers a new project for the given user.
49 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
50
51 /// Unregisters a project for the given project id.
52 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
53
54 /// Update file counts by extension for the given project and worktree.
55 async fn update_worktree_extensions(
56 &self,
57 project_id: ProjectId,
58 worktree_id: u64,
59 extensions: HashMap<String, u32>,
60 ) -> Result<()>;
61
62 /// Get the file counts on the given project keyed by their worktree and extension.
63 async fn get_project_extensions(
64 &self,
65 project_id: ProjectId,
66 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
67
68 /// Record which users have been active in which projects during
69 /// a given period of time.
70 async fn record_user_activity(
71 &self,
72 time_period: Range<OffsetDateTime>,
73 active_projects: &[(UserId, ProjectId)],
74 ) -> Result<()>;
75
76 /// Get the number of users who have been active in the given
77 /// time period for at least the given time duration.
78 async fn get_active_user_count(
79 &self,
80 time_period: Range<OffsetDateTime>,
81 min_duration: Duration,
82 only_collaborative: bool,
83 ) -> Result<usize>;
84
85 /// Get the users that have been most active during the given time period,
86 /// along with the amount of time they have been active in each project.
87 async fn get_top_users_activity_summary(
88 &self,
89 time_period: Range<OffsetDateTime>,
90 max_user_count: usize,
91 ) -> Result<Vec<UserActivitySummary>>;
92
93 /// Get the project activity for the given user and time period.
94 async fn get_user_activity_timeline(
95 &self,
96 time_period: Range<OffsetDateTime>,
97 user_id: UserId,
98 ) -> Result<Vec<UserActivityPeriod>>;
99
100 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
101 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
102 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
103 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
104 async fn dismiss_contact_notification(
105 &self,
106 responder_id: UserId,
107 requester_id: UserId,
108 ) -> Result<()>;
109 async fn respond_to_contact_request(
110 &self,
111 responder_id: UserId,
112 requester_id: UserId,
113 accept: bool,
114 ) -> Result<()>;
115
116 async fn create_access_token_hash(
117 &self,
118 user_id: UserId,
119 access_token_hash: &str,
120 max_access_token_count: usize,
121 ) -> Result<()>;
122 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
123 #[cfg(any(test, feature = "seed-support"))]
124
125 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
126 #[cfg(any(test, feature = "seed-support"))]
127 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
128 #[cfg(any(test, feature = "seed-support"))]
129 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
130 #[cfg(any(test, feature = "seed-support"))]
131 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
132 #[cfg(any(test, feature = "seed-support"))]
133
134 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
135 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
136 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
137 -> Result<bool>;
138 #[cfg(any(test, feature = "seed-support"))]
139 async fn add_channel_member(
140 &self,
141 channel_id: ChannelId,
142 user_id: UserId,
143 is_admin: bool,
144 ) -> Result<()>;
145 async fn create_channel_message(
146 &self,
147 channel_id: ChannelId,
148 sender_id: UserId,
149 body: &str,
150 timestamp: OffsetDateTime,
151 nonce: u128,
152 ) -> Result<MessageId>;
153 async fn get_channel_messages(
154 &self,
155 channel_id: ChannelId,
156 count: usize,
157 before_id: Option<MessageId>,
158 ) -> Result<Vec<ChannelMessage>>;
159 #[cfg(test)]
160 async fn teardown(&self, url: &str);
161 #[cfg(test)]
162 fn as_fake(&self) -> Option<&tests::FakeDb>;
163}
164
165pub struct PostgresDb {
166 pool: sqlx::PgPool,
167}
168
169impl PostgresDb {
170 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
171 let pool = DbOptions::new()
172 .max_connections(max_connections)
173 .connect(url)
174 .await
175 .context("failed to connect to postgres database")?;
176 Ok(Self { pool })
177 }
178}
179
180#[async_trait]
181impl Db for PostgresDb {
182 // users
183
184 async fn create_user(
185 &self,
186 github_login: &str,
187 email_address: Option<&str>,
188 admin: bool,
189 ) -> Result<UserId> {
190 let query = "
191 INSERT INTO users (github_login, email_address, admin)
192 VALUES ($1, $2, $3)
193 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
194 RETURNING id
195 ";
196 Ok(sqlx::query_scalar(query)
197 .bind(github_login)
198 .bind(email_address)
199 .bind(admin)
200 .fetch_one(&self.pool)
201 .await
202 .map(UserId)?)
203 }
204
205 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
206 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
207 Ok(sqlx::query_as(query)
208 .bind(limit as i32)
209 .bind((page * limit) as i32)
210 .fetch_all(&self.pool)
211 .await?)
212 }
213
214 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
215 let mut query = QueryBuilder::new(
216 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
217 );
218 query.push_values(
219 users,
220 |mut query, (github_login, email_address, invite_count)| {
221 query
222 .push_bind(github_login)
223 .push_bind(email_address)
224 .push_bind(false)
225 .push_bind(random_invite_code())
226 .push_bind(invite_count as i32);
227 },
228 );
229 query.push(
230 "
231 ON CONFLICT (github_login) DO UPDATE SET
232 github_login = excluded.github_login,
233 invite_count = excluded.invite_count,
234 invite_code = CASE WHEN users.invite_code IS NULL
235 THEN excluded.invite_code
236 ELSE users.invite_code
237 END
238 RETURNING id
239 ",
240 );
241
242 let rows = query.build().fetch_all(&self.pool).await?;
243 Ok(rows
244 .into_iter()
245 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
246 .collect())
247 }
248
249 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
250 let like_string = fuzzy_like_string(name_query);
251 let query = "
252 SELECT users.*
253 FROM users
254 WHERE github_login ILIKE $1
255 ORDER BY github_login <-> $2
256 LIMIT $3
257 ";
258 Ok(sqlx::query_as(query)
259 .bind(like_string)
260 .bind(name_query)
261 .bind(limit as i32)
262 .fetch_all(&self.pool)
263 .await?)
264 }
265
266 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
267 let users = self.get_users_by_ids(vec![id]).await?;
268 Ok(users.into_iter().next())
269 }
270
271 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
272 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
273 let query = "
274 SELECT users.*
275 FROM users
276 WHERE users.id = ANY ($1)
277 ";
278 Ok(sqlx::query_as(query)
279 .bind(&ids)
280 .fetch_all(&self.pool)
281 .await?)
282 }
283
284 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
285 let query = format!(
286 "
287 SELECT users.*
288 FROM users
289 WHERE invite_count = 0
290 AND inviter_id IS{} NULL
291 ",
292 if invited_by_another_user { " NOT" } else { "" }
293 );
294
295 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
296 }
297
298 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
299 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
300 Ok(sqlx::query_as(query)
301 .bind(github_login)
302 .fetch_optional(&self.pool)
303 .await?)
304 }
305
306 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
307 let query = "UPDATE users SET admin = $1 WHERE id = $2";
308 Ok(sqlx::query(query)
309 .bind(is_admin)
310 .bind(id.0)
311 .execute(&self.pool)
312 .await
313 .map(drop)?)
314 }
315
316 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
317 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
318 Ok(sqlx::query(query)
319 .bind(connected_once)
320 .bind(id.0)
321 .execute(&self.pool)
322 .await
323 .map(drop)?)
324 }
325
326 async fn destroy_user(&self, id: UserId) -> Result<()> {
327 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
328 sqlx::query(query)
329 .bind(id.0)
330 .execute(&self.pool)
331 .await
332 .map(drop)?;
333 let query = "DELETE FROM users WHERE id = $1;";
334 Ok(sqlx::query(query)
335 .bind(id.0)
336 .execute(&self.pool)
337 .await
338 .map(drop)?)
339 }
340
341 // signups
342
343 async fn create_signup(&self, signup: Signup) -> Result<()> {
344 sqlx::query(
345 "
346 INSERT INTO signups
347 (
348 email_address,
349 email_confirmation_code,
350 email_confirmation_sent,
351 platform_linux,
352 platform_mac,
353 platform_windows,
354 platform_unknown,
355 editor_features,
356 programming_languages
357 )
358 VALUES
359 ($1, $2, 'f', $3, $4, $5, 'f', $6, $7)
360 ",
361 )
362 .bind(&signup.email_address)
363 .bind(&random_email_confirmation_code())
364 .bind(&signup.platform_linux)
365 .bind(&signup.platform_mac)
366 .bind(&signup.platform_windows)
367 .bind(&signup.editor_features)
368 .bind(&signup.programming_languages)
369 .execute(&self.pool)
370 .await?;
371 Ok(())
372 }
373
374 async fn get_signup_invites(&self, count: usize) -> Result<Vec<SignupInvite>> {
375 Ok(sqlx::query_as(
376 "
377 SELECT
378 email_address, email_confirmation_code
379 FROM signups
380 WHERE
381 NOT email_confirmation_sent AND
382 platform_mac
383 LIMIT $1
384 ",
385 )
386 .bind(count as i32)
387 .fetch_all(&self.pool)
388 .await?)
389 }
390
391 async fn record_signup_invites_sent(&self, signups: &[SignupInvite]) -> Result<()> {
392 sqlx::query(
393 "
394 UPDATE signups
395 SET email_confirmation_sent = 't'
396 WHERE email_address = ANY ($1)
397 ",
398 )
399 .bind(
400 &signups
401 .iter()
402 .map(|s| s.email_address.as_str())
403 .collect::<Vec<_>>(),
404 )
405 .execute(&self.pool)
406 .await?;
407 Ok(())
408 }
409
410 async fn redeem_signup(&self, redemption: SignupRedemption) -> Result<UserId> {
411 let mut tx = self.pool.begin().await?;
412 let signup_id: i32 = sqlx::query_scalar(
413 "
414 SELECT id
415 FROM signups
416 WHERE
417 email_address = $1 AND
418 email_confirmation_code = $2 AND
419 email_confirmation_sent AND
420 user_id is NULL
421 ",
422 )
423 .bind(&redemption.email_address)
424 .bind(&redemption.email_confirmation_code)
425 .fetch_one(&mut tx)
426 .await?;
427
428 let user_id: i32 = sqlx::query_scalar(
429 "
430 INSERT INTO users
431 (email_address, github_login, admin, invite_count, invite_code)
432 VALUES
433 ($1, $2, 'f', $3, $4)
434 RETURNING id
435 ",
436 )
437 .bind(&redemption.email_address)
438 .bind(&redemption.github_login)
439 .bind(&redemption.invite_count)
440 .bind(random_invite_code())
441 .fetch_one(&mut tx)
442 .await?;
443
444 sqlx::query(
445 "
446 UPDATE signups
447 SET user_id = $1
448 WHERE id = $2
449 ",
450 )
451 .bind(&user_id)
452 .bind(&signup_id)
453 .execute(&mut tx)
454 .await?;
455
456 tx.commit().await?;
457 Ok(UserId(user_id))
458 }
459
460 // invite codes
461
462 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
463 let mut tx = self.pool.begin().await?;
464 if count > 0 {
465 sqlx::query(
466 "
467 UPDATE users
468 SET invite_code = $1
469 WHERE id = $2 AND invite_code IS NULL
470 ",
471 )
472 .bind(random_invite_code())
473 .bind(id)
474 .execute(&mut tx)
475 .await?;
476 }
477
478 sqlx::query(
479 "
480 UPDATE users
481 SET invite_count = $1
482 WHERE id = $2
483 ",
484 )
485 .bind(count as i32)
486 .bind(id)
487 .execute(&mut tx)
488 .await?;
489 tx.commit().await?;
490 Ok(())
491 }
492
493 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
494 let result: Option<(String, i32)> = sqlx::query_as(
495 "
496 SELECT invite_code, invite_count
497 FROM users
498 WHERE id = $1 AND invite_code IS NOT NULL
499 ",
500 )
501 .bind(id)
502 .fetch_optional(&self.pool)
503 .await?;
504 if let Some((code, count)) = result {
505 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
506 } else {
507 Ok(None)
508 }
509 }
510
511 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
512 sqlx::query_as(
513 "
514 SELECT *
515 FROM users
516 WHERE invite_code = $1
517 ",
518 )
519 .bind(code)
520 .fetch_optional(&self.pool)
521 .await?
522 .ok_or_else(|| {
523 Error::Http(
524 StatusCode::NOT_FOUND,
525 "that invite code does not exist".to_string(),
526 )
527 })
528 }
529
530 async fn redeem_invite_code(
531 &self,
532 code: &str,
533 login: &str,
534 email_address: Option<&str>,
535 ) -> Result<UserId> {
536 let mut tx = self.pool.begin().await?;
537
538 let inviter_id: Option<UserId> = sqlx::query_scalar(
539 "
540 UPDATE users
541 SET invite_count = invite_count - 1
542 WHERE
543 invite_code = $1 AND
544 invite_count > 0
545 RETURNING id
546 ",
547 )
548 .bind(code)
549 .fetch_optional(&mut tx)
550 .await?;
551
552 let inviter_id = match inviter_id {
553 Some(inviter_id) => inviter_id,
554 None => {
555 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
556 .bind(code)
557 .fetch_optional(&mut tx)
558 .await?
559 .is_some()
560 {
561 Err(Error::Http(
562 StatusCode::UNAUTHORIZED,
563 "no invites remaining".to_string(),
564 ))?
565 } else {
566 Err(Error::Http(
567 StatusCode::NOT_FOUND,
568 "invite code not found".to_string(),
569 ))?
570 }
571 }
572 };
573
574 let invitee_id = sqlx::query_scalar(
575 "
576 INSERT INTO users
577 (github_login, email_address, admin, inviter_id, invite_code, invite_count)
578 VALUES
579 ($1, $2, 'f', $3, $4, $5)
580 RETURNING id
581 ",
582 )
583 .bind(login)
584 .bind(email_address)
585 .bind(inviter_id)
586 .bind(random_invite_code())
587 .bind(5)
588 .fetch_one(&mut tx)
589 .await
590 .map(UserId)?;
591
592 sqlx::query(
593 "
594 INSERT INTO contacts
595 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
596 VALUES
597 ($1, $2, 't', 't', 't')
598 ",
599 )
600 .bind(inviter_id)
601 .bind(invitee_id)
602 .execute(&mut tx)
603 .await?;
604
605 tx.commit().await?;
606 Ok(invitee_id)
607 }
608
609 // projects
610
611 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
612 Ok(sqlx::query_scalar(
613 "
614 INSERT INTO projects(host_user_id)
615 VALUES ($1)
616 RETURNING id
617 ",
618 )
619 .bind(host_user_id)
620 .fetch_one(&self.pool)
621 .await
622 .map(ProjectId)?)
623 }
624
625 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
626 sqlx::query(
627 "
628 UPDATE projects
629 SET unregistered = 't'
630 WHERE id = $1
631 ",
632 )
633 .bind(project_id)
634 .execute(&self.pool)
635 .await?;
636 Ok(())
637 }
638
639 async fn update_worktree_extensions(
640 &self,
641 project_id: ProjectId,
642 worktree_id: u64,
643 extensions: HashMap<String, u32>,
644 ) -> Result<()> {
645 if extensions.is_empty() {
646 return Ok(());
647 }
648
649 let mut query = QueryBuilder::new(
650 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
651 );
652 query.push_values(extensions, |mut query, (extension, count)| {
653 query
654 .push_bind(project_id)
655 .push_bind(worktree_id as i32)
656 .push_bind(extension)
657 .push_bind(count as i32);
658 });
659 query.push(
660 "
661 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
662 count = excluded.count
663 ",
664 );
665 query.build().execute(&self.pool).await?;
666
667 Ok(())
668 }
669
670 async fn get_project_extensions(
671 &self,
672 project_id: ProjectId,
673 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
674 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
675 struct WorktreeExtension {
676 worktree_id: i32,
677 extension: String,
678 count: i32,
679 }
680
681 let query = "
682 SELECT worktree_id, extension, count
683 FROM worktree_extensions
684 WHERE project_id = $1
685 ";
686 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
687 .bind(&project_id)
688 .fetch_all(&self.pool)
689 .await?;
690
691 let mut extension_counts = HashMap::default();
692 for count in counts {
693 extension_counts
694 .entry(count.worktree_id as u64)
695 .or_insert_with(HashMap::default)
696 .insert(count.extension, count.count as usize);
697 }
698 Ok(extension_counts)
699 }
700
701 async fn record_user_activity(
702 &self,
703 time_period: Range<OffsetDateTime>,
704 projects: &[(UserId, ProjectId)],
705 ) -> Result<()> {
706 let query = "
707 INSERT INTO project_activity_periods
708 (ended_at, duration_millis, user_id, project_id)
709 VALUES
710 ($1, $2, $3, $4);
711 ";
712
713 let mut tx = self.pool.begin().await?;
714 let duration_millis =
715 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
716 for (user_id, project_id) in projects {
717 sqlx::query(query)
718 .bind(time_period.end)
719 .bind(duration_millis)
720 .bind(user_id)
721 .bind(project_id)
722 .execute(&mut tx)
723 .await?;
724 }
725 tx.commit().await?;
726
727 Ok(())
728 }
729
730 async fn get_active_user_count(
731 &self,
732 time_period: Range<OffsetDateTime>,
733 min_duration: Duration,
734 only_collaborative: bool,
735 ) -> Result<usize> {
736 let mut with_clause = String::new();
737 with_clause.push_str("WITH\n");
738 with_clause.push_str(
739 "
740 project_durations AS (
741 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
742 FROM project_activity_periods
743 WHERE $1 < ended_at AND ended_at <= $2
744 GROUP BY user_id, project_id
745 ),
746 ",
747 );
748 with_clause.push_str(
749 "
750 project_collaborators as (
751 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
752 FROM project_durations
753 GROUP BY project_id
754 ),
755 ",
756 );
757
758 if only_collaborative {
759 with_clause.push_str(
760 "
761 user_durations AS (
762 SELECT user_id, SUM(project_duration) as total_duration
763 FROM project_durations, project_collaborators
764 WHERE
765 project_durations.project_id = project_collaborators.project_id AND
766 max_collaborators > 1
767 GROUP BY user_id
768 ORDER BY total_duration DESC
769 LIMIT $3
770 )
771 ",
772 );
773 } else {
774 with_clause.push_str(
775 "
776 user_durations AS (
777 SELECT user_id, SUM(project_duration) as total_duration
778 FROM project_durations
779 GROUP BY user_id
780 ORDER BY total_duration DESC
781 LIMIT $3
782 )
783 ",
784 );
785 }
786
787 let query = format!(
788 "
789 {with_clause}
790 SELECT count(user_durations.user_id)
791 FROM user_durations
792 WHERE user_durations.total_duration >= $3
793 "
794 );
795
796 let count: i64 = sqlx::query_scalar(&query)
797 .bind(time_period.start)
798 .bind(time_period.end)
799 .bind(min_duration.as_millis() as i64)
800 .fetch_one(&self.pool)
801 .await?;
802 Ok(count as usize)
803 }
804
805 async fn get_top_users_activity_summary(
806 &self,
807 time_period: Range<OffsetDateTime>,
808 max_user_count: usize,
809 ) -> Result<Vec<UserActivitySummary>> {
810 let query = "
811 WITH
812 project_durations AS (
813 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
814 FROM project_activity_periods
815 WHERE $1 < ended_at AND ended_at <= $2
816 GROUP BY user_id, project_id
817 ),
818 user_durations AS (
819 SELECT user_id, SUM(project_duration) as total_duration
820 FROM project_durations
821 GROUP BY user_id
822 ORDER BY total_duration DESC
823 LIMIT $3
824 ),
825 project_collaborators as (
826 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
827 FROM project_durations
828 GROUP BY project_id
829 )
830 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
831 FROM user_durations, project_durations, project_collaborators, users
832 WHERE
833 user_durations.user_id = project_durations.user_id AND
834 user_durations.user_id = users.id AND
835 project_durations.project_id = project_collaborators.project_id
836 ORDER BY total_duration DESC, user_id ASC, project_id ASC
837 ";
838
839 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
840 .bind(time_period.start)
841 .bind(time_period.end)
842 .bind(max_user_count as i32)
843 .fetch(&self.pool);
844
845 let mut result = Vec::<UserActivitySummary>::new();
846 while let Some(row) = rows.next().await {
847 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
848 let project_id = project_id;
849 let duration = Duration::from_millis(duration_millis as u64);
850 let project_activity = ProjectActivitySummary {
851 id: project_id,
852 duration,
853 max_collaborators: project_collaborators as usize,
854 };
855 if let Some(last_summary) = result.last_mut() {
856 if last_summary.id == user_id {
857 last_summary.project_activity.push(project_activity);
858 continue;
859 }
860 }
861 result.push(UserActivitySummary {
862 id: user_id,
863 project_activity: vec![project_activity],
864 github_login,
865 });
866 }
867
868 Ok(result)
869 }
870
871 async fn get_user_activity_timeline(
872 &self,
873 time_period: Range<OffsetDateTime>,
874 user_id: UserId,
875 ) -> Result<Vec<UserActivityPeriod>> {
876 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
877
878 let query = "
879 SELECT
880 project_activity_periods.ended_at,
881 project_activity_periods.duration_millis,
882 project_activity_periods.project_id,
883 worktree_extensions.extension,
884 worktree_extensions.count
885 FROM project_activity_periods
886 LEFT OUTER JOIN
887 worktree_extensions
888 ON
889 project_activity_periods.project_id = worktree_extensions.project_id
890 WHERE
891 project_activity_periods.user_id = $1 AND
892 $2 < project_activity_periods.ended_at AND
893 project_activity_periods.ended_at <= $3
894 ORDER BY project_activity_periods.id ASC
895 ";
896
897 let mut rows = sqlx::query_as::<
898 _,
899 (
900 PrimitiveDateTime,
901 i32,
902 ProjectId,
903 Option<String>,
904 Option<i32>,
905 ),
906 >(query)
907 .bind(user_id)
908 .bind(time_period.start)
909 .bind(time_period.end)
910 .fetch(&self.pool);
911
912 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
913 while let Some(row) = rows.next().await {
914 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
915 let ended_at = ended_at.assume_utc();
916 let duration = Duration::from_millis(duration_millis as u64);
917 let started_at = ended_at - duration;
918 let project_time_periods = time_periods.entry(project_id).or_default();
919
920 if let Some(prev_duration) = project_time_periods.last_mut() {
921 if started_at <= prev_duration.end + COALESCE_THRESHOLD
922 && ended_at >= prev_duration.start
923 {
924 prev_duration.end = cmp::max(prev_duration.end, ended_at);
925 } else {
926 project_time_periods.push(UserActivityPeriod {
927 project_id,
928 start: started_at,
929 end: ended_at,
930 extensions: Default::default(),
931 });
932 }
933 } else {
934 project_time_periods.push(UserActivityPeriod {
935 project_id,
936 start: started_at,
937 end: ended_at,
938 extensions: Default::default(),
939 });
940 }
941
942 if let Some((extension, extension_count)) = extension.zip(extension_count) {
943 project_time_periods
944 .last_mut()
945 .unwrap()
946 .extensions
947 .insert(extension, extension_count as usize);
948 }
949 }
950
951 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
952 durations.sort_unstable_by_key(|duration| duration.start);
953 Ok(durations)
954 }
955
956 // contacts
957
958 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
959 let query = "
960 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
961 FROM contacts
962 WHERE user_id_a = $1 OR user_id_b = $1;
963 ";
964
965 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
966 .bind(user_id)
967 .fetch(&self.pool);
968
969 let mut contacts = vec![Contact::Accepted {
970 user_id,
971 should_notify: false,
972 }];
973 while let Some(row) = rows.next().await {
974 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
975
976 if user_id_a == user_id {
977 if accepted {
978 contacts.push(Contact::Accepted {
979 user_id: user_id_b,
980 should_notify: should_notify && a_to_b,
981 });
982 } else if a_to_b {
983 contacts.push(Contact::Outgoing { user_id: user_id_b })
984 } else {
985 contacts.push(Contact::Incoming {
986 user_id: user_id_b,
987 should_notify,
988 });
989 }
990 } else if accepted {
991 contacts.push(Contact::Accepted {
992 user_id: user_id_a,
993 should_notify: should_notify && !a_to_b,
994 });
995 } else if a_to_b {
996 contacts.push(Contact::Incoming {
997 user_id: user_id_a,
998 should_notify,
999 });
1000 } else {
1001 contacts.push(Contact::Outgoing { user_id: user_id_a });
1002 }
1003 }
1004
1005 contacts.sort_unstable_by_key(|contact| contact.user_id());
1006
1007 Ok(contacts)
1008 }
1009
1010 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1011 let (id_a, id_b) = if user_id_1 < user_id_2 {
1012 (user_id_1, user_id_2)
1013 } else {
1014 (user_id_2, user_id_1)
1015 };
1016
1017 let query = "
1018 SELECT 1 FROM contacts
1019 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1020 LIMIT 1
1021 ";
1022 Ok(sqlx::query_scalar::<_, i32>(query)
1023 .bind(id_a.0)
1024 .bind(id_b.0)
1025 .fetch_optional(&self.pool)
1026 .await?
1027 .is_some())
1028 }
1029
1030 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1031 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1032 (sender_id, receiver_id, true)
1033 } else {
1034 (receiver_id, sender_id, false)
1035 };
1036 let query = "
1037 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1038 VALUES ($1, $2, $3, 'f', 't')
1039 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1040 SET
1041 accepted = 't',
1042 should_notify = 'f'
1043 WHERE
1044 NOT contacts.accepted AND
1045 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1046 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1047 ";
1048 let result = sqlx::query(query)
1049 .bind(id_a.0)
1050 .bind(id_b.0)
1051 .bind(a_to_b)
1052 .execute(&self.pool)
1053 .await?;
1054
1055 if result.rows_affected() == 1 {
1056 Ok(())
1057 } else {
1058 Err(anyhow!("contact already requested"))?
1059 }
1060 }
1061
1062 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1063 let (id_a, id_b) = if responder_id < requester_id {
1064 (responder_id, requester_id)
1065 } else {
1066 (requester_id, responder_id)
1067 };
1068 let query = "
1069 DELETE FROM contacts
1070 WHERE user_id_a = $1 AND user_id_b = $2;
1071 ";
1072 let result = sqlx::query(query)
1073 .bind(id_a.0)
1074 .bind(id_b.0)
1075 .execute(&self.pool)
1076 .await?;
1077
1078 if result.rows_affected() == 1 {
1079 Ok(())
1080 } else {
1081 Err(anyhow!("no such contact"))?
1082 }
1083 }
1084
1085 async fn dismiss_contact_notification(
1086 &self,
1087 user_id: UserId,
1088 contact_user_id: UserId,
1089 ) -> Result<()> {
1090 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1091 (user_id, contact_user_id, true)
1092 } else {
1093 (contact_user_id, user_id, false)
1094 };
1095
1096 let query = "
1097 UPDATE contacts
1098 SET should_notify = 'f'
1099 WHERE
1100 user_id_a = $1 AND user_id_b = $2 AND
1101 (
1102 (a_to_b = $3 AND accepted) OR
1103 (a_to_b != $3 AND NOT accepted)
1104 );
1105 ";
1106
1107 let result = sqlx::query(query)
1108 .bind(id_a.0)
1109 .bind(id_b.0)
1110 .bind(a_to_b)
1111 .execute(&self.pool)
1112 .await?;
1113
1114 if result.rows_affected() == 0 {
1115 Err(anyhow!("no such contact request"))?;
1116 }
1117
1118 Ok(())
1119 }
1120
1121 async fn respond_to_contact_request(
1122 &self,
1123 responder_id: UserId,
1124 requester_id: UserId,
1125 accept: bool,
1126 ) -> Result<()> {
1127 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1128 (responder_id, requester_id, false)
1129 } else {
1130 (requester_id, responder_id, true)
1131 };
1132 let result = if accept {
1133 let query = "
1134 UPDATE contacts
1135 SET accepted = 't', should_notify = 't'
1136 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1137 ";
1138 sqlx::query(query)
1139 .bind(id_a.0)
1140 .bind(id_b.0)
1141 .bind(a_to_b)
1142 .execute(&self.pool)
1143 .await?
1144 } else {
1145 let query = "
1146 DELETE FROM contacts
1147 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1148 ";
1149 sqlx::query(query)
1150 .bind(id_a.0)
1151 .bind(id_b.0)
1152 .bind(a_to_b)
1153 .execute(&self.pool)
1154 .await?
1155 };
1156 if result.rows_affected() == 1 {
1157 Ok(())
1158 } else {
1159 Err(anyhow!("no such contact request"))?
1160 }
1161 }
1162
1163 // access tokens
1164
1165 async fn create_access_token_hash(
1166 &self,
1167 user_id: UserId,
1168 access_token_hash: &str,
1169 max_access_token_count: usize,
1170 ) -> Result<()> {
1171 let insert_query = "
1172 INSERT INTO access_tokens (user_id, hash)
1173 VALUES ($1, $2);
1174 ";
1175 let cleanup_query = "
1176 DELETE FROM access_tokens
1177 WHERE id IN (
1178 SELECT id from access_tokens
1179 WHERE user_id = $1
1180 ORDER BY id DESC
1181 OFFSET $3
1182 )
1183 ";
1184
1185 let mut tx = self.pool.begin().await?;
1186 sqlx::query(insert_query)
1187 .bind(user_id.0)
1188 .bind(access_token_hash)
1189 .execute(&mut tx)
1190 .await?;
1191 sqlx::query(cleanup_query)
1192 .bind(user_id.0)
1193 .bind(access_token_hash)
1194 .bind(max_access_token_count as i32)
1195 .execute(&mut tx)
1196 .await?;
1197 Ok(tx.commit().await?)
1198 }
1199
1200 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1201 let query = "
1202 SELECT hash
1203 FROM access_tokens
1204 WHERE user_id = $1
1205 ORDER BY id DESC
1206 ";
1207 Ok(sqlx::query_scalar(query)
1208 .bind(user_id.0)
1209 .fetch_all(&self.pool)
1210 .await?)
1211 }
1212
1213 // orgs
1214
1215 #[allow(unused)] // Help rust-analyzer
1216 #[cfg(any(test, feature = "seed-support"))]
1217 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1218 let query = "
1219 SELECT *
1220 FROM orgs
1221 WHERE slug = $1
1222 ";
1223 Ok(sqlx::query_as(query)
1224 .bind(slug)
1225 .fetch_optional(&self.pool)
1226 .await?)
1227 }
1228
1229 #[cfg(any(test, feature = "seed-support"))]
1230 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1231 let query = "
1232 INSERT INTO orgs (name, slug)
1233 VALUES ($1, $2)
1234 RETURNING id
1235 ";
1236 Ok(sqlx::query_scalar(query)
1237 .bind(name)
1238 .bind(slug)
1239 .fetch_one(&self.pool)
1240 .await
1241 .map(OrgId)?)
1242 }
1243
1244 #[cfg(any(test, feature = "seed-support"))]
1245 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1246 let query = "
1247 INSERT INTO org_memberships (org_id, user_id, admin)
1248 VALUES ($1, $2, $3)
1249 ON CONFLICT DO NOTHING
1250 ";
1251 Ok(sqlx::query(query)
1252 .bind(org_id.0)
1253 .bind(user_id.0)
1254 .bind(is_admin)
1255 .execute(&self.pool)
1256 .await
1257 .map(drop)?)
1258 }
1259
1260 // channels
1261
1262 #[cfg(any(test, feature = "seed-support"))]
1263 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1264 let query = "
1265 INSERT INTO channels (owner_id, owner_is_user, name)
1266 VALUES ($1, false, $2)
1267 RETURNING id
1268 ";
1269 Ok(sqlx::query_scalar(query)
1270 .bind(org_id.0)
1271 .bind(name)
1272 .fetch_one(&self.pool)
1273 .await
1274 .map(ChannelId)?)
1275 }
1276
1277 #[allow(unused)] // Help rust-analyzer
1278 #[cfg(any(test, feature = "seed-support"))]
1279 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1280 let query = "
1281 SELECT *
1282 FROM channels
1283 WHERE
1284 channels.owner_is_user = false AND
1285 channels.owner_id = $1
1286 ";
1287 Ok(sqlx::query_as(query)
1288 .bind(org_id.0)
1289 .fetch_all(&self.pool)
1290 .await?)
1291 }
1292
1293 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1294 let query = "
1295 SELECT
1296 channels.*
1297 FROM
1298 channel_memberships, channels
1299 WHERE
1300 channel_memberships.user_id = $1 AND
1301 channel_memberships.channel_id = channels.id
1302 ";
1303 Ok(sqlx::query_as(query)
1304 .bind(user_id.0)
1305 .fetch_all(&self.pool)
1306 .await?)
1307 }
1308
1309 async fn can_user_access_channel(
1310 &self,
1311 user_id: UserId,
1312 channel_id: ChannelId,
1313 ) -> Result<bool> {
1314 let query = "
1315 SELECT id
1316 FROM channel_memberships
1317 WHERE user_id = $1 AND channel_id = $2
1318 LIMIT 1
1319 ";
1320 Ok(sqlx::query_scalar::<_, i32>(query)
1321 .bind(user_id.0)
1322 .bind(channel_id.0)
1323 .fetch_optional(&self.pool)
1324 .await
1325 .map(|e| e.is_some())?)
1326 }
1327
1328 #[cfg(any(test, feature = "seed-support"))]
1329 async fn add_channel_member(
1330 &self,
1331 channel_id: ChannelId,
1332 user_id: UserId,
1333 is_admin: bool,
1334 ) -> Result<()> {
1335 let query = "
1336 INSERT INTO channel_memberships (channel_id, user_id, admin)
1337 VALUES ($1, $2, $3)
1338 ON CONFLICT DO NOTHING
1339 ";
1340 Ok(sqlx::query(query)
1341 .bind(channel_id.0)
1342 .bind(user_id.0)
1343 .bind(is_admin)
1344 .execute(&self.pool)
1345 .await
1346 .map(drop)?)
1347 }
1348
1349 // messages
1350
1351 async fn create_channel_message(
1352 &self,
1353 channel_id: ChannelId,
1354 sender_id: UserId,
1355 body: &str,
1356 timestamp: OffsetDateTime,
1357 nonce: u128,
1358 ) -> Result<MessageId> {
1359 let query = "
1360 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1361 VALUES ($1, $2, $3, $4, $5)
1362 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1363 RETURNING id
1364 ";
1365 Ok(sqlx::query_scalar(query)
1366 .bind(channel_id.0)
1367 .bind(sender_id.0)
1368 .bind(body)
1369 .bind(timestamp)
1370 .bind(Uuid::from_u128(nonce))
1371 .fetch_one(&self.pool)
1372 .await
1373 .map(MessageId)?)
1374 }
1375
1376 async fn get_channel_messages(
1377 &self,
1378 channel_id: ChannelId,
1379 count: usize,
1380 before_id: Option<MessageId>,
1381 ) -> Result<Vec<ChannelMessage>> {
1382 let query = r#"
1383 SELECT * FROM (
1384 SELECT
1385 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1386 FROM
1387 channel_messages
1388 WHERE
1389 channel_id = $1 AND
1390 id < $2
1391 ORDER BY id DESC
1392 LIMIT $3
1393 ) as recent_messages
1394 ORDER BY id ASC
1395 "#;
1396 Ok(sqlx::query_as(query)
1397 .bind(channel_id.0)
1398 .bind(before_id.unwrap_or(MessageId::MAX))
1399 .bind(count as i64)
1400 .fetch_all(&self.pool)
1401 .await?)
1402 }
1403
1404 #[cfg(test)]
1405 async fn teardown(&self, url: &str) {
1406 use util::ResultExt;
1407
1408 let query = "
1409 SELECT pg_terminate_backend(pg_stat_activity.pid)
1410 FROM pg_stat_activity
1411 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1412 ";
1413 sqlx::query(query).execute(&self.pool).await.log_err();
1414 self.pool.close().await;
1415 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1416 .await
1417 .log_err();
1418 }
1419
1420 #[cfg(test)]
1421 fn as_fake(&self) -> Option<&tests::FakeDb> {
1422 None
1423 }
1424}
1425
1426macro_rules! id_type {
1427 ($name:ident) => {
1428 #[derive(
1429 Clone,
1430 Copy,
1431 Debug,
1432 Default,
1433 PartialEq,
1434 Eq,
1435 PartialOrd,
1436 Ord,
1437 Hash,
1438 sqlx::Type,
1439 Serialize,
1440 Deserialize,
1441 )]
1442 #[sqlx(transparent)]
1443 #[serde(transparent)]
1444 pub struct $name(pub i32);
1445
1446 impl $name {
1447 #[allow(unused)]
1448 pub const MAX: Self = Self(i32::MAX);
1449
1450 #[allow(unused)]
1451 pub fn from_proto(value: u64) -> Self {
1452 Self(value as i32)
1453 }
1454
1455 #[allow(unused)]
1456 pub fn to_proto(self) -> u64 {
1457 self.0 as u64
1458 }
1459 }
1460
1461 impl std::fmt::Display for $name {
1462 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1463 self.0.fmt(f)
1464 }
1465 }
1466 };
1467}
1468
1469id_type!(UserId);
1470#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1471pub struct User {
1472 pub id: UserId,
1473 pub github_login: String,
1474 pub email_address: Option<String>,
1475 pub admin: bool,
1476 pub invite_code: Option<String>,
1477 pub invite_count: i32,
1478 pub connected_once: bool,
1479}
1480
1481id_type!(ProjectId);
1482#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1483pub struct Project {
1484 pub id: ProjectId,
1485 pub host_user_id: UserId,
1486 pub unregistered: bool,
1487}
1488
1489#[derive(Clone, Debug, PartialEq, Serialize)]
1490pub struct UserActivitySummary {
1491 pub id: UserId,
1492 pub github_login: String,
1493 pub project_activity: Vec<ProjectActivitySummary>,
1494}
1495
1496#[derive(Clone, Debug, PartialEq, Serialize)]
1497pub struct ProjectActivitySummary {
1498 id: ProjectId,
1499 duration: Duration,
1500 max_collaborators: usize,
1501}
1502
1503#[derive(Clone, Debug, PartialEq, Serialize)]
1504pub struct UserActivityPeriod {
1505 project_id: ProjectId,
1506 #[serde(with = "time::serde::iso8601")]
1507 start: OffsetDateTime,
1508 #[serde(with = "time::serde::iso8601")]
1509 end: OffsetDateTime,
1510 extensions: HashMap<String, usize>,
1511}
1512
1513id_type!(OrgId);
1514#[derive(FromRow)]
1515pub struct Org {
1516 pub id: OrgId,
1517 pub name: String,
1518 pub slug: String,
1519}
1520
1521id_type!(ChannelId);
1522#[derive(Clone, Debug, FromRow, Serialize)]
1523pub struct Channel {
1524 pub id: ChannelId,
1525 pub name: String,
1526 pub owner_id: i32,
1527 pub owner_is_user: bool,
1528}
1529
1530id_type!(MessageId);
1531#[derive(Clone, Debug, FromRow)]
1532pub struct ChannelMessage {
1533 pub id: MessageId,
1534 pub channel_id: ChannelId,
1535 pub sender_id: UserId,
1536 pub body: String,
1537 pub sent_at: OffsetDateTime,
1538 pub nonce: Uuid,
1539}
1540
1541#[derive(Clone, Debug, PartialEq, Eq)]
1542pub enum Contact {
1543 Accepted {
1544 user_id: UserId,
1545 should_notify: bool,
1546 },
1547 Outgoing {
1548 user_id: UserId,
1549 },
1550 Incoming {
1551 user_id: UserId,
1552 should_notify: bool,
1553 },
1554}
1555
1556impl Contact {
1557 pub fn user_id(&self) -> UserId {
1558 match self {
1559 Contact::Accepted { user_id, .. } => *user_id,
1560 Contact::Outgoing { user_id } => *user_id,
1561 Contact::Incoming { user_id, .. } => *user_id,
1562 }
1563 }
1564}
1565
1566#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1567pub struct IncomingContactRequest {
1568 pub requester_id: UserId,
1569 pub should_notify: bool,
1570}
1571
1572#[derive(Clone, Deserialize)]
1573pub struct Signup {
1574 pub email_address: String,
1575 pub platform_mac: bool,
1576 pub platform_windows: bool,
1577 pub platform_linux: bool,
1578 pub editor_features: Vec<String>,
1579 pub programming_languages: Vec<String>,
1580}
1581
1582#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1583pub struct SignupInvite {
1584 pub email_address: String,
1585 pub email_confirmation_code: String,
1586}
1587
1588#[derive(Debug, Serialize, Deserialize)]
1589pub struct SignupRedemption {
1590 pub email_address: String,
1591 pub email_confirmation_code: String,
1592 pub github_login: String,
1593 pub invite_count: i32,
1594}
1595
1596fn fuzzy_like_string(string: &str) -> String {
1597 let mut result = String::with_capacity(string.len() * 2 + 1);
1598 for c in string.chars() {
1599 if c.is_alphanumeric() {
1600 result.push('%');
1601 result.push(c);
1602 }
1603 }
1604 result.push('%');
1605 result
1606}
1607
1608fn random_invite_code() -> String {
1609 nanoid::nanoid!(16)
1610}
1611
1612fn random_email_confirmation_code() -> String {
1613 nanoid::nanoid!(64)
1614}
1615
1616#[cfg(test)]
1617pub mod tests {
1618 use super::*;
1619 use anyhow::anyhow;
1620 use collections::BTreeMap;
1621 use gpui::executor::{Background, Deterministic};
1622 use lazy_static::lazy_static;
1623 use parking_lot::Mutex;
1624 use rand::prelude::*;
1625 use sqlx::{
1626 migrate::{MigrateDatabase, Migrator},
1627 Postgres,
1628 };
1629 use std::{path::Path, sync::Arc};
1630 use util::post_inc;
1631
1632 #[tokio::test(flavor = "multi_thread")]
1633 async fn test_get_users_by_ids() {
1634 for test_db in [
1635 TestDb::postgres().await,
1636 TestDb::fake(build_background_executor()),
1637 ] {
1638 let db = test_db.db();
1639
1640 let user = db.create_user("user", None, false).await.unwrap();
1641 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1642 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1643 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1644
1645 assert_eq!(
1646 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1647 .await
1648 .unwrap(),
1649 vec![
1650 User {
1651 id: user,
1652 github_login: "user".to_string(),
1653 admin: false,
1654 ..Default::default()
1655 },
1656 User {
1657 id: friend1,
1658 github_login: "friend-1".to_string(),
1659 admin: false,
1660 ..Default::default()
1661 },
1662 User {
1663 id: friend2,
1664 github_login: "friend-2".to_string(),
1665 admin: false,
1666 ..Default::default()
1667 },
1668 User {
1669 id: friend3,
1670 github_login: "friend-3".to_string(),
1671 admin: false,
1672 ..Default::default()
1673 }
1674 ]
1675 );
1676 }
1677 }
1678
1679 #[tokio::test(flavor = "multi_thread")]
1680 async fn test_create_users() {
1681 let db = TestDb::postgres().await;
1682 let db = db.db();
1683
1684 // Create the first batch of users, ensuring invite counts are assigned
1685 // correctly and the respective invite codes are unique.
1686 let user_ids_batch_1 = db
1687 .create_users(vec![
1688 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1689 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1690 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1691 ])
1692 .await
1693 .unwrap();
1694 assert_eq!(user_ids_batch_1.len(), 3);
1695
1696 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1697 assert_eq!(users.len(), 3);
1698 assert_eq!(users[0].github_login, "user1");
1699 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1700 assert_eq!(users[0].invite_count, 5);
1701 assert_eq!(users[1].github_login, "user2");
1702 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1703 assert_eq!(users[1].invite_count, 4);
1704 assert_eq!(users[2].github_login, "user3");
1705 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1706 assert_eq!(users[2].invite_count, 3);
1707
1708 let invite_code_1 = users[0].invite_code.clone().unwrap();
1709 let invite_code_2 = users[1].invite_code.clone().unwrap();
1710 let invite_code_3 = users[2].invite_code.clone().unwrap();
1711 assert_ne!(invite_code_1, invite_code_2);
1712 assert_ne!(invite_code_1, invite_code_3);
1713 assert_ne!(invite_code_2, invite_code_3);
1714
1715 // Create the second batch of users and include a user that is already in the database, ensuring
1716 // the invite count for the existing user is updated without changing their invite code.
1717 let user_ids_batch_2 = db
1718 .create_users(vec![
1719 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1720 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1721 ])
1722 .await
1723 .unwrap();
1724 assert_eq!(user_ids_batch_2.len(), 2);
1725 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1726
1727 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1728 assert_eq!(users.len(), 2);
1729 assert_eq!(users[0].github_login, "user2");
1730 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1731 assert_eq!(users[0].invite_count, 10);
1732 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1733 assert_eq!(users[1].github_login, "user4");
1734 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1735 assert_eq!(users[1].invite_count, 2);
1736
1737 let invite_code_4 = users[1].invite_code.clone().unwrap();
1738 assert_ne!(invite_code_4, invite_code_1);
1739 assert_ne!(invite_code_4, invite_code_2);
1740 assert_ne!(invite_code_4, invite_code_3);
1741 }
1742
1743 #[tokio::test(flavor = "multi_thread")]
1744 async fn test_worktree_extensions() {
1745 let test_db = TestDb::postgres().await;
1746 let db = test_db.db();
1747
1748 let user = db.create_user("user_1", None, false).await.unwrap();
1749 let project = db.register_project(user).await.unwrap();
1750
1751 db.update_worktree_extensions(project, 100, Default::default())
1752 .await
1753 .unwrap();
1754 db.update_worktree_extensions(
1755 project,
1756 100,
1757 [("rs".to_string(), 5), ("md".to_string(), 3)]
1758 .into_iter()
1759 .collect(),
1760 )
1761 .await
1762 .unwrap();
1763 db.update_worktree_extensions(
1764 project,
1765 100,
1766 [("rs".to_string(), 6), ("md".to_string(), 5)]
1767 .into_iter()
1768 .collect(),
1769 )
1770 .await
1771 .unwrap();
1772 db.update_worktree_extensions(
1773 project,
1774 101,
1775 [("ts".to_string(), 2), ("md".to_string(), 1)]
1776 .into_iter()
1777 .collect(),
1778 )
1779 .await
1780 .unwrap();
1781
1782 assert_eq!(
1783 db.get_project_extensions(project).await.unwrap(),
1784 [
1785 (
1786 100,
1787 [("rs".into(), 6), ("md".into(), 5),]
1788 .into_iter()
1789 .collect::<HashMap<_, _>>()
1790 ),
1791 (
1792 101,
1793 [("ts".into(), 2), ("md".into(), 1),]
1794 .into_iter()
1795 .collect::<HashMap<_, _>>()
1796 )
1797 ]
1798 .into_iter()
1799 .collect()
1800 );
1801 }
1802
1803 #[tokio::test(flavor = "multi_thread")]
1804 async fn test_user_activity() {
1805 let test_db = TestDb::postgres().await;
1806 let db = test_db.db();
1807
1808 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1809 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1810 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1811 let project_1 = db.register_project(user_1).await.unwrap();
1812 db.update_worktree_extensions(
1813 project_1,
1814 1,
1815 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1816 )
1817 .await
1818 .unwrap();
1819 let project_2 = db.register_project(user_2).await.unwrap();
1820 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1821
1822 // User 2 opens a project
1823 let t1 = t0 + Duration::from_secs(10);
1824 db.record_user_activity(t0..t1, &[(user_2, project_2)])
1825 .await
1826 .unwrap();
1827
1828 let t2 = t1 + Duration::from_secs(10);
1829 db.record_user_activity(t1..t2, &[(user_2, project_2)])
1830 .await
1831 .unwrap();
1832
1833 // User 1 joins the project
1834 let t3 = t2 + Duration::from_secs(10);
1835 db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1836 .await
1837 .unwrap();
1838
1839 // User 1 opens another project
1840 let t4 = t3 + Duration::from_secs(10);
1841 db.record_user_activity(
1842 t3..t4,
1843 &[
1844 (user_2, project_2),
1845 (user_1, project_2),
1846 (user_1, project_1),
1847 ],
1848 )
1849 .await
1850 .unwrap();
1851
1852 // User 3 joins that project
1853 let t5 = t4 + Duration::from_secs(10);
1854 db.record_user_activity(
1855 t4..t5,
1856 &[
1857 (user_2, project_2),
1858 (user_1, project_2),
1859 (user_1, project_1),
1860 (user_3, project_1),
1861 ],
1862 )
1863 .await
1864 .unwrap();
1865
1866 // User 2 leaves
1867 let t6 = t5 + Duration::from_secs(5);
1868 db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1869 .await
1870 .unwrap();
1871
1872 let t7 = t6 + Duration::from_secs(60);
1873 let t8 = t7 + Duration::from_secs(10);
1874 db.record_user_activity(t7..t8, &[(user_1, project_1)])
1875 .await
1876 .unwrap();
1877
1878 assert_eq!(
1879 db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1880 &[
1881 UserActivitySummary {
1882 id: user_1,
1883 github_login: "user_1".to_string(),
1884 project_activity: vec![
1885 ProjectActivitySummary {
1886 id: project_1,
1887 duration: Duration::from_secs(25),
1888 max_collaborators: 2
1889 },
1890 ProjectActivitySummary {
1891 id: project_2,
1892 duration: Duration::from_secs(30),
1893 max_collaborators: 2
1894 }
1895 ]
1896 },
1897 UserActivitySummary {
1898 id: user_2,
1899 github_login: "user_2".to_string(),
1900 project_activity: vec![ProjectActivitySummary {
1901 id: project_2,
1902 duration: Duration::from_secs(50),
1903 max_collaborators: 2
1904 }]
1905 },
1906 UserActivitySummary {
1907 id: user_3,
1908 github_login: "user_3".to_string(),
1909 project_activity: vec![ProjectActivitySummary {
1910 id: project_1,
1911 duration: Duration::from_secs(15),
1912 max_collaborators: 2
1913 }]
1914 },
1915 ]
1916 );
1917
1918 assert_eq!(
1919 db.get_active_user_count(t0..t6, Duration::from_secs(56), false)
1920 .await
1921 .unwrap(),
1922 0
1923 );
1924 assert_eq!(
1925 db.get_active_user_count(t0..t6, Duration::from_secs(56), true)
1926 .await
1927 .unwrap(),
1928 0
1929 );
1930 assert_eq!(
1931 db.get_active_user_count(t0..t6, Duration::from_secs(54), false)
1932 .await
1933 .unwrap(),
1934 1
1935 );
1936 assert_eq!(
1937 db.get_active_user_count(t0..t6, Duration::from_secs(54), true)
1938 .await
1939 .unwrap(),
1940 1
1941 );
1942 assert_eq!(
1943 db.get_active_user_count(t0..t6, Duration::from_secs(30), false)
1944 .await
1945 .unwrap(),
1946 2
1947 );
1948 assert_eq!(
1949 db.get_active_user_count(t0..t6, Duration::from_secs(30), true)
1950 .await
1951 .unwrap(),
1952 2
1953 );
1954 assert_eq!(
1955 db.get_active_user_count(t0..t6, Duration::from_secs(10), false)
1956 .await
1957 .unwrap(),
1958 3
1959 );
1960 assert_eq!(
1961 db.get_active_user_count(t0..t6, Duration::from_secs(10), true)
1962 .await
1963 .unwrap(),
1964 3
1965 );
1966 assert_eq!(
1967 db.get_active_user_count(t0..t1, Duration::from_secs(5), false)
1968 .await
1969 .unwrap(),
1970 1
1971 );
1972 assert_eq!(
1973 db.get_active_user_count(t0..t1, Duration::from_secs(5), true)
1974 .await
1975 .unwrap(),
1976 0
1977 );
1978
1979 assert_eq!(
1980 db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1981 &[
1982 UserActivityPeriod {
1983 project_id: project_1,
1984 start: t3,
1985 end: t6,
1986 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1987 },
1988 UserActivityPeriod {
1989 project_id: project_2,
1990 start: t3,
1991 end: t5,
1992 extensions: Default::default(),
1993 },
1994 ]
1995 );
1996 assert_eq!(
1997 db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1998 &[
1999 UserActivityPeriod {
2000 project_id: project_2,
2001 start: t2,
2002 end: t5,
2003 extensions: Default::default(),
2004 },
2005 UserActivityPeriod {
2006 project_id: project_1,
2007 start: t3,
2008 end: t6,
2009 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
2010 },
2011 UserActivityPeriod {
2012 project_id: project_1,
2013 start: t7,
2014 end: t8,
2015 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
2016 },
2017 ]
2018 );
2019 }
2020
2021 #[tokio::test(flavor = "multi_thread")]
2022 async fn test_recent_channel_messages() {
2023 for test_db in [
2024 TestDb::postgres().await,
2025 TestDb::fake(build_background_executor()),
2026 ] {
2027 let db = test_db.db();
2028 let user = db.create_user("user", None, false).await.unwrap();
2029 let org = db.create_org("org", "org").await.unwrap();
2030 let channel = db.create_org_channel(org, "channel").await.unwrap();
2031 for i in 0..10 {
2032 db.create_channel_message(
2033 channel,
2034 user,
2035 &i.to_string(),
2036 OffsetDateTime::now_utc(),
2037 i,
2038 )
2039 .await
2040 .unwrap();
2041 }
2042
2043 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
2044 assert_eq!(
2045 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
2046 ["5", "6", "7", "8", "9"]
2047 );
2048
2049 let prev_messages = db
2050 .get_channel_messages(channel, 4, Some(messages[0].id))
2051 .await
2052 .unwrap();
2053 assert_eq!(
2054 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
2055 ["1", "2", "3", "4"]
2056 );
2057 }
2058 }
2059
2060 #[tokio::test(flavor = "multi_thread")]
2061 async fn test_channel_message_nonces() {
2062 for test_db in [
2063 TestDb::postgres().await,
2064 TestDb::fake(build_background_executor()),
2065 ] {
2066 let db = test_db.db();
2067 let user = db.create_user("user", None, false).await.unwrap();
2068 let org = db.create_org("org", "org").await.unwrap();
2069 let channel = db.create_org_channel(org, "channel").await.unwrap();
2070
2071 let msg1_id = db
2072 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
2073 .await
2074 .unwrap();
2075 let msg2_id = db
2076 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
2077 .await
2078 .unwrap();
2079 let msg3_id = db
2080 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
2081 .await
2082 .unwrap();
2083 let msg4_id = db
2084 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
2085 .await
2086 .unwrap();
2087
2088 assert_ne!(msg1_id, msg2_id);
2089 assert_eq!(msg1_id, msg3_id);
2090 assert_eq!(msg2_id, msg4_id);
2091 }
2092 }
2093
2094 #[tokio::test(flavor = "multi_thread")]
2095 async fn test_create_access_tokens() {
2096 let test_db = TestDb::postgres().await;
2097 let db = test_db.db();
2098 let user = db.create_user("the-user", None, false).await.unwrap();
2099
2100 db.create_access_token_hash(user, "h1", 3).await.unwrap();
2101 db.create_access_token_hash(user, "h2", 3).await.unwrap();
2102 assert_eq!(
2103 db.get_access_token_hashes(user).await.unwrap(),
2104 &["h2".to_string(), "h1".to_string()]
2105 );
2106
2107 db.create_access_token_hash(user, "h3", 3).await.unwrap();
2108 assert_eq!(
2109 db.get_access_token_hashes(user).await.unwrap(),
2110 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
2111 );
2112
2113 db.create_access_token_hash(user, "h4", 3).await.unwrap();
2114 assert_eq!(
2115 db.get_access_token_hashes(user).await.unwrap(),
2116 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
2117 );
2118
2119 db.create_access_token_hash(user, "h5", 3).await.unwrap();
2120 assert_eq!(
2121 db.get_access_token_hashes(user).await.unwrap(),
2122 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
2123 );
2124 }
2125
2126 #[test]
2127 fn test_fuzzy_like_string() {
2128 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
2129 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
2130 assert_eq!(fuzzy_like_string(" z "), "%z%");
2131 }
2132
2133 #[tokio::test(flavor = "multi_thread")]
2134 async fn test_fuzzy_search_users() {
2135 let test_db = TestDb::postgres().await;
2136 let db = test_db.db();
2137 for github_login in [
2138 "California",
2139 "colorado",
2140 "oregon",
2141 "washington",
2142 "florida",
2143 "delaware",
2144 "rhode-island",
2145 ] {
2146 db.create_user(github_login, None, false).await.unwrap();
2147 }
2148
2149 assert_eq!(
2150 fuzzy_search_user_names(db, "clr").await,
2151 &["colorado", "California"]
2152 );
2153 assert_eq!(
2154 fuzzy_search_user_names(db, "ro").await,
2155 &["rhode-island", "colorado", "oregon"],
2156 );
2157
2158 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
2159 db.fuzzy_search_users(query, 10)
2160 .await
2161 .unwrap()
2162 .into_iter()
2163 .map(|user| user.github_login)
2164 .collect::<Vec<_>>()
2165 }
2166 }
2167
2168 #[tokio::test(flavor = "multi_thread")]
2169 async fn test_add_contacts() {
2170 for test_db in [
2171 TestDb::postgres().await,
2172 TestDb::fake(build_background_executor()),
2173 ] {
2174 let db = test_db.db();
2175
2176 let user_1 = db.create_user("user1", None, false).await.unwrap();
2177 let user_2 = db.create_user("user2", None, false).await.unwrap();
2178 let user_3 = db.create_user("user3", None, false).await.unwrap();
2179
2180 // User starts with no contacts
2181 assert_eq!(
2182 db.get_contacts(user_1).await.unwrap(),
2183 vec![Contact::Accepted {
2184 user_id: user_1,
2185 should_notify: false
2186 }],
2187 );
2188
2189 // User requests a contact. Both users see the pending request.
2190 db.send_contact_request(user_1, user_2).await.unwrap();
2191 assert!(!db.has_contact(user_1, user_2).await.unwrap());
2192 assert!(!db.has_contact(user_2, user_1).await.unwrap());
2193 assert_eq!(
2194 db.get_contacts(user_1).await.unwrap(),
2195 &[
2196 Contact::Accepted {
2197 user_id: user_1,
2198 should_notify: false
2199 },
2200 Contact::Outgoing { user_id: user_2 }
2201 ],
2202 );
2203 assert_eq!(
2204 db.get_contacts(user_2).await.unwrap(),
2205 &[
2206 Contact::Incoming {
2207 user_id: user_1,
2208 should_notify: true
2209 },
2210 Contact::Accepted {
2211 user_id: user_2,
2212 should_notify: false
2213 },
2214 ]
2215 );
2216
2217 // User 2 dismisses the contact request notification without accepting or rejecting.
2218 // We shouldn't notify them again.
2219 db.dismiss_contact_notification(user_1, user_2)
2220 .await
2221 .unwrap_err();
2222 db.dismiss_contact_notification(user_2, user_1)
2223 .await
2224 .unwrap();
2225 assert_eq!(
2226 db.get_contacts(user_2).await.unwrap(),
2227 &[
2228 Contact::Incoming {
2229 user_id: user_1,
2230 should_notify: false
2231 },
2232 Contact::Accepted {
2233 user_id: user_2,
2234 should_notify: false
2235 },
2236 ]
2237 );
2238
2239 // User can't accept their own contact request
2240 db.respond_to_contact_request(user_1, user_2, true)
2241 .await
2242 .unwrap_err();
2243
2244 // User accepts a contact request. Both users see the contact.
2245 db.respond_to_contact_request(user_2, user_1, true)
2246 .await
2247 .unwrap();
2248 assert_eq!(
2249 db.get_contacts(user_1).await.unwrap(),
2250 &[
2251 Contact::Accepted {
2252 user_id: user_1,
2253 should_notify: false
2254 },
2255 Contact::Accepted {
2256 user_id: user_2,
2257 should_notify: true
2258 }
2259 ],
2260 );
2261 assert!(db.has_contact(user_1, user_2).await.unwrap());
2262 assert!(db.has_contact(user_2, user_1).await.unwrap());
2263 assert_eq!(
2264 db.get_contacts(user_2).await.unwrap(),
2265 &[
2266 Contact::Accepted {
2267 user_id: user_1,
2268 should_notify: false,
2269 },
2270 Contact::Accepted {
2271 user_id: user_2,
2272 should_notify: false,
2273 },
2274 ]
2275 );
2276
2277 // Users cannot re-request existing contacts.
2278 db.send_contact_request(user_1, user_2).await.unwrap_err();
2279 db.send_contact_request(user_2, user_1).await.unwrap_err();
2280
2281 // Users can't dismiss notifications of them accepting other users' requests.
2282 db.dismiss_contact_notification(user_2, user_1)
2283 .await
2284 .unwrap_err();
2285 assert_eq!(
2286 db.get_contacts(user_1).await.unwrap(),
2287 &[
2288 Contact::Accepted {
2289 user_id: user_1,
2290 should_notify: false
2291 },
2292 Contact::Accepted {
2293 user_id: user_2,
2294 should_notify: true,
2295 },
2296 ]
2297 );
2298
2299 // Users can dismiss notifications of other users accepting their requests.
2300 db.dismiss_contact_notification(user_1, user_2)
2301 .await
2302 .unwrap();
2303 assert_eq!(
2304 db.get_contacts(user_1).await.unwrap(),
2305 &[
2306 Contact::Accepted {
2307 user_id: user_1,
2308 should_notify: false
2309 },
2310 Contact::Accepted {
2311 user_id: user_2,
2312 should_notify: false,
2313 },
2314 ]
2315 );
2316
2317 // Users send each other concurrent contact requests and
2318 // see that they are immediately accepted.
2319 db.send_contact_request(user_1, user_3).await.unwrap();
2320 db.send_contact_request(user_3, user_1).await.unwrap();
2321 assert_eq!(
2322 db.get_contacts(user_1).await.unwrap(),
2323 &[
2324 Contact::Accepted {
2325 user_id: user_1,
2326 should_notify: false
2327 },
2328 Contact::Accepted {
2329 user_id: user_2,
2330 should_notify: false,
2331 },
2332 Contact::Accepted {
2333 user_id: user_3,
2334 should_notify: false
2335 },
2336 ]
2337 );
2338 assert_eq!(
2339 db.get_contacts(user_3).await.unwrap(),
2340 &[
2341 Contact::Accepted {
2342 user_id: user_1,
2343 should_notify: false
2344 },
2345 Contact::Accepted {
2346 user_id: user_3,
2347 should_notify: false
2348 }
2349 ],
2350 );
2351
2352 // User declines a contact request. Both users see that it is gone.
2353 db.send_contact_request(user_2, user_3).await.unwrap();
2354 db.respond_to_contact_request(user_3, user_2, false)
2355 .await
2356 .unwrap();
2357 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2358 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2359 assert_eq!(
2360 db.get_contacts(user_2).await.unwrap(),
2361 &[
2362 Contact::Accepted {
2363 user_id: user_1,
2364 should_notify: false
2365 },
2366 Contact::Accepted {
2367 user_id: user_2,
2368 should_notify: false
2369 }
2370 ]
2371 );
2372 assert_eq!(
2373 db.get_contacts(user_3).await.unwrap(),
2374 &[
2375 Contact::Accepted {
2376 user_id: user_1,
2377 should_notify: false
2378 },
2379 Contact::Accepted {
2380 user_id: user_3,
2381 should_notify: false
2382 }
2383 ],
2384 );
2385 }
2386 }
2387
2388 #[tokio::test(flavor = "multi_thread")]
2389 async fn test_invite_codes() {
2390 let postgres = TestDb::postgres().await;
2391 let db = postgres.db();
2392 let user1 = db.create_user("user-1", None, false).await.unwrap();
2393
2394 // Initially, user 1 has no invite code
2395 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2396
2397 // Setting invite count to 0 when no code is assigned does not assign a new code
2398 db.set_invite_count(user1, 0).await.unwrap();
2399 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2400
2401 // User 1 creates an invite code that can be used twice.
2402 db.set_invite_count(user1, 2).await.unwrap();
2403 let (invite_code, invite_count) =
2404 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2405 assert_eq!(invite_count, 2);
2406
2407 // User 2 redeems the invite code and becomes a contact of user 1.
2408 let user2 = db
2409 .redeem_invite_code(&invite_code, "user-2", None)
2410 .await
2411 .unwrap();
2412 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2413 assert_eq!(invite_count, 1);
2414 assert_eq!(
2415 db.get_contacts(user1).await.unwrap(),
2416 [
2417 Contact::Accepted {
2418 user_id: user1,
2419 should_notify: false
2420 },
2421 Contact::Accepted {
2422 user_id: user2,
2423 should_notify: true
2424 }
2425 ]
2426 );
2427 assert_eq!(
2428 db.get_contacts(user2).await.unwrap(),
2429 [
2430 Contact::Accepted {
2431 user_id: user1,
2432 should_notify: false
2433 },
2434 Contact::Accepted {
2435 user_id: user2,
2436 should_notify: false
2437 }
2438 ]
2439 );
2440
2441 // User 3 redeems the invite code and becomes a contact of user 1.
2442 let user3 = db
2443 .redeem_invite_code(&invite_code, "user-3", None)
2444 .await
2445 .unwrap();
2446 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2447 assert_eq!(invite_count, 0);
2448 assert_eq!(
2449 db.get_contacts(user1).await.unwrap(),
2450 [
2451 Contact::Accepted {
2452 user_id: user1,
2453 should_notify: false
2454 },
2455 Contact::Accepted {
2456 user_id: user2,
2457 should_notify: true
2458 },
2459 Contact::Accepted {
2460 user_id: user3,
2461 should_notify: true
2462 }
2463 ]
2464 );
2465 assert_eq!(
2466 db.get_contacts(user3).await.unwrap(),
2467 [
2468 Contact::Accepted {
2469 user_id: user1,
2470 should_notify: false
2471 },
2472 Contact::Accepted {
2473 user_id: user3,
2474 should_notify: false
2475 },
2476 ]
2477 );
2478
2479 // Trying to reedem the code for the third time results in an error.
2480 db.redeem_invite_code(&invite_code, "user-4", None)
2481 .await
2482 .unwrap_err();
2483
2484 // Invite count can be updated after the code has been created.
2485 db.set_invite_count(user1, 2).await.unwrap();
2486 let (latest_code, invite_count) =
2487 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2488 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2489 assert_eq!(invite_count, 2);
2490
2491 // User 4 can now redeem the invite code and becomes a contact of user 1.
2492 let user4 = db
2493 .redeem_invite_code(&invite_code, "user-4", None)
2494 .await
2495 .unwrap();
2496 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2497 assert_eq!(invite_count, 1);
2498 assert_eq!(
2499 db.get_contacts(user1).await.unwrap(),
2500 [
2501 Contact::Accepted {
2502 user_id: user1,
2503 should_notify: false
2504 },
2505 Contact::Accepted {
2506 user_id: user2,
2507 should_notify: true
2508 },
2509 Contact::Accepted {
2510 user_id: user3,
2511 should_notify: true
2512 },
2513 Contact::Accepted {
2514 user_id: user4,
2515 should_notify: true
2516 }
2517 ]
2518 );
2519 assert_eq!(
2520 db.get_contacts(user4).await.unwrap(),
2521 [
2522 Contact::Accepted {
2523 user_id: user1,
2524 should_notify: false
2525 },
2526 Contact::Accepted {
2527 user_id: user4,
2528 should_notify: false
2529 },
2530 ]
2531 );
2532
2533 // An existing user cannot redeem invite codes.
2534 db.redeem_invite_code(&invite_code, "user-2", None)
2535 .await
2536 .unwrap_err();
2537 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2538 assert_eq!(invite_count, 1);
2539
2540 // Ensure invited users get invite codes too.
2541 assert_eq!(
2542 db.get_invite_code_for_user(user2).await.unwrap().unwrap().1,
2543 5
2544 );
2545 assert_eq!(
2546 db.get_invite_code_for_user(user3).await.unwrap().unwrap().1,
2547 5
2548 );
2549 assert_eq!(
2550 db.get_invite_code_for_user(user4).await.unwrap().unwrap().1,
2551 5
2552 );
2553 }
2554
2555 #[tokio::test(flavor = "multi_thread")]
2556 async fn test_signups() {
2557 let postgres = TestDb::postgres().await;
2558 let db = postgres.db();
2559
2560 // people sign up on the waitlist
2561 for i in 0..8 {
2562 db.create_signup(Signup {
2563 email_address: format!("person-{i}@example.com"),
2564 platform_mac: true,
2565 platform_linux: true,
2566 platform_windows: false,
2567 editor_features: vec!["speed".into()],
2568 programming_languages: vec!["rust".into(), "c".into()],
2569 })
2570 .await
2571 .unwrap();
2572 }
2573
2574 // retrieve the next batch of signup emails to send
2575 let signups_batch1 = db.get_signup_invites(3).await.unwrap();
2576 let addresses = signups_batch1
2577 .iter()
2578 .map(|s| &s.email_address)
2579 .collect::<Vec<_>>();
2580 assert_eq!(
2581 addresses,
2582 &[
2583 "person-0@example.com",
2584 "person-1@example.com",
2585 "person-2@example.com"
2586 ]
2587 );
2588 assert_ne!(
2589 signups_batch1[0].email_confirmation_code,
2590 signups_batch1[1].email_confirmation_code
2591 );
2592
2593 // the waitlist isn't updated until we record that the emails
2594 // were successfully sent.
2595 let signups_batch = db.get_signup_invites(3).await.unwrap();
2596 assert_eq!(signups_batch, signups_batch1);
2597
2598 // once the emails go out, we can retrieve the next batch
2599 // of signups.
2600 db.record_signup_invites_sent(&signups_batch1)
2601 .await
2602 .unwrap();
2603 let signups_batch2 = db.get_signup_invites(3).await.unwrap();
2604 let addresses = signups_batch2
2605 .iter()
2606 .map(|s| &s.email_address)
2607 .collect::<Vec<_>>();
2608 assert_eq!(
2609 addresses,
2610 &[
2611 "person-3@example.com",
2612 "person-4@example.com",
2613 "person-5@example.com"
2614 ]
2615 );
2616
2617 // user completes the signup process by providing their
2618 // github account.
2619 let user_id = db
2620 .redeem_signup(SignupRedemption {
2621 email_address: signups_batch1[0].email_address.clone(),
2622 email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(),
2623 github_login: "person-0".into(),
2624 invite_count: 5,
2625 })
2626 .await
2627 .unwrap();
2628 let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
2629 assert_eq!(user.github_login, "person-0");
2630 assert_eq!(user.email_address.as_deref(), Some("person-0@example.com"));
2631 assert_eq!(user.invite_count, 5);
2632
2633 // cannot redeem the same signup again.
2634 db.redeem_signup(SignupRedemption {
2635 email_address: signups_batch1[0].email_address.clone(),
2636 email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(),
2637 github_login: "some-other-github_account".into(),
2638 invite_count: 5,
2639 })
2640 .await
2641 .unwrap_err();
2642
2643 // cannot redeem a signup with the wrong confirmation code.
2644 db.redeem_signup(SignupRedemption {
2645 email_address: signups_batch1[1].email_address.clone(),
2646 email_confirmation_code: "the-wrong-code".to_string(),
2647 github_login: "person-1".into(),
2648 invite_count: 5,
2649 })
2650 .await
2651 .unwrap_err();
2652 }
2653
2654 pub struct TestDb {
2655 pub db: Option<Arc<dyn Db>>,
2656 pub url: String,
2657 }
2658
2659 impl TestDb {
2660 #[allow(clippy::await_holding_lock)]
2661 pub async fn postgres() -> Self {
2662 lazy_static! {
2663 static ref LOCK: Mutex<()> = Mutex::new(());
2664 }
2665
2666 let _guard = LOCK.lock();
2667 let mut rng = StdRng::from_entropy();
2668 let name = format!("zed-test-{}", rng.gen::<u128>());
2669 let url = format!("postgres://postgres@localhost/{}", name);
2670 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2671 Postgres::create_database(&url)
2672 .await
2673 .expect("failed to create test db");
2674 let db = PostgresDb::new(&url, 5).await.unwrap();
2675 let migrator = Migrator::new(migrations_path).await.unwrap();
2676 migrator.run(&db.pool).await.unwrap();
2677 Self {
2678 db: Some(Arc::new(db)),
2679 url,
2680 }
2681 }
2682
2683 pub fn fake(background: Arc<Background>) -> Self {
2684 Self {
2685 db: Some(Arc::new(FakeDb::new(background))),
2686 url: Default::default(),
2687 }
2688 }
2689
2690 pub fn db(&self) -> &Arc<dyn Db> {
2691 self.db.as_ref().unwrap()
2692 }
2693 }
2694
2695 impl Drop for TestDb {
2696 fn drop(&mut self) {
2697 if let Some(db) = self.db.take() {
2698 futures::executor::block_on(db.teardown(&self.url));
2699 }
2700 }
2701 }
2702
2703 pub struct FakeDb {
2704 background: Arc<Background>,
2705 pub users: Mutex<BTreeMap<UserId, User>>,
2706 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2707 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2708 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2709 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2710 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2711 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2712 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2713 pub contacts: Mutex<Vec<FakeContact>>,
2714 next_channel_message_id: Mutex<i32>,
2715 next_user_id: Mutex<i32>,
2716 next_org_id: Mutex<i32>,
2717 next_channel_id: Mutex<i32>,
2718 next_project_id: Mutex<i32>,
2719 }
2720
2721 #[derive(Debug)]
2722 pub struct FakeContact {
2723 pub requester_id: UserId,
2724 pub responder_id: UserId,
2725 pub accepted: bool,
2726 pub should_notify: bool,
2727 }
2728
2729 impl FakeDb {
2730 pub fn new(background: Arc<Background>) -> Self {
2731 Self {
2732 background,
2733 users: Default::default(),
2734 next_user_id: Mutex::new(0),
2735 projects: Default::default(),
2736 worktree_extensions: Default::default(),
2737 next_project_id: Mutex::new(1),
2738 orgs: Default::default(),
2739 next_org_id: Mutex::new(1),
2740 org_memberships: Default::default(),
2741 channels: Default::default(),
2742 next_channel_id: Mutex::new(1),
2743 channel_memberships: Default::default(),
2744 channel_messages: Default::default(),
2745 next_channel_message_id: Mutex::new(1),
2746 contacts: Default::default(),
2747 }
2748 }
2749 }
2750
2751 #[async_trait]
2752 impl Db for FakeDb {
2753 async fn create_user(
2754 &self,
2755 github_login: &str,
2756 email_address: Option<&str>,
2757 admin: bool,
2758 ) -> Result<UserId> {
2759 self.background.simulate_random_delay().await;
2760
2761 let mut users = self.users.lock();
2762 if let Some(user) = users
2763 .values()
2764 .find(|user| user.github_login == github_login)
2765 {
2766 Ok(user.id)
2767 } else {
2768 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2769 users.insert(
2770 user_id,
2771 User {
2772 id: user_id,
2773 github_login: github_login.to_string(),
2774 email_address: email_address.map(str::to_string),
2775 admin,
2776 invite_code: None,
2777 invite_count: 0,
2778 connected_once: false,
2779 },
2780 );
2781 Ok(user_id)
2782 }
2783 }
2784
2785 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2786 unimplemented!()
2787 }
2788
2789 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2790 unimplemented!()
2791 }
2792
2793 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2794 unimplemented!()
2795 }
2796
2797 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2798 self.background.simulate_random_delay().await;
2799 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2800 }
2801
2802 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2803 self.background.simulate_random_delay().await;
2804 let users = self.users.lock();
2805 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2806 }
2807
2808 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2809 unimplemented!()
2810 }
2811
2812 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2813 self.background.simulate_random_delay().await;
2814 Ok(self
2815 .users
2816 .lock()
2817 .values()
2818 .find(|user| user.github_login == github_login)
2819 .cloned())
2820 }
2821
2822 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2823 unimplemented!()
2824 }
2825
2826 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2827 self.background.simulate_random_delay().await;
2828 let mut users = self.users.lock();
2829 let mut user = users
2830 .get_mut(&id)
2831 .ok_or_else(|| anyhow!("user not found"))?;
2832 user.connected_once = connected_once;
2833 Ok(())
2834 }
2835
2836 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2837 unimplemented!()
2838 }
2839
2840 // signups
2841
2842 async fn create_signup(&self, _signup: Signup) -> Result<()> {
2843 unimplemented!()
2844 }
2845
2846 async fn get_signup_invites(&self, _count: usize) -> Result<Vec<SignupInvite>> {
2847 unimplemented!()
2848 }
2849
2850 async fn record_signup_invites_sent(&self, _signups: &[SignupInvite]) -> Result<()> {
2851 unimplemented!()
2852 }
2853
2854 async fn redeem_signup(
2855 &self,
2856 _redemption: SignupRedemption,
2857 ) -> Result<UserId> {
2858 unimplemented!()
2859 }
2860
2861 // invite codes
2862
2863 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2864 unimplemented!()
2865 }
2866
2867 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2868 self.background.simulate_random_delay().await;
2869 Ok(None)
2870 }
2871
2872 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2873 unimplemented!()
2874 }
2875
2876 async fn redeem_invite_code(
2877 &self,
2878 _code: &str,
2879 _login: &str,
2880 _email_address: Option<&str>,
2881 ) -> Result<UserId> {
2882 unimplemented!()
2883 }
2884
2885 // projects
2886
2887 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2888 self.background.simulate_random_delay().await;
2889 if !self.users.lock().contains_key(&host_user_id) {
2890 Err(anyhow!("no such user"))?;
2891 }
2892
2893 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2894 self.projects.lock().insert(
2895 project_id,
2896 Project {
2897 id: project_id,
2898 host_user_id,
2899 unregistered: false,
2900 },
2901 );
2902 Ok(project_id)
2903 }
2904
2905 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2906 self.background.simulate_random_delay().await;
2907 self.projects
2908 .lock()
2909 .get_mut(&project_id)
2910 .ok_or_else(|| anyhow!("no such project"))?
2911 .unregistered = true;
2912 Ok(())
2913 }
2914
2915 async fn update_worktree_extensions(
2916 &self,
2917 project_id: ProjectId,
2918 worktree_id: u64,
2919 extensions: HashMap<String, u32>,
2920 ) -> Result<()> {
2921 self.background.simulate_random_delay().await;
2922 if !self.projects.lock().contains_key(&project_id) {
2923 Err(anyhow!("no such project"))?;
2924 }
2925
2926 for (extension, count) in extensions {
2927 self.worktree_extensions
2928 .lock()
2929 .insert((project_id, worktree_id, extension), count);
2930 }
2931
2932 Ok(())
2933 }
2934
2935 async fn get_project_extensions(
2936 &self,
2937 _project_id: ProjectId,
2938 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2939 unimplemented!()
2940 }
2941
2942 async fn record_user_activity(
2943 &self,
2944 _time_period: Range<OffsetDateTime>,
2945 _active_projects: &[(UserId, ProjectId)],
2946 ) -> Result<()> {
2947 unimplemented!()
2948 }
2949
2950 async fn get_active_user_count(
2951 &self,
2952 _time_period: Range<OffsetDateTime>,
2953 _min_duration: Duration,
2954 _only_collaborative: bool,
2955 ) -> Result<usize> {
2956 unimplemented!()
2957 }
2958
2959 async fn get_top_users_activity_summary(
2960 &self,
2961 _time_period: Range<OffsetDateTime>,
2962 _limit: usize,
2963 ) -> Result<Vec<UserActivitySummary>> {
2964 unimplemented!()
2965 }
2966
2967 async fn get_user_activity_timeline(
2968 &self,
2969 _time_period: Range<OffsetDateTime>,
2970 _user_id: UserId,
2971 ) -> Result<Vec<UserActivityPeriod>> {
2972 unimplemented!()
2973 }
2974
2975 // contacts
2976
2977 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2978 self.background.simulate_random_delay().await;
2979 let mut contacts = vec![Contact::Accepted {
2980 user_id: id,
2981 should_notify: false,
2982 }];
2983
2984 for contact in self.contacts.lock().iter() {
2985 if contact.requester_id == id {
2986 if contact.accepted {
2987 contacts.push(Contact::Accepted {
2988 user_id: contact.responder_id,
2989 should_notify: contact.should_notify,
2990 });
2991 } else {
2992 contacts.push(Contact::Outgoing {
2993 user_id: contact.responder_id,
2994 });
2995 }
2996 } else if contact.responder_id == id {
2997 if contact.accepted {
2998 contacts.push(Contact::Accepted {
2999 user_id: contact.requester_id,
3000 should_notify: false,
3001 });
3002 } else {
3003 contacts.push(Contact::Incoming {
3004 user_id: contact.requester_id,
3005 should_notify: contact.should_notify,
3006 });
3007 }
3008 }
3009 }
3010
3011 contacts.sort_unstable_by_key(|contact| contact.user_id());
3012 Ok(contacts)
3013 }
3014
3015 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
3016 self.background.simulate_random_delay().await;
3017 Ok(self.contacts.lock().iter().any(|contact| {
3018 contact.accepted
3019 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
3020 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
3021 }))
3022 }
3023
3024 async fn send_contact_request(
3025 &self,
3026 requester_id: UserId,
3027 responder_id: UserId,
3028 ) -> Result<()> {
3029 self.background.simulate_random_delay().await;
3030 let mut contacts = self.contacts.lock();
3031 for contact in contacts.iter_mut() {
3032 if contact.requester_id == requester_id && contact.responder_id == responder_id {
3033 if contact.accepted {
3034 Err(anyhow!("contact already exists"))?;
3035 } else {
3036 Err(anyhow!("contact already requested"))?;
3037 }
3038 }
3039 if contact.responder_id == requester_id && contact.requester_id == responder_id {
3040 if contact.accepted {
3041 Err(anyhow!("contact already exists"))?;
3042 } else {
3043 contact.accepted = true;
3044 contact.should_notify = false;
3045 return Ok(());
3046 }
3047 }
3048 }
3049 contacts.push(FakeContact {
3050 requester_id,
3051 responder_id,
3052 accepted: false,
3053 should_notify: true,
3054 });
3055 Ok(())
3056 }
3057
3058 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
3059 self.background.simulate_random_delay().await;
3060 self.contacts.lock().retain(|contact| {
3061 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
3062 });
3063 Ok(())
3064 }
3065
3066 async fn dismiss_contact_notification(
3067 &self,
3068 user_id: UserId,
3069 contact_user_id: UserId,
3070 ) -> Result<()> {
3071 self.background.simulate_random_delay().await;
3072 let mut contacts = self.contacts.lock();
3073 for contact in contacts.iter_mut() {
3074 if contact.requester_id == contact_user_id
3075 && contact.responder_id == user_id
3076 && !contact.accepted
3077 {
3078 contact.should_notify = false;
3079 return Ok(());
3080 }
3081 if contact.requester_id == user_id
3082 && contact.responder_id == contact_user_id
3083 && contact.accepted
3084 {
3085 contact.should_notify = false;
3086 return Ok(());
3087 }
3088 }
3089 Err(anyhow!("no such notification"))?
3090 }
3091
3092 async fn respond_to_contact_request(
3093 &self,
3094 responder_id: UserId,
3095 requester_id: UserId,
3096 accept: bool,
3097 ) -> Result<()> {
3098 self.background.simulate_random_delay().await;
3099 let mut contacts = self.contacts.lock();
3100 for (ix, contact) in contacts.iter_mut().enumerate() {
3101 if contact.requester_id == requester_id && contact.responder_id == responder_id {
3102 if contact.accepted {
3103 Err(anyhow!("contact already confirmed"))?;
3104 }
3105 if accept {
3106 contact.accepted = true;
3107 contact.should_notify = true;
3108 } else {
3109 contacts.remove(ix);
3110 }
3111 return Ok(());
3112 }
3113 }
3114 Err(anyhow!("no such contact request"))?
3115 }
3116
3117 async fn create_access_token_hash(
3118 &self,
3119 _user_id: UserId,
3120 _access_token_hash: &str,
3121 _max_access_token_count: usize,
3122 ) -> Result<()> {
3123 unimplemented!()
3124 }
3125
3126 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
3127 unimplemented!()
3128 }
3129
3130 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
3131 unimplemented!()
3132 }
3133
3134 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
3135 self.background.simulate_random_delay().await;
3136 let mut orgs = self.orgs.lock();
3137 if orgs.values().any(|org| org.slug == slug) {
3138 Err(anyhow!("org already exists"))?
3139 } else {
3140 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
3141 orgs.insert(
3142 org_id,
3143 Org {
3144 id: org_id,
3145 name: name.to_string(),
3146 slug: slug.to_string(),
3147 },
3148 );
3149 Ok(org_id)
3150 }
3151 }
3152
3153 async fn add_org_member(
3154 &self,
3155 org_id: OrgId,
3156 user_id: UserId,
3157 is_admin: bool,
3158 ) -> Result<()> {
3159 self.background.simulate_random_delay().await;
3160 if !self.orgs.lock().contains_key(&org_id) {
3161 Err(anyhow!("org does not exist"))?;
3162 }
3163 if !self.users.lock().contains_key(&user_id) {
3164 Err(anyhow!("user does not exist"))?;
3165 }
3166
3167 self.org_memberships
3168 .lock()
3169 .entry((org_id, user_id))
3170 .or_insert(is_admin);
3171 Ok(())
3172 }
3173
3174 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
3175 self.background.simulate_random_delay().await;
3176 if !self.orgs.lock().contains_key(&org_id) {
3177 Err(anyhow!("org does not exist"))?;
3178 }
3179
3180 let mut channels = self.channels.lock();
3181 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
3182 channels.insert(
3183 channel_id,
3184 Channel {
3185 id: channel_id,
3186 name: name.to_string(),
3187 owner_id: org_id.0,
3188 owner_is_user: false,
3189 },
3190 );
3191 Ok(channel_id)
3192 }
3193
3194 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
3195 self.background.simulate_random_delay().await;
3196 Ok(self
3197 .channels
3198 .lock()
3199 .values()
3200 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
3201 .cloned()
3202 .collect())
3203 }
3204
3205 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
3206 self.background.simulate_random_delay().await;
3207 let channels = self.channels.lock();
3208 let memberships = self.channel_memberships.lock();
3209 Ok(channels
3210 .values()
3211 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
3212 .cloned()
3213 .collect())
3214 }
3215
3216 async fn can_user_access_channel(
3217 &self,
3218 user_id: UserId,
3219 channel_id: ChannelId,
3220 ) -> Result<bool> {
3221 self.background.simulate_random_delay().await;
3222 Ok(self
3223 .channel_memberships
3224 .lock()
3225 .contains_key(&(channel_id, user_id)))
3226 }
3227
3228 async fn add_channel_member(
3229 &self,
3230 channel_id: ChannelId,
3231 user_id: UserId,
3232 is_admin: bool,
3233 ) -> Result<()> {
3234 self.background.simulate_random_delay().await;
3235 if !self.channels.lock().contains_key(&channel_id) {
3236 Err(anyhow!("channel does not exist"))?;
3237 }
3238 if !self.users.lock().contains_key(&user_id) {
3239 Err(anyhow!("user does not exist"))?;
3240 }
3241
3242 self.channel_memberships
3243 .lock()
3244 .entry((channel_id, user_id))
3245 .or_insert(is_admin);
3246 Ok(())
3247 }
3248
3249 async fn create_channel_message(
3250 &self,
3251 channel_id: ChannelId,
3252 sender_id: UserId,
3253 body: &str,
3254 timestamp: OffsetDateTime,
3255 nonce: u128,
3256 ) -> Result<MessageId> {
3257 self.background.simulate_random_delay().await;
3258 if !self.channels.lock().contains_key(&channel_id) {
3259 Err(anyhow!("channel does not exist"))?;
3260 }
3261 if !self.users.lock().contains_key(&sender_id) {
3262 Err(anyhow!("user does not exist"))?;
3263 }
3264
3265 let mut messages = self.channel_messages.lock();
3266 if let Some(message) = messages
3267 .values()
3268 .find(|message| message.nonce.as_u128() == nonce)
3269 {
3270 Ok(message.id)
3271 } else {
3272 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
3273 messages.insert(
3274 message_id,
3275 ChannelMessage {
3276 id: message_id,
3277 channel_id,
3278 sender_id,
3279 body: body.to_string(),
3280 sent_at: timestamp,
3281 nonce: Uuid::from_u128(nonce),
3282 },
3283 );
3284 Ok(message_id)
3285 }
3286 }
3287
3288 async fn get_channel_messages(
3289 &self,
3290 channel_id: ChannelId,
3291 count: usize,
3292 before_id: Option<MessageId>,
3293 ) -> Result<Vec<ChannelMessage>> {
3294 self.background.simulate_random_delay().await;
3295 let mut messages = self
3296 .channel_messages
3297 .lock()
3298 .values()
3299 .rev()
3300 .filter(|message| {
3301 message.channel_id == channel_id
3302 && message.id < before_id.unwrap_or(MessageId::MAX)
3303 })
3304 .take(count)
3305 .cloned()
3306 .collect::<Vec<_>>();
3307 messages.sort_unstable_by_key(|message| message.id);
3308 Ok(messages)
3309 }
3310
3311 async fn teardown(&self, _: &str) {}
3312
3313 #[cfg(test)]
3314 fn as_fake(&self) -> Option<&FakeDb> {
3315 Some(self)
3316 }
3317 }
3318
3319 fn build_background_executor() -> Arc<Background> {
3320 Deterministic::new(0).build_background()
3321 }
3322}