1use std::{ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::Serialize;
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::OffsetDateTime;
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
29 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
30 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
31 async fn destroy_user(&self, id: UserId) -> Result<()>;
32
33 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
34 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
35 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
36 async fn redeem_invite_code(
37 &self,
38 code: &str,
39 login: &str,
40 email_address: Option<&str>,
41 ) -> Result<UserId>;
42
43 /// Registers a new project for the given user.
44 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
45
46 /// Unregisters a project for the given project id.
47 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
48
49 /// Update file counts by extension for the given project and worktree.
50 async fn update_worktree_extensions(
51 &self,
52 project_id: ProjectId,
53 worktree_id: u64,
54 extensions: HashMap<String, usize>,
55 ) -> Result<()>;
56
57 /// Get the file counts on the given project keyed by their worktree and extension.
58 async fn get_project_extensions(
59 &self,
60 project_id: ProjectId,
61 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
62
63 /// Record which users have been active in which projects during
64 /// a given period of time.
65 async fn record_project_activity(
66 &self,
67 time_period: Range<OffsetDateTime>,
68 active_projects: &[(UserId, ProjectId)],
69 ) -> Result<()>;
70
71 /// Get the users that have been most active during the given time period,
72 /// along with the amount of time they have been active in each project.
73 async fn summarize_project_activity(
74 &self,
75 time_period: Range<OffsetDateTime>,
76 max_user_count: usize,
77 ) -> Result<Vec<UserActivitySummary>>;
78
79 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
80 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
81 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
82 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
83 async fn dismiss_contact_notification(
84 &self,
85 responder_id: UserId,
86 requester_id: UserId,
87 ) -> Result<()>;
88 async fn respond_to_contact_request(
89 &self,
90 responder_id: UserId,
91 requester_id: UserId,
92 accept: bool,
93 ) -> Result<()>;
94
95 async fn create_access_token_hash(
96 &self,
97 user_id: UserId,
98 access_token_hash: &str,
99 max_access_token_count: usize,
100 ) -> Result<()>;
101 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
102 #[cfg(any(test, feature = "seed-support"))]
103
104 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
105 #[cfg(any(test, feature = "seed-support"))]
106 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
107 #[cfg(any(test, feature = "seed-support"))]
108 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
109 #[cfg(any(test, feature = "seed-support"))]
110 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
111 #[cfg(any(test, feature = "seed-support"))]
112
113 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
114 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
115 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
116 -> Result<bool>;
117 #[cfg(any(test, feature = "seed-support"))]
118 async fn add_channel_member(
119 &self,
120 channel_id: ChannelId,
121 user_id: UserId,
122 is_admin: bool,
123 ) -> Result<()>;
124 async fn create_channel_message(
125 &self,
126 channel_id: ChannelId,
127 sender_id: UserId,
128 body: &str,
129 timestamp: OffsetDateTime,
130 nonce: u128,
131 ) -> Result<MessageId>;
132 async fn get_channel_messages(
133 &self,
134 channel_id: ChannelId,
135 count: usize,
136 before_id: Option<MessageId>,
137 ) -> Result<Vec<ChannelMessage>>;
138 #[cfg(test)]
139 async fn teardown(&self, url: &str);
140 #[cfg(test)]
141 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
142}
143
144pub struct PostgresDb {
145 pool: sqlx::PgPool,
146}
147
148impl PostgresDb {
149 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
150 let pool = DbOptions::new()
151 .max_connections(max_connections)
152 .connect(&url)
153 .await
154 .context("failed to connect to postgres database")?;
155 Ok(Self { pool })
156 }
157}
158
159#[async_trait]
160impl Db for PostgresDb {
161 // users
162
163 async fn create_user(
164 &self,
165 github_login: &str,
166 email_address: Option<&str>,
167 admin: bool,
168 ) -> Result<UserId> {
169 let query = "
170 INSERT INTO users (github_login, email_address, admin)
171 VALUES ($1, $2, $3)
172 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
173 RETURNING id
174 ";
175 Ok(sqlx::query_scalar(query)
176 .bind(github_login)
177 .bind(email_address)
178 .bind(admin)
179 .fetch_one(&self.pool)
180 .await
181 .map(UserId)?)
182 }
183
184 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
185 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
186 Ok(sqlx::query_as(query)
187 .bind(limit as i32)
188 .bind((page * limit) as i32)
189 .fetch_all(&self.pool)
190 .await?)
191 }
192
193 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
194 let mut query = QueryBuilder::new(
195 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
196 );
197 query.push_values(
198 users,
199 |mut query, (github_login, email_address, invite_count)| {
200 query
201 .push_bind(github_login)
202 .push_bind(email_address)
203 .push_bind(false)
204 .push_bind(nanoid!(16))
205 .push_bind(invite_count as i32);
206 },
207 );
208 query.push(
209 "
210 ON CONFLICT (github_login) DO UPDATE SET
211 github_login = excluded.github_login,
212 invite_count = excluded.invite_count,
213 invite_code = CASE WHEN users.invite_code IS NULL
214 THEN excluded.invite_code
215 ELSE users.invite_code
216 END
217 RETURNING id
218 ",
219 );
220
221 let rows = query.build().fetch_all(&self.pool).await?;
222 Ok(rows
223 .into_iter()
224 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
225 .collect())
226 }
227
228 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
229 let like_string = fuzzy_like_string(name_query);
230 let query = "
231 SELECT users.*
232 FROM users
233 WHERE github_login ILIKE $1
234 ORDER BY github_login <-> $2
235 LIMIT $3
236 ";
237 Ok(sqlx::query_as(query)
238 .bind(like_string)
239 .bind(name_query)
240 .bind(limit as i32)
241 .fetch_all(&self.pool)
242 .await?)
243 }
244
245 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
246 let users = self.get_users_by_ids(vec![id]).await?;
247 Ok(users.into_iter().next())
248 }
249
250 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
251 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
252 let query = "
253 SELECT users.*
254 FROM users
255 WHERE users.id = ANY ($1)
256 ";
257 Ok(sqlx::query_as(query)
258 .bind(&ids)
259 .fetch_all(&self.pool)
260 .await?)
261 }
262
263 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
264 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
265 Ok(sqlx::query_as(query)
266 .bind(github_login)
267 .fetch_optional(&self.pool)
268 .await?)
269 }
270
271 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
272 let query = "UPDATE users SET admin = $1 WHERE id = $2";
273 Ok(sqlx::query(query)
274 .bind(is_admin)
275 .bind(id.0)
276 .execute(&self.pool)
277 .await
278 .map(drop)?)
279 }
280
281 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
282 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
283 Ok(sqlx::query(query)
284 .bind(connected_once)
285 .bind(id.0)
286 .execute(&self.pool)
287 .await
288 .map(drop)?)
289 }
290
291 async fn destroy_user(&self, id: UserId) -> Result<()> {
292 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
293 sqlx::query(query)
294 .bind(id.0)
295 .execute(&self.pool)
296 .await
297 .map(drop)?;
298 let query = "DELETE FROM users WHERE id = $1;";
299 Ok(sqlx::query(query)
300 .bind(id.0)
301 .execute(&self.pool)
302 .await
303 .map(drop)?)
304 }
305
306 // invite codes
307
308 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
309 let mut tx = self.pool.begin().await?;
310 if count > 0 {
311 sqlx::query(
312 "
313 UPDATE users
314 SET invite_code = $1
315 WHERE id = $2 AND invite_code IS NULL
316 ",
317 )
318 .bind(nanoid!(16))
319 .bind(id)
320 .execute(&mut tx)
321 .await?;
322 }
323
324 sqlx::query(
325 "
326 UPDATE users
327 SET invite_count = $1
328 WHERE id = $2
329 ",
330 )
331 .bind(count as i32)
332 .bind(id)
333 .execute(&mut tx)
334 .await?;
335 tx.commit().await?;
336 Ok(())
337 }
338
339 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
340 let result: Option<(String, i32)> = sqlx::query_as(
341 "
342 SELECT invite_code, invite_count
343 FROM users
344 WHERE id = $1 AND invite_code IS NOT NULL
345 ",
346 )
347 .bind(id)
348 .fetch_optional(&self.pool)
349 .await?;
350 if let Some((code, count)) = result {
351 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
352 } else {
353 Ok(None)
354 }
355 }
356
357 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
358 sqlx::query_as(
359 "
360 SELECT *
361 FROM users
362 WHERE invite_code = $1
363 ",
364 )
365 .bind(code)
366 .fetch_optional(&self.pool)
367 .await?
368 .ok_or_else(|| {
369 Error::Http(
370 StatusCode::NOT_FOUND,
371 "that invite code does not exist".to_string(),
372 )
373 })
374 }
375
376 async fn redeem_invite_code(
377 &self,
378 code: &str,
379 login: &str,
380 email_address: Option<&str>,
381 ) -> Result<UserId> {
382 let mut tx = self.pool.begin().await?;
383
384 let inviter_id: Option<UserId> = sqlx::query_scalar(
385 "
386 UPDATE users
387 SET invite_count = invite_count - 1
388 WHERE
389 invite_code = $1 AND
390 invite_count > 0
391 RETURNING id
392 ",
393 )
394 .bind(code)
395 .fetch_optional(&mut tx)
396 .await?;
397
398 let inviter_id = match inviter_id {
399 Some(inviter_id) => inviter_id,
400 None => {
401 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
402 .bind(code)
403 .fetch_optional(&mut tx)
404 .await?
405 .is_some()
406 {
407 Err(Error::Http(
408 StatusCode::UNAUTHORIZED,
409 "no invites remaining".to_string(),
410 ))?
411 } else {
412 Err(Error::Http(
413 StatusCode::NOT_FOUND,
414 "invite code not found".to_string(),
415 ))?
416 }
417 }
418 };
419
420 let invitee_id = sqlx::query_scalar(
421 "
422 INSERT INTO users
423 (github_login, email_address, admin, inviter_id)
424 VALUES
425 ($1, $2, 'f', $3)
426 RETURNING id
427 ",
428 )
429 .bind(login)
430 .bind(email_address)
431 .bind(inviter_id)
432 .fetch_one(&mut tx)
433 .await
434 .map(UserId)?;
435
436 sqlx::query(
437 "
438 INSERT INTO contacts
439 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
440 VALUES
441 ($1, $2, 't', 't', 't')
442 ",
443 )
444 .bind(inviter_id)
445 .bind(invitee_id)
446 .execute(&mut tx)
447 .await?;
448
449 tx.commit().await?;
450 Ok(invitee_id)
451 }
452
453 // projects
454
455 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
456 Ok(sqlx::query_scalar(
457 "
458 INSERT INTO projects(host_user_id)
459 VALUES ($1)
460 RETURNING id
461 ",
462 )
463 .bind(host_user_id)
464 .fetch_one(&self.pool)
465 .await
466 .map(ProjectId)?)
467 }
468
469 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
470 sqlx::query(
471 "
472 UPDATE projects
473 SET unregistered = 't'
474 WHERE id = $1
475 ",
476 )
477 .bind(project_id)
478 .execute(&self.pool)
479 .await?;
480 Ok(())
481 }
482
483 async fn update_worktree_extensions(
484 &self,
485 project_id: ProjectId,
486 worktree_id: u64,
487 extensions: HashMap<String, usize>,
488 ) -> Result<()> {
489 let mut query = QueryBuilder::new(
490 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
491 );
492 query.push_values(extensions, |mut query, (extension, count)| {
493 query
494 .push_bind(project_id)
495 .push_bind(worktree_id as i32)
496 .push_bind(extension)
497 .push_bind(count as i32);
498 });
499 query.push(
500 "
501 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
502 count = excluded.count
503 ",
504 );
505 query.build().execute(&self.pool).await?;
506
507 Ok(())
508 }
509
510 async fn get_project_extensions(
511 &self,
512 project_id: ProjectId,
513 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
514 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
515 struct WorktreeExtension {
516 worktree_id: i32,
517 extension: String,
518 count: i32,
519 }
520
521 let query = "
522 SELECT worktree_id, extension, count
523 FROM worktree_extensions
524 WHERE project_id = $1
525 ";
526 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
527 .bind(&project_id)
528 .fetch_all(&self.pool)
529 .await?;
530
531 let mut extension_counts = HashMap::default();
532 for count in counts {
533 extension_counts
534 .entry(count.worktree_id as u64)
535 .or_insert(HashMap::default())
536 .insert(count.extension, count.count as usize);
537 }
538 Ok(extension_counts)
539 }
540
541 async fn record_project_activity(
542 &self,
543 time_period: Range<OffsetDateTime>,
544 projects: &[(UserId, ProjectId)],
545 ) -> Result<()> {
546 let query = "
547 INSERT INTO project_activity_periods
548 (ended_at, duration_millis, user_id, project_id)
549 VALUES
550 ($1, $2, $3, $4);
551 ";
552
553 let mut tx = self.pool.begin().await?;
554 let duration_millis =
555 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
556 for (user_id, project_id) in projects {
557 sqlx::query(query)
558 .bind(time_period.end)
559 .bind(duration_millis)
560 .bind(user_id)
561 .bind(project_id)
562 .execute(&mut tx)
563 .await?;
564 }
565 tx.commit().await?;
566
567 Ok(())
568 }
569
570 async fn summarize_project_activity(
571 &self,
572 time_period: Range<OffsetDateTime>,
573 max_user_count: usize,
574 ) -> Result<Vec<UserActivitySummary>> {
575 let query = "
576 WITH
577 project_durations AS (
578 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
579 FROM project_activity_periods
580 WHERE $1 <= ended_at AND ended_at <= $2
581 GROUP BY user_id, project_id
582 ),
583 user_durations AS (
584 SELECT user_id, SUM(project_duration) as total_duration
585 FROM project_durations
586 GROUP BY user_id
587 ORDER BY total_duration DESC
588 LIMIT $3
589 )
590 SELECT user_durations.user_id, users.github_login, project_id, project_duration
591 FROM user_durations, project_durations, users
592 WHERE
593 user_durations.user_id = project_durations.user_id AND
594 user_durations.user_id = users.id
595 ORDER BY user_id ASC, project_duration DESC
596 ";
597
598 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
599 .bind(time_period.start)
600 .bind(time_period.end)
601 .bind(max_user_count as i32)
602 .fetch(&self.pool);
603
604 let mut result = Vec::<UserActivitySummary>::new();
605 while let Some(row) = rows.next().await {
606 let (user_id, github_login, project_id, duration_millis) = row?;
607 let project_id = project_id;
608 let duration = Duration::from_millis(duration_millis as u64);
609 if let Some(last_summary) = result.last_mut() {
610 if last_summary.id == user_id {
611 last_summary.project_activity.push((project_id, duration));
612 continue;
613 }
614 }
615 result.push(UserActivitySummary {
616 id: user_id,
617 project_activity: vec![(project_id, duration)],
618 github_login,
619 });
620 }
621
622 Ok(result)
623 }
624
625 // contacts
626
627 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
628 let query = "
629 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
630 FROM contacts
631 WHERE user_id_a = $1 OR user_id_b = $1;
632 ";
633
634 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
635 .bind(user_id)
636 .fetch(&self.pool);
637
638 let mut contacts = vec![Contact::Accepted {
639 user_id,
640 should_notify: false,
641 }];
642 while let Some(row) = rows.next().await {
643 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
644
645 if user_id_a == user_id {
646 if accepted {
647 contacts.push(Contact::Accepted {
648 user_id: user_id_b,
649 should_notify: should_notify && a_to_b,
650 });
651 } else if a_to_b {
652 contacts.push(Contact::Outgoing { user_id: user_id_b })
653 } else {
654 contacts.push(Contact::Incoming {
655 user_id: user_id_b,
656 should_notify,
657 });
658 }
659 } else {
660 if accepted {
661 contacts.push(Contact::Accepted {
662 user_id: user_id_a,
663 should_notify: should_notify && !a_to_b,
664 });
665 } else if a_to_b {
666 contacts.push(Contact::Incoming {
667 user_id: user_id_a,
668 should_notify,
669 });
670 } else {
671 contacts.push(Contact::Outgoing { user_id: user_id_a });
672 }
673 }
674 }
675
676 contacts.sort_unstable_by_key(|contact| contact.user_id());
677
678 Ok(contacts)
679 }
680
681 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
682 let (id_a, id_b) = if user_id_1 < user_id_2 {
683 (user_id_1, user_id_2)
684 } else {
685 (user_id_2, user_id_1)
686 };
687
688 let query = "
689 SELECT 1 FROM contacts
690 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
691 LIMIT 1
692 ";
693 Ok(sqlx::query_scalar::<_, i32>(query)
694 .bind(id_a.0)
695 .bind(id_b.0)
696 .fetch_optional(&self.pool)
697 .await?
698 .is_some())
699 }
700
701 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
702 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
703 (sender_id, receiver_id, true)
704 } else {
705 (receiver_id, sender_id, false)
706 };
707 let query = "
708 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
709 VALUES ($1, $2, $3, 'f', 't')
710 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
711 SET
712 accepted = 't',
713 should_notify = 'f'
714 WHERE
715 NOT contacts.accepted AND
716 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
717 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
718 ";
719 let result = sqlx::query(query)
720 .bind(id_a.0)
721 .bind(id_b.0)
722 .bind(a_to_b)
723 .execute(&self.pool)
724 .await?;
725
726 if result.rows_affected() == 1 {
727 Ok(())
728 } else {
729 Err(anyhow!("contact already requested"))?
730 }
731 }
732
733 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
734 let (id_a, id_b) = if responder_id < requester_id {
735 (responder_id, requester_id)
736 } else {
737 (requester_id, responder_id)
738 };
739 let query = "
740 DELETE FROM contacts
741 WHERE user_id_a = $1 AND user_id_b = $2;
742 ";
743 let result = sqlx::query(query)
744 .bind(id_a.0)
745 .bind(id_b.0)
746 .execute(&self.pool)
747 .await?;
748
749 if result.rows_affected() == 1 {
750 Ok(())
751 } else {
752 Err(anyhow!("no such contact"))?
753 }
754 }
755
756 async fn dismiss_contact_notification(
757 &self,
758 user_id: UserId,
759 contact_user_id: UserId,
760 ) -> Result<()> {
761 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
762 (user_id, contact_user_id, true)
763 } else {
764 (contact_user_id, user_id, false)
765 };
766
767 let query = "
768 UPDATE contacts
769 SET should_notify = 'f'
770 WHERE
771 user_id_a = $1 AND user_id_b = $2 AND
772 (
773 (a_to_b = $3 AND accepted) OR
774 (a_to_b != $3 AND NOT accepted)
775 );
776 ";
777
778 let result = sqlx::query(query)
779 .bind(id_a.0)
780 .bind(id_b.0)
781 .bind(a_to_b)
782 .execute(&self.pool)
783 .await?;
784
785 if result.rows_affected() == 0 {
786 Err(anyhow!("no such contact request"))?;
787 }
788
789 Ok(())
790 }
791
792 async fn respond_to_contact_request(
793 &self,
794 responder_id: UserId,
795 requester_id: UserId,
796 accept: bool,
797 ) -> Result<()> {
798 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
799 (responder_id, requester_id, false)
800 } else {
801 (requester_id, responder_id, true)
802 };
803 let result = if accept {
804 let query = "
805 UPDATE contacts
806 SET accepted = 't', should_notify = 't'
807 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
808 ";
809 sqlx::query(query)
810 .bind(id_a.0)
811 .bind(id_b.0)
812 .bind(a_to_b)
813 .execute(&self.pool)
814 .await?
815 } else {
816 let query = "
817 DELETE FROM contacts
818 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
819 ";
820 sqlx::query(query)
821 .bind(id_a.0)
822 .bind(id_b.0)
823 .bind(a_to_b)
824 .execute(&self.pool)
825 .await?
826 };
827 if result.rows_affected() == 1 {
828 Ok(())
829 } else {
830 Err(anyhow!("no such contact request"))?
831 }
832 }
833
834 // access tokens
835
836 async fn create_access_token_hash(
837 &self,
838 user_id: UserId,
839 access_token_hash: &str,
840 max_access_token_count: usize,
841 ) -> Result<()> {
842 let insert_query = "
843 INSERT INTO access_tokens (user_id, hash)
844 VALUES ($1, $2);
845 ";
846 let cleanup_query = "
847 DELETE FROM access_tokens
848 WHERE id IN (
849 SELECT id from access_tokens
850 WHERE user_id = $1
851 ORDER BY id DESC
852 OFFSET $3
853 )
854 ";
855
856 let mut tx = self.pool.begin().await?;
857 sqlx::query(insert_query)
858 .bind(user_id.0)
859 .bind(access_token_hash)
860 .execute(&mut tx)
861 .await?;
862 sqlx::query(cleanup_query)
863 .bind(user_id.0)
864 .bind(access_token_hash)
865 .bind(max_access_token_count as i32)
866 .execute(&mut tx)
867 .await?;
868 Ok(tx.commit().await?)
869 }
870
871 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
872 let query = "
873 SELECT hash
874 FROM access_tokens
875 WHERE user_id = $1
876 ORDER BY id DESC
877 ";
878 Ok(sqlx::query_scalar(query)
879 .bind(user_id.0)
880 .fetch_all(&self.pool)
881 .await?)
882 }
883
884 // orgs
885
886 #[allow(unused)] // Help rust-analyzer
887 #[cfg(any(test, feature = "seed-support"))]
888 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
889 let query = "
890 SELECT *
891 FROM orgs
892 WHERE slug = $1
893 ";
894 Ok(sqlx::query_as(query)
895 .bind(slug)
896 .fetch_optional(&self.pool)
897 .await?)
898 }
899
900 #[cfg(any(test, feature = "seed-support"))]
901 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
902 let query = "
903 INSERT INTO orgs (name, slug)
904 VALUES ($1, $2)
905 RETURNING id
906 ";
907 Ok(sqlx::query_scalar(query)
908 .bind(name)
909 .bind(slug)
910 .fetch_one(&self.pool)
911 .await
912 .map(OrgId)?)
913 }
914
915 #[cfg(any(test, feature = "seed-support"))]
916 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
917 let query = "
918 INSERT INTO org_memberships (org_id, user_id, admin)
919 VALUES ($1, $2, $3)
920 ON CONFLICT DO NOTHING
921 ";
922 Ok(sqlx::query(query)
923 .bind(org_id.0)
924 .bind(user_id.0)
925 .bind(is_admin)
926 .execute(&self.pool)
927 .await
928 .map(drop)?)
929 }
930
931 // channels
932
933 #[cfg(any(test, feature = "seed-support"))]
934 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
935 let query = "
936 INSERT INTO channels (owner_id, owner_is_user, name)
937 VALUES ($1, false, $2)
938 RETURNING id
939 ";
940 Ok(sqlx::query_scalar(query)
941 .bind(org_id.0)
942 .bind(name)
943 .fetch_one(&self.pool)
944 .await
945 .map(ChannelId)?)
946 }
947
948 #[allow(unused)] // Help rust-analyzer
949 #[cfg(any(test, feature = "seed-support"))]
950 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
951 let query = "
952 SELECT *
953 FROM channels
954 WHERE
955 channels.owner_is_user = false AND
956 channels.owner_id = $1
957 ";
958 Ok(sqlx::query_as(query)
959 .bind(org_id.0)
960 .fetch_all(&self.pool)
961 .await?)
962 }
963
964 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
965 let query = "
966 SELECT
967 channels.*
968 FROM
969 channel_memberships, channels
970 WHERE
971 channel_memberships.user_id = $1 AND
972 channel_memberships.channel_id = channels.id
973 ";
974 Ok(sqlx::query_as(query)
975 .bind(user_id.0)
976 .fetch_all(&self.pool)
977 .await?)
978 }
979
980 async fn can_user_access_channel(
981 &self,
982 user_id: UserId,
983 channel_id: ChannelId,
984 ) -> Result<bool> {
985 let query = "
986 SELECT id
987 FROM channel_memberships
988 WHERE user_id = $1 AND channel_id = $2
989 LIMIT 1
990 ";
991 Ok(sqlx::query_scalar::<_, i32>(query)
992 .bind(user_id.0)
993 .bind(channel_id.0)
994 .fetch_optional(&self.pool)
995 .await
996 .map(|e| e.is_some())?)
997 }
998
999 #[cfg(any(test, feature = "seed-support"))]
1000 async fn add_channel_member(
1001 &self,
1002 channel_id: ChannelId,
1003 user_id: UserId,
1004 is_admin: bool,
1005 ) -> Result<()> {
1006 let query = "
1007 INSERT INTO channel_memberships (channel_id, user_id, admin)
1008 VALUES ($1, $2, $3)
1009 ON CONFLICT DO NOTHING
1010 ";
1011 Ok(sqlx::query(query)
1012 .bind(channel_id.0)
1013 .bind(user_id.0)
1014 .bind(is_admin)
1015 .execute(&self.pool)
1016 .await
1017 .map(drop)?)
1018 }
1019
1020 // messages
1021
1022 async fn create_channel_message(
1023 &self,
1024 channel_id: ChannelId,
1025 sender_id: UserId,
1026 body: &str,
1027 timestamp: OffsetDateTime,
1028 nonce: u128,
1029 ) -> Result<MessageId> {
1030 let query = "
1031 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1032 VALUES ($1, $2, $3, $4, $5)
1033 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1034 RETURNING id
1035 ";
1036 Ok(sqlx::query_scalar(query)
1037 .bind(channel_id.0)
1038 .bind(sender_id.0)
1039 .bind(body)
1040 .bind(timestamp)
1041 .bind(Uuid::from_u128(nonce))
1042 .fetch_one(&self.pool)
1043 .await
1044 .map(MessageId)?)
1045 }
1046
1047 async fn get_channel_messages(
1048 &self,
1049 channel_id: ChannelId,
1050 count: usize,
1051 before_id: Option<MessageId>,
1052 ) -> Result<Vec<ChannelMessage>> {
1053 let query = r#"
1054 SELECT * FROM (
1055 SELECT
1056 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1057 FROM
1058 channel_messages
1059 WHERE
1060 channel_id = $1 AND
1061 id < $2
1062 ORDER BY id DESC
1063 LIMIT $3
1064 ) as recent_messages
1065 ORDER BY id ASC
1066 "#;
1067 Ok(sqlx::query_as(query)
1068 .bind(channel_id.0)
1069 .bind(before_id.unwrap_or(MessageId::MAX))
1070 .bind(count as i64)
1071 .fetch_all(&self.pool)
1072 .await?)
1073 }
1074
1075 #[cfg(test)]
1076 async fn teardown(&self, url: &str) {
1077 use util::ResultExt;
1078
1079 let query = "
1080 SELECT pg_terminate_backend(pg_stat_activity.pid)
1081 FROM pg_stat_activity
1082 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1083 ";
1084 sqlx::query(query).execute(&self.pool).await.log_err();
1085 self.pool.close().await;
1086 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1087 .await
1088 .log_err();
1089 }
1090
1091 #[cfg(test)]
1092 fn as_fake(&self) -> Option<&tests::FakeDb> {
1093 None
1094 }
1095}
1096
1097macro_rules! id_type {
1098 ($name:ident) => {
1099 #[derive(
1100 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
1101 )]
1102 #[sqlx(transparent)]
1103 #[serde(transparent)]
1104 pub struct $name(pub i32);
1105
1106 impl $name {
1107 #[allow(unused)]
1108 pub const MAX: Self = Self(i32::MAX);
1109
1110 #[allow(unused)]
1111 pub fn from_proto(value: u64) -> Self {
1112 Self(value as i32)
1113 }
1114
1115 #[allow(unused)]
1116 pub fn to_proto(&self) -> u64 {
1117 self.0 as u64
1118 }
1119 }
1120
1121 impl std::fmt::Display for $name {
1122 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1123 self.0.fmt(f)
1124 }
1125 }
1126 };
1127}
1128
1129id_type!(UserId);
1130#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1131pub struct User {
1132 pub id: UserId,
1133 pub github_login: String,
1134 pub email_address: Option<String>,
1135 pub admin: bool,
1136 pub invite_code: Option<String>,
1137 pub invite_count: i32,
1138 pub connected_once: bool,
1139}
1140
1141id_type!(ProjectId);
1142#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1143pub struct Project {
1144 pub id: ProjectId,
1145 pub host_user_id: UserId,
1146 pub unregistered: bool,
1147}
1148
1149#[derive(Clone, Debug, PartialEq, Serialize)]
1150pub struct UserActivitySummary {
1151 pub id: UserId,
1152 pub github_login: String,
1153 pub project_activity: Vec<(ProjectId, Duration)>,
1154}
1155
1156id_type!(OrgId);
1157#[derive(FromRow)]
1158pub struct Org {
1159 pub id: OrgId,
1160 pub name: String,
1161 pub slug: String,
1162}
1163
1164id_type!(ChannelId);
1165#[derive(Clone, Debug, FromRow, Serialize)]
1166pub struct Channel {
1167 pub id: ChannelId,
1168 pub name: String,
1169 pub owner_id: i32,
1170 pub owner_is_user: bool,
1171}
1172
1173id_type!(MessageId);
1174#[derive(Clone, Debug, FromRow)]
1175pub struct ChannelMessage {
1176 pub id: MessageId,
1177 pub channel_id: ChannelId,
1178 pub sender_id: UserId,
1179 pub body: String,
1180 pub sent_at: OffsetDateTime,
1181 pub nonce: Uuid,
1182}
1183
1184#[derive(Clone, Debug, PartialEq, Eq)]
1185pub enum Contact {
1186 Accepted {
1187 user_id: UserId,
1188 should_notify: bool,
1189 },
1190 Outgoing {
1191 user_id: UserId,
1192 },
1193 Incoming {
1194 user_id: UserId,
1195 should_notify: bool,
1196 },
1197}
1198
1199impl Contact {
1200 pub fn user_id(&self) -> UserId {
1201 match self {
1202 Contact::Accepted { user_id, .. } => *user_id,
1203 Contact::Outgoing { user_id } => *user_id,
1204 Contact::Incoming { user_id, .. } => *user_id,
1205 }
1206 }
1207}
1208
1209#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1210pub struct IncomingContactRequest {
1211 pub requester_id: UserId,
1212 pub should_notify: bool,
1213}
1214
1215fn fuzzy_like_string(string: &str) -> String {
1216 let mut result = String::with_capacity(string.len() * 2 + 1);
1217 for c in string.chars() {
1218 if c.is_alphanumeric() {
1219 result.push('%');
1220 result.push(c);
1221 }
1222 }
1223 result.push('%');
1224 result
1225}
1226
1227#[cfg(test)]
1228pub mod tests {
1229 use super::*;
1230 use anyhow::anyhow;
1231 use collections::BTreeMap;
1232 use gpui::executor::Background;
1233 use lazy_static::lazy_static;
1234 use parking_lot::Mutex;
1235 use rand::prelude::*;
1236 use sqlx::{
1237 migrate::{MigrateDatabase, Migrator},
1238 Postgres,
1239 };
1240 use std::{path::Path, sync::Arc};
1241 use util::post_inc;
1242
1243 #[tokio::test(flavor = "multi_thread")]
1244 async fn test_get_users_by_ids() {
1245 for test_db in [
1246 TestDb::postgres().await,
1247 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1248 ] {
1249 let db = test_db.db();
1250
1251 let user = db.create_user("user", None, false).await.unwrap();
1252 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1253 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1254 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1255
1256 assert_eq!(
1257 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1258 .await
1259 .unwrap(),
1260 vec![
1261 User {
1262 id: user,
1263 github_login: "user".to_string(),
1264 admin: false,
1265 ..Default::default()
1266 },
1267 User {
1268 id: friend1,
1269 github_login: "friend-1".to_string(),
1270 admin: false,
1271 ..Default::default()
1272 },
1273 User {
1274 id: friend2,
1275 github_login: "friend-2".to_string(),
1276 admin: false,
1277 ..Default::default()
1278 },
1279 User {
1280 id: friend3,
1281 github_login: "friend-3".to_string(),
1282 admin: false,
1283 ..Default::default()
1284 }
1285 ]
1286 );
1287 }
1288 }
1289
1290 #[tokio::test(flavor = "multi_thread")]
1291 async fn test_create_users() {
1292 let db = TestDb::postgres().await;
1293 let db = db.db();
1294
1295 // Create the first batch of users, ensuring invite counts are assigned
1296 // correctly and the respective invite codes are unique.
1297 let user_ids_batch_1 = db
1298 .create_users(vec![
1299 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1300 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1301 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1302 ])
1303 .await
1304 .unwrap();
1305 assert_eq!(user_ids_batch_1.len(), 3);
1306
1307 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1308 assert_eq!(users.len(), 3);
1309 assert_eq!(users[0].github_login, "user1");
1310 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1311 assert_eq!(users[0].invite_count, 5);
1312 assert_eq!(users[1].github_login, "user2");
1313 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1314 assert_eq!(users[1].invite_count, 4);
1315 assert_eq!(users[2].github_login, "user3");
1316 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1317 assert_eq!(users[2].invite_count, 3);
1318
1319 let invite_code_1 = users[0].invite_code.clone().unwrap();
1320 let invite_code_2 = users[1].invite_code.clone().unwrap();
1321 let invite_code_3 = users[2].invite_code.clone().unwrap();
1322 assert_ne!(invite_code_1, invite_code_2);
1323 assert_ne!(invite_code_1, invite_code_3);
1324 assert_ne!(invite_code_2, invite_code_3);
1325
1326 // Create the second batch of users and include a user that is already in the database, ensuring
1327 // the invite count for the existing user is updated without changing their invite code.
1328 let user_ids_batch_2 = db
1329 .create_users(vec![
1330 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1331 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1332 ])
1333 .await
1334 .unwrap();
1335 assert_eq!(user_ids_batch_2.len(), 2);
1336 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1337
1338 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1339 assert_eq!(users.len(), 2);
1340 assert_eq!(users[0].github_login, "user2");
1341 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1342 assert_eq!(users[0].invite_count, 10);
1343 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1344 assert_eq!(users[1].github_login, "user4");
1345 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1346 assert_eq!(users[1].invite_count, 2);
1347
1348 let invite_code_4 = users[1].invite_code.clone().unwrap();
1349 assert_ne!(invite_code_4, invite_code_1);
1350 assert_ne!(invite_code_4, invite_code_2);
1351 assert_ne!(invite_code_4, invite_code_3);
1352 }
1353
1354 #[tokio::test(flavor = "multi_thread")]
1355 async fn test_project_activity() {
1356 let test_db = TestDb::postgres().await;
1357 let db = test_db.db();
1358
1359 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1360 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1361 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1362 let project_1 = db.register_project(user_1).await.unwrap();
1363 let project_2 = db.register_project(user_2).await.unwrap();
1364 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1365
1366 // User 2 opens a project
1367 let t1 = t0 + Duration::from_secs(10);
1368 db.record_project_activity(t0..t1, &[(user_2, project_2)])
1369 .await
1370 .unwrap();
1371
1372 let t2 = t1 + Duration::from_secs(10);
1373 db.record_project_activity(t1..t2, &[(user_2, project_2)])
1374 .await
1375 .unwrap();
1376
1377 // User 1 joins the project
1378 let t3 = t2 + Duration::from_secs(10);
1379 db.record_project_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1380 .await
1381 .unwrap();
1382
1383 // User 1 opens another project
1384 let t4 = t3 + Duration::from_secs(10);
1385 db.record_project_activity(
1386 t3..t4,
1387 &[
1388 (user_2, project_2),
1389 (user_1, project_2),
1390 (user_1, project_1),
1391 ],
1392 )
1393 .await
1394 .unwrap();
1395
1396 // User 3 joins that project
1397 let t5 = t4 + Duration::from_secs(10);
1398 db.record_project_activity(
1399 t4..t5,
1400 &[
1401 (user_2, project_2),
1402 (user_1, project_2),
1403 (user_1, project_1),
1404 (user_3, project_1),
1405 ],
1406 )
1407 .await
1408 .unwrap();
1409
1410 // User 2 leaves
1411 let t6 = t5 + Duration::from_secs(5);
1412 db.record_project_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1413 .await
1414 .unwrap();
1415
1416 let summary = db.summarize_project_activity(t0..t6, 10).await.unwrap();
1417 assert_eq!(
1418 summary,
1419 &[
1420 UserActivitySummary {
1421 id: user_1,
1422 github_login: "user_1".to_string(),
1423 project_activity: vec![
1424 (project_2, Duration::from_secs(30)),
1425 (project_1, Duration::from_secs(25))
1426 ]
1427 },
1428 UserActivitySummary {
1429 id: user_2,
1430 github_login: "user_2".to_string(),
1431 project_activity: vec![(project_2, Duration::from_secs(50))]
1432 },
1433 UserActivitySummary {
1434 id: user_3,
1435 github_login: "user_3".to_string(),
1436 project_activity: vec![(project_1, Duration::from_secs(15))]
1437 },
1438 ]
1439 );
1440 }
1441
1442 #[tokio::test(flavor = "multi_thread")]
1443 async fn test_recent_channel_messages() {
1444 for test_db in [
1445 TestDb::postgres().await,
1446 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1447 ] {
1448 let db = test_db.db();
1449 let user = db.create_user("user", None, false).await.unwrap();
1450 let org = db.create_org("org", "org").await.unwrap();
1451 let channel = db.create_org_channel(org, "channel").await.unwrap();
1452 for i in 0..10 {
1453 db.create_channel_message(
1454 channel,
1455 user,
1456 &i.to_string(),
1457 OffsetDateTime::now_utc(),
1458 i,
1459 )
1460 .await
1461 .unwrap();
1462 }
1463
1464 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1465 assert_eq!(
1466 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1467 ["5", "6", "7", "8", "9"]
1468 );
1469
1470 let prev_messages = db
1471 .get_channel_messages(channel, 4, Some(messages[0].id))
1472 .await
1473 .unwrap();
1474 assert_eq!(
1475 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1476 ["1", "2", "3", "4"]
1477 );
1478 }
1479 }
1480
1481 #[tokio::test(flavor = "multi_thread")]
1482 async fn test_channel_message_nonces() {
1483 for test_db in [
1484 TestDb::postgres().await,
1485 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1486 ] {
1487 let db = test_db.db();
1488 let user = db.create_user("user", None, false).await.unwrap();
1489 let org = db.create_org("org", "org").await.unwrap();
1490 let channel = db.create_org_channel(org, "channel").await.unwrap();
1491
1492 let msg1_id = db
1493 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1494 .await
1495 .unwrap();
1496 let msg2_id = db
1497 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1498 .await
1499 .unwrap();
1500 let msg3_id = db
1501 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1502 .await
1503 .unwrap();
1504 let msg4_id = db
1505 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1506 .await
1507 .unwrap();
1508
1509 assert_ne!(msg1_id, msg2_id);
1510 assert_eq!(msg1_id, msg3_id);
1511 assert_eq!(msg2_id, msg4_id);
1512 }
1513 }
1514
1515 #[tokio::test(flavor = "multi_thread")]
1516 async fn test_create_access_tokens() {
1517 let test_db = TestDb::postgres().await;
1518 let db = test_db.db();
1519 let user = db.create_user("the-user", None, false).await.unwrap();
1520
1521 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1522 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1523 assert_eq!(
1524 db.get_access_token_hashes(user).await.unwrap(),
1525 &["h2".to_string(), "h1".to_string()]
1526 );
1527
1528 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1529 assert_eq!(
1530 db.get_access_token_hashes(user).await.unwrap(),
1531 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1532 );
1533
1534 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1535 assert_eq!(
1536 db.get_access_token_hashes(user).await.unwrap(),
1537 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1538 );
1539
1540 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1541 assert_eq!(
1542 db.get_access_token_hashes(user).await.unwrap(),
1543 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1544 );
1545 }
1546
1547 #[test]
1548 fn test_fuzzy_like_string() {
1549 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1550 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1551 assert_eq!(fuzzy_like_string(" z "), "%z%");
1552 }
1553
1554 #[tokio::test(flavor = "multi_thread")]
1555 async fn test_fuzzy_search_users() {
1556 let test_db = TestDb::postgres().await;
1557 let db = test_db.db();
1558 for github_login in [
1559 "California",
1560 "colorado",
1561 "oregon",
1562 "washington",
1563 "florida",
1564 "delaware",
1565 "rhode-island",
1566 ] {
1567 db.create_user(github_login, None, false).await.unwrap();
1568 }
1569
1570 assert_eq!(
1571 fuzzy_search_user_names(db, "clr").await,
1572 &["colorado", "California"]
1573 );
1574 assert_eq!(
1575 fuzzy_search_user_names(db, "ro").await,
1576 &["rhode-island", "colorado", "oregon"],
1577 );
1578
1579 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1580 db.fuzzy_search_users(query, 10)
1581 .await
1582 .unwrap()
1583 .into_iter()
1584 .map(|user| user.github_login)
1585 .collect::<Vec<_>>()
1586 }
1587 }
1588
1589 #[tokio::test(flavor = "multi_thread")]
1590 async fn test_add_contacts() {
1591 for test_db in [
1592 TestDb::postgres().await,
1593 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1594 ] {
1595 let db = test_db.db();
1596
1597 let user_1 = db.create_user("user1", None, false).await.unwrap();
1598 let user_2 = db.create_user("user2", None, false).await.unwrap();
1599 let user_3 = db.create_user("user3", None, false).await.unwrap();
1600
1601 // User starts with no contacts
1602 assert_eq!(
1603 db.get_contacts(user_1).await.unwrap(),
1604 vec![Contact::Accepted {
1605 user_id: user_1,
1606 should_notify: false
1607 }],
1608 );
1609
1610 // User requests a contact. Both users see the pending request.
1611 db.send_contact_request(user_1, user_2).await.unwrap();
1612 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1613 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1614 assert_eq!(
1615 db.get_contacts(user_1).await.unwrap(),
1616 &[
1617 Contact::Accepted {
1618 user_id: user_1,
1619 should_notify: false
1620 },
1621 Contact::Outgoing { user_id: user_2 }
1622 ],
1623 );
1624 assert_eq!(
1625 db.get_contacts(user_2).await.unwrap(),
1626 &[
1627 Contact::Incoming {
1628 user_id: user_1,
1629 should_notify: true
1630 },
1631 Contact::Accepted {
1632 user_id: user_2,
1633 should_notify: false
1634 },
1635 ]
1636 );
1637
1638 // User 2 dismisses the contact request notification without accepting or rejecting.
1639 // We shouldn't notify them again.
1640 db.dismiss_contact_notification(user_1, user_2)
1641 .await
1642 .unwrap_err();
1643 db.dismiss_contact_notification(user_2, user_1)
1644 .await
1645 .unwrap();
1646 assert_eq!(
1647 db.get_contacts(user_2).await.unwrap(),
1648 &[
1649 Contact::Incoming {
1650 user_id: user_1,
1651 should_notify: false
1652 },
1653 Contact::Accepted {
1654 user_id: user_2,
1655 should_notify: false
1656 },
1657 ]
1658 );
1659
1660 // User can't accept their own contact request
1661 db.respond_to_contact_request(user_1, user_2, true)
1662 .await
1663 .unwrap_err();
1664
1665 // User accepts a contact request. Both users see the contact.
1666 db.respond_to_contact_request(user_2, user_1, true)
1667 .await
1668 .unwrap();
1669 assert_eq!(
1670 db.get_contacts(user_1).await.unwrap(),
1671 &[
1672 Contact::Accepted {
1673 user_id: user_1,
1674 should_notify: false
1675 },
1676 Contact::Accepted {
1677 user_id: user_2,
1678 should_notify: true
1679 }
1680 ],
1681 );
1682 assert!(db.has_contact(user_1, user_2).await.unwrap());
1683 assert!(db.has_contact(user_2, user_1).await.unwrap());
1684 assert_eq!(
1685 db.get_contacts(user_2).await.unwrap(),
1686 &[
1687 Contact::Accepted {
1688 user_id: user_1,
1689 should_notify: false,
1690 },
1691 Contact::Accepted {
1692 user_id: user_2,
1693 should_notify: false,
1694 },
1695 ]
1696 );
1697
1698 // Users cannot re-request existing contacts.
1699 db.send_contact_request(user_1, user_2).await.unwrap_err();
1700 db.send_contact_request(user_2, user_1).await.unwrap_err();
1701
1702 // Users can't dismiss notifications of them accepting other users' requests.
1703 db.dismiss_contact_notification(user_2, user_1)
1704 .await
1705 .unwrap_err();
1706 assert_eq!(
1707 db.get_contacts(user_1).await.unwrap(),
1708 &[
1709 Contact::Accepted {
1710 user_id: user_1,
1711 should_notify: false
1712 },
1713 Contact::Accepted {
1714 user_id: user_2,
1715 should_notify: true,
1716 },
1717 ]
1718 );
1719
1720 // Users can dismiss notifications of other users accepting their requests.
1721 db.dismiss_contact_notification(user_1, user_2)
1722 .await
1723 .unwrap();
1724 assert_eq!(
1725 db.get_contacts(user_1).await.unwrap(),
1726 &[
1727 Contact::Accepted {
1728 user_id: user_1,
1729 should_notify: false
1730 },
1731 Contact::Accepted {
1732 user_id: user_2,
1733 should_notify: false,
1734 },
1735 ]
1736 );
1737
1738 // Users send each other concurrent contact requests and
1739 // see that they are immediately accepted.
1740 db.send_contact_request(user_1, user_3).await.unwrap();
1741 db.send_contact_request(user_3, user_1).await.unwrap();
1742 assert_eq!(
1743 db.get_contacts(user_1).await.unwrap(),
1744 &[
1745 Contact::Accepted {
1746 user_id: user_1,
1747 should_notify: false
1748 },
1749 Contact::Accepted {
1750 user_id: user_2,
1751 should_notify: false,
1752 },
1753 Contact::Accepted {
1754 user_id: user_3,
1755 should_notify: false
1756 },
1757 ]
1758 );
1759 assert_eq!(
1760 db.get_contacts(user_3).await.unwrap(),
1761 &[
1762 Contact::Accepted {
1763 user_id: user_1,
1764 should_notify: false
1765 },
1766 Contact::Accepted {
1767 user_id: user_3,
1768 should_notify: false
1769 }
1770 ],
1771 );
1772
1773 // User declines a contact request. Both users see that it is gone.
1774 db.send_contact_request(user_2, user_3).await.unwrap();
1775 db.respond_to_contact_request(user_3, user_2, false)
1776 .await
1777 .unwrap();
1778 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1779 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1780 assert_eq!(
1781 db.get_contacts(user_2).await.unwrap(),
1782 &[
1783 Contact::Accepted {
1784 user_id: user_1,
1785 should_notify: false
1786 },
1787 Contact::Accepted {
1788 user_id: user_2,
1789 should_notify: false
1790 }
1791 ]
1792 );
1793 assert_eq!(
1794 db.get_contacts(user_3).await.unwrap(),
1795 &[
1796 Contact::Accepted {
1797 user_id: user_1,
1798 should_notify: false
1799 },
1800 Contact::Accepted {
1801 user_id: user_3,
1802 should_notify: false
1803 }
1804 ],
1805 );
1806 }
1807 }
1808
1809 #[tokio::test(flavor = "multi_thread")]
1810 async fn test_invite_codes() {
1811 let postgres = TestDb::postgres().await;
1812 let db = postgres.db();
1813 let user1 = db.create_user("user-1", None, false).await.unwrap();
1814
1815 // Initially, user 1 has no invite code
1816 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1817
1818 // Setting invite count to 0 when no code is assigned does not assign a new code
1819 db.set_invite_count(user1, 0).await.unwrap();
1820 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1821
1822 // User 1 creates an invite code that can be used twice.
1823 db.set_invite_count(user1, 2).await.unwrap();
1824 let (invite_code, invite_count) =
1825 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1826 assert_eq!(invite_count, 2);
1827
1828 // User 2 redeems the invite code and becomes a contact of user 1.
1829 let user2 = db
1830 .redeem_invite_code(&invite_code, "user-2", None)
1831 .await
1832 .unwrap();
1833 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1834 assert_eq!(invite_count, 1);
1835 assert_eq!(
1836 db.get_contacts(user1).await.unwrap(),
1837 [
1838 Contact::Accepted {
1839 user_id: user1,
1840 should_notify: false
1841 },
1842 Contact::Accepted {
1843 user_id: user2,
1844 should_notify: true
1845 }
1846 ]
1847 );
1848 assert_eq!(
1849 db.get_contacts(user2).await.unwrap(),
1850 [
1851 Contact::Accepted {
1852 user_id: user1,
1853 should_notify: false
1854 },
1855 Contact::Accepted {
1856 user_id: user2,
1857 should_notify: false
1858 }
1859 ]
1860 );
1861
1862 // User 3 redeems the invite code and becomes a contact of user 1.
1863 let user3 = db
1864 .redeem_invite_code(&invite_code, "user-3", None)
1865 .await
1866 .unwrap();
1867 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1868 assert_eq!(invite_count, 0);
1869 assert_eq!(
1870 db.get_contacts(user1).await.unwrap(),
1871 [
1872 Contact::Accepted {
1873 user_id: user1,
1874 should_notify: false
1875 },
1876 Contact::Accepted {
1877 user_id: user2,
1878 should_notify: true
1879 },
1880 Contact::Accepted {
1881 user_id: user3,
1882 should_notify: true
1883 }
1884 ]
1885 );
1886 assert_eq!(
1887 db.get_contacts(user3).await.unwrap(),
1888 [
1889 Contact::Accepted {
1890 user_id: user1,
1891 should_notify: false
1892 },
1893 Contact::Accepted {
1894 user_id: user3,
1895 should_notify: false
1896 },
1897 ]
1898 );
1899
1900 // Trying to reedem the code for the third time results in an error.
1901 db.redeem_invite_code(&invite_code, "user-4", None)
1902 .await
1903 .unwrap_err();
1904
1905 // Invite count can be updated after the code has been created.
1906 db.set_invite_count(user1, 2).await.unwrap();
1907 let (latest_code, invite_count) =
1908 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1909 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1910 assert_eq!(invite_count, 2);
1911
1912 // User 4 can now redeem the invite code and becomes a contact of user 1.
1913 let user4 = db
1914 .redeem_invite_code(&invite_code, "user-4", None)
1915 .await
1916 .unwrap();
1917 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1918 assert_eq!(invite_count, 1);
1919 assert_eq!(
1920 db.get_contacts(user1).await.unwrap(),
1921 [
1922 Contact::Accepted {
1923 user_id: user1,
1924 should_notify: false
1925 },
1926 Contact::Accepted {
1927 user_id: user2,
1928 should_notify: true
1929 },
1930 Contact::Accepted {
1931 user_id: user3,
1932 should_notify: true
1933 },
1934 Contact::Accepted {
1935 user_id: user4,
1936 should_notify: true
1937 }
1938 ]
1939 );
1940 assert_eq!(
1941 db.get_contacts(user4).await.unwrap(),
1942 [
1943 Contact::Accepted {
1944 user_id: user1,
1945 should_notify: false
1946 },
1947 Contact::Accepted {
1948 user_id: user4,
1949 should_notify: false
1950 },
1951 ]
1952 );
1953
1954 // An existing user cannot redeem invite codes.
1955 db.redeem_invite_code(&invite_code, "user-2", None)
1956 .await
1957 .unwrap_err();
1958 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1959 assert_eq!(invite_count, 1);
1960 }
1961
1962 pub struct TestDb {
1963 pub db: Option<Arc<dyn Db>>,
1964 pub url: String,
1965 }
1966
1967 impl TestDb {
1968 pub async fn postgres() -> Self {
1969 lazy_static! {
1970 static ref LOCK: Mutex<()> = Mutex::new(());
1971 }
1972
1973 let _guard = LOCK.lock();
1974 let mut rng = StdRng::from_entropy();
1975 let name = format!("zed-test-{}", rng.gen::<u128>());
1976 let url = format!("postgres://postgres@localhost/{}", name);
1977 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
1978 Postgres::create_database(&url)
1979 .await
1980 .expect("failed to create test db");
1981 let db = PostgresDb::new(&url, 5).await.unwrap();
1982 let migrator = Migrator::new(migrations_path).await.unwrap();
1983 migrator.run(&db.pool).await.unwrap();
1984 Self {
1985 db: Some(Arc::new(db)),
1986 url,
1987 }
1988 }
1989
1990 pub fn fake(background: Arc<Background>) -> Self {
1991 Self {
1992 db: Some(Arc::new(FakeDb::new(background))),
1993 url: Default::default(),
1994 }
1995 }
1996
1997 pub fn db(&self) -> &Arc<dyn Db> {
1998 self.db.as_ref().unwrap()
1999 }
2000 }
2001
2002 impl Drop for TestDb {
2003 fn drop(&mut self) {
2004 if let Some(db) = self.db.take() {
2005 futures::executor::block_on(db.teardown(&self.url));
2006 }
2007 }
2008 }
2009
2010 pub struct FakeDb {
2011 background: Arc<Background>,
2012 pub users: Mutex<BTreeMap<UserId, User>>,
2013 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2014 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), usize>>,
2015 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2016 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2017 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2018 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2019 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2020 pub contacts: Mutex<Vec<FakeContact>>,
2021 next_channel_message_id: Mutex<i32>,
2022 next_user_id: Mutex<i32>,
2023 next_org_id: Mutex<i32>,
2024 next_channel_id: Mutex<i32>,
2025 next_project_id: Mutex<i32>,
2026 }
2027
2028 #[derive(Debug)]
2029 pub struct FakeContact {
2030 pub requester_id: UserId,
2031 pub responder_id: UserId,
2032 pub accepted: bool,
2033 pub should_notify: bool,
2034 }
2035
2036 impl FakeDb {
2037 pub fn new(background: Arc<Background>) -> Self {
2038 Self {
2039 background,
2040 users: Default::default(),
2041 next_user_id: Mutex::new(1),
2042 projects: Default::default(),
2043 worktree_extensions: Default::default(),
2044 next_project_id: Mutex::new(1),
2045 orgs: Default::default(),
2046 next_org_id: Mutex::new(1),
2047 org_memberships: Default::default(),
2048 channels: Default::default(),
2049 next_channel_id: Mutex::new(1),
2050 channel_memberships: Default::default(),
2051 channel_messages: Default::default(),
2052 next_channel_message_id: Mutex::new(1),
2053 contacts: Default::default(),
2054 }
2055 }
2056 }
2057
2058 #[async_trait]
2059 impl Db for FakeDb {
2060 async fn create_user(
2061 &self,
2062 github_login: &str,
2063 email_address: Option<&str>,
2064 admin: bool,
2065 ) -> Result<UserId> {
2066 self.background.simulate_random_delay().await;
2067
2068 let mut users = self.users.lock();
2069 if let Some(user) = users
2070 .values()
2071 .find(|user| user.github_login == github_login)
2072 {
2073 Ok(user.id)
2074 } else {
2075 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2076 users.insert(
2077 user_id,
2078 User {
2079 id: user_id,
2080 github_login: github_login.to_string(),
2081 email_address: email_address.map(str::to_string),
2082 admin,
2083 invite_code: None,
2084 invite_count: 0,
2085 connected_once: false,
2086 },
2087 );
2088 Ok(user_id)
2089 }
2090 }
2091
2092 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2093 unimplemented!()
2094 }
2095
2096 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2097 unimplemented!()
2098 }
2099
2100 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2101 unimplemented!()
2102 }
2103
2104 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2105 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2106 }
2107
2108 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2109 self.background.simulate_random_delay().await;
2110 let users = self.users.lock();
2111 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2112 }
2113
2114 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2115 Ok(self
2116 .users
2117 .lock()
2118 .values()
2119 .find(|user| user.github_login == github_login)
2120 .cloned())
2121 }
2122
2123 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2124 unimplemented!()
2125 }
2126
2127 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2128 self.background.simulate_random_delay().await;
2129 let mut users = self.users.lock();
2130 let mut user = users
2131 .get_mut(&id)
2132 .ok_or_else(|| anyhow!("user not found"))?;
2133 user.connected_once = connected_once;
2134 Ok(())
2135 }
2136
2137 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2138 unimplemented!()
2139 }
2140
2141 // invite codes
2142
2143 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2144 unimplemented!()
2145 }
2146
2147 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2148 Ok(None)
2149 }
2150
2151 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2152 unimplemented!()
2153 }
2154
2155 async fn redeem_invite_code(
2156 &self,
2157 _code: &str,
2158 _login: &str,
2159 _email_address: Option<&str>,
2160 ) -> Result<UserId> {
2161 unimplemented!()
2162 }
2163
2164 // projects
2165
2166 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2167 self.background.simulate_random_delay().await;
2168 if !self.users.lock().contains_key(&host_user_id) {
2169 Err(anyhow!("no such user"))?;
2170 }
2171
2172 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2173 self.projects.lock().insert(
2174 project_id,
2175 Project {
2176 id: project_id,
2177 host_user_id,
2178 unregistered: false,
2179 },
2180 );
2181 Ok(project_id)
2182 }
2183
2184 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2185 self.projects
2186 .lock()
2187 .get_mut(&project_id)
2188 .ok_or_else(|| anyhow!("no such project"))?
2189 .unregistered = true;
2190 Ok(())
2191 }
2192
2193 async fn update_worktree_extensions(
2194 &self,
2195 project_id: ProjectId,
2196 worktree_id: u64,
2197 extensions: HashMap<String, usize>,
2198 ) -> Result<()> {
2199 self.background.simulate_random_delay().await;
2200 if !self.projects.lock().contains_key(&project_id) {
2201 Err(anyhow!("no such project"))?;
2202 }
2203
2204 for (extension, count) in extensions {
2205 self.worktree_extensions
2206 .lock()
2207 .insert((project_id, worktree_id, extension), count);
2208 }
2209
2210 Ok(())
2211 }
2212
2213 async fn get_project_extensions(
2214 &self,
2215 _project_id: ProjectId,
2216 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2217 unimplemented!()
2218 }
2219
2220 async fn record_project_activity(
2221 &self,
2222 _period: Range<OffsetDateTime>,
2223 _active_projects: &[(UserId, ProjectId)],
2224 ) -> Result<()> {
2225 unimplemented!()
2226 }
2227
2228 async fn summarize_project_activity(
2229 &self,
2230 _period: Range<OffsetDateTime>,
2231 _limit: usize,
2232 ) -> Result<Vec<UserActivitySummary>> {
2233 unimplemented!()
2234 }
2235
2236 // contacts
2237
2238 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2239 self.background.simulate_random_delay().await;
2240 let mut contacts = vec![Contact::Accepted {
2241 user_id: id,
2242 should_notify: false,
2243 }];
2244
2245 for contact in self.contacts.lock().iter() {
2246 if contact.requester_id == id {
2247 if contact.accepted {
2248 contacts.push(Contact::Accepted {
2249 user_id: contact.responder_id,
2250 should_notify: contact.should_notify,
2251 });
2252 } else {
2253 contacts.push(Contact::Outgoing {
2254 user_id: contact.responder_id,
2255 });
2256 }
2257 } else if contact.responder_id == id {
2258 if contact.accepted {
2259 contacts.push(Contact::Accepted {
2260 user_id: contact.requester_id,
2261 should_notify: false,
2262 });
2263 } else {
2264 contacts.push(Contact::Incoming {
2265 user_id: contact.requester_id,
2266 should_notify: contact.should_notify,
2267 });
2268 }
2269 }
2270 }
2271
2272 contacts.sort_unstable_by_key(|contact| contact.user_id());
2273 Ok(contacts)
2274 }
2275
2276 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2277 self.background.simulate_random_delay().await;
2278 Ok(self.contacts.lock().iter().any(|contact| {
2279 contact.accepted
2280 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2281 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2282 }))
2283 }
2284
2285 async fn send_contact_request(
2286 &self,
2287 requester_id: UserId,
2288 responder_id: UserId,
2289 ) -> Result<()> {
2290 let mut contacts = self.contacts.lock();
2291 for contact in contacts.iter_mut() {
2292 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2293 if contact.accepted {
2294 Err(anyhow!("contact already exists"))?;
2295 } else {
2296 Err(anyhow!("contact already requested"))?;
2297 }
2298 }
2299 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2300 if contact.accepted {
2301 Err(anyhow!("contact already exists"))?;
2302 } else {
2303 contact.accepted = true;
2304 contact.should_notify = false;
2305 return Ok(());
2306 }
2307 }
2308 }
2309 contacts.push(FakeContact {
2310 requester_id,
2311 responder_id,
2312 accepted: false,
2313 should_notify: true,
2314 });
2315 Ok(())
2316 }
2317
2318 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2319 self.contacts.lock().retain(|contact| {
2320 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2321 });
2322 Ok(())
2323 }
2324
2325 async fn dismiss_contact_notification(
2326 &self,
2327 user_id: UserId,
2328 contact_user_id: UserId,
2329 ) -> Result<()> {
2330 let mut contacts = self.contacts.lock();
2331 for contact in contacts.iter_mut() {
2332 if contact.requester_id == contact_user_id
2333 && contact.responder_id == user_id
2334 && !contact.accepted
2335 {
2336 contact.should_notify = false;
2337 return Ok(());
2338 }
2339 if contact.requester_id == user_id
2340 && contact.responder_id == contact_user_id
2341 && contact.accepted
2342 {
2343 contact.should_notify = false;
2344 return Ok(());
2345 }
2346 }
2347 Err(anyhow!("no such notification"))?
2348 }
2349
2350 async fn respond_to_contact_request(
2351 &self,
2352 responder_id: UserId,
2353 requester_id: UserId,
2354 accept: bool,
2355 ) -> Result<()> {
2356 let mut contacts = self.contacts.lock();
2357 for (ix, contact) in contacts.iter_mut().enumerate() {
2358 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2359 if contact.accepted {
2360 Err(anyhow!("contact already confirmed"))?;
2361 }
2362 if accept {
2363 contact.accepted = true;
2364 contact.should_notify = true;
2365 } else {
2366 contacts.remove(ix);
2367 }
2368 return Ok(());
2369 }
2370 }
2371 Err(anyhow!("no such contact request"))?
2372 }
2373
2374 async fn create_access_token_hash(
2375 &self,
2376 _user_id: UserId,
2377 _access_token_hash: &str,
2378 _max_access_token_count: usize,
2379 ) -> Result<()> {
2380 unimplemented!()
2381 }
2382
2383 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2384 unimplemented!()
2385 }
2386
2387 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2388 unimplemented!()
2389 }
2390
2391 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2392 self.background.simulate_random_delay().await;
2393 let mut orgs = self.orgs.lock();
2394 if orgs.values().any(|org| org.slug == slug) {
2395 Err(anyhow!("org already exists"))?
2396 } else {
2397 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2398 orgs.insert(
2399 org_id,
2400 Org {
2401 id: org_id,
2402 name: name.to_string(),
2403 slug: slug.to_string(),
2404 },
2405 );
2406 Ok(org_id)
2407 }
2408 }
2409
2410 async fn add_org_member(
2411 &self,
2412 org_id: OrgId,
2413 user_id: UserId,
2414 is_admin: bool,
2415 ) -> Result<()> {
2416 self.background.simulate_random_delay().await;
2417 if !self.orgs.lock().contains_key(&org_id) {
2418 Err(anyhow!("org does not exist"))?;
2419 }
2420 if !self.users.lock().contains_key(&user_id) {
2421 Err(anyhow!("user does not exist"))?;
2422 }
2423
2424 self.org_memberships
2425 .lock()
2426 .entry((org_id, user_id))
2427 .or_insert(is_admin);
2428 Ok(())
2429 }
2430
2431 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2432 self.background.simulate_random_delay().await;
2433 if !self.orgs.lock().contains_key(&org_id) {
2434 Err(anyhow!("org does not exist"))?;
2435 }
2436
2437 let mut channels = self.channels.lock();
2438 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2439 channels.insert(
2440 channel_id,
2441 Channel {
2442 id: channel_id,
2443 name: name.to_string(),
2444 owner_id: org_id.0,
2445 owner_is_user: false,
2446 },
2447 );
2448 Ok(channel_id)
2449 }
2450
2451 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2452 self.background.simulate_random_delay().await;
2453 Ok(self
2454 .channels
2455 .lock()
2456 .values()
2457 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2458 .cloned()
2459 .collect())
2460 }
2461
2462 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2463 self.background.simulate_random_delay().await;
2464 let channels = self.channels.lock();
2465 let memberships = self.channel_memberships.lock();
2466 Ok(channels
2467 .values()
2468 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2469 .cloned()
2470 .collect())
2471 }
2472
2473 async fn can_user_access_channel(
2474 &self,
2475 user_id: UserId,
2476 channel_id: ChannelId,
2477 ) -> Result<bool> {
2478 self.background.simulate_random_delay().await;
2479 Ok(self
2480 .channel_memberships
2481 .lock()
2482 .contains_key(&(channel_id, user_id)))
2483 }
2484
2485 async fn add_channel_member(
2486 &self,
2487 channel_id: ChannelId,
2488 user_id: UserId,
2489 is_admin: bool,
2490 ) -> Result<()> {
2491 self.background.simulate_random_delay().await;
2492 if !self.channels.lock().contains_key(&channel_id) {
2493 Err(anyhow!("channel does not exist"))?;
2494 }
2495 if !self.users.lock().contains_key(&user_id) {
2496 Err(anyhow!("user does not exist"))?;
2497 }
2498
2499 self.channel_memberships
2500 .lock()
2501 .entry((channel_id, user_id))
2502 .or_insert(is_admin);
2503 Ok(())
2504 }
2505
2506 async fn create_channel_message(
2507 &self,
2508 channel_id: ChannelId,
2509 sender_id: UserId,
2510 body: &str,
2511 timestamp: OffsetDateTime,
2512 nonce: u128,
2513 ) -> Result<MessageId> {
2514 self.background.simulate_random_delay().await;
2515 if !self.channels.lock().contains_key(&channel_id) {
2516 Err(anyhow!("channel does not exist"))?;
2517 }
2518 if !self.users.lock().contains_key(&sender_id) {
2519 Err(anyhow!("user does not exist"))?;
2520 }
2521
2522 let mut messages = self.channel_messages.lock();
2523 if let Some(message) = messages
2524 .values()
2525 .find(|message| message.nonce.as_u128() == nonce)
2526 {
2527 Ok(message.id)
2528 } else {
2529 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2530 messages.insert(
2531 message_id,
2532 ChannelMessage {
2533 id: message_id,
2534 channel_id,
2535 sender_id,
2536 body: body.to_string(),
2537 sent_at: timestamp,
2538 nonce: Uuid::from_u128(nonce),
2539 },
2540 );
2541 Ok(message_id)
2542 }
2543 }
2544
2545 async fn get_channel_messages(
2546 &self,
2547 channel_id: ChannelId,
2548 count: usize,
2549 before_id: Option<MessageId>,
2550 ) -> Result<Vec<ChannelMessage>> {
2551 let mut messages = self
2552 .channel_messages
2553 .lock()
2554 .values()
2555 .rev()
2556 .filter(|message| {
2557 message.channel_id == channel_id
2558 && message.id < before_id.unwrap_or(MessageId::MAX)
2559 })
2560 .take(count)
2561 .cloned()
2562 .collect::<Vec<_>>();
2563 messages.sort_unstable_by_key(|message| message.id);
2564 Ok(messages)
2565 }
2566
2567 async fn teardown(&self, _: &str) {}
2568
2569 #[cfg(test)]
2570 fn as_fake(&self) -> Option<&FakeDb> {
2571 Some(self)
2572 }
2573 }
2574}