1use std::{ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::Serialize;
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::OffsetDateTime;
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
29 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
30 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
31 async fn destroy_user(&self, id: UserId) -> Result<()>;
32
33 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
34 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
35 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
36 async fn redeem_invite_code(
37 &self,
38 code: &str,
39 login: &str,
40 email_address: Option<&str>,
41 ) -> Result<UserId>;
42
43 /// Registers a new project for the given user.
44 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
45
46 /// Unregisters a project for the given project id.
47 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
48
49 /// Update file counts by extension for the given project and worktree.
50 async fn update_worktree_extensions(
51 &self,
52 project_id: ProjectId,
53 worktree_id: u64,
54 extensions: HashMap<String, usize>,
55 ) -> Result<()>;
56
57 /// Get the file counts on the given project keyed by their worktree and extension.
58 async fn get_project_extensions(
59 &self,
60 project_id: ProjectId,
61 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
62
63 /// Record which users have been active in which projects during
64 /// a given period of time.
65 async fn record_project_activity(
66 &self,
67 time_period: Range<OffsetDateTime>,
68 active_projects: &[(UserId, ProjectId)],
69 ) -> Result<()>;
70
71 /// Get the users that have been most active during the given time period,
72 /// along with the amount of time they have been active in each project.
73 async fn summarize_project_activity(
74 &self,
75 time_period: Range<OffsetDateTime>,
76 max_user_count: usize,
77 ) -> Result<Vec<UserActivitySummary>>;
78
79 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
80 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
81 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
82 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
83 async fn dismiss_contact_notification(
84 &self,
85 responder_id: UserId,
86 requester_id: UserId,
87 ) -> Result<()>;
88 async fn respond_to_contact_request(
89 &self,
90 responder_id: UserId,
91 requester_id: UserId,
92 accept: bool,
93 ) -> Result<()>;
94
95 async fn create_access_token_hash(
96 &self,
97 user_id: UserId,
98 access_token_hash: &str,
99 max_access_token_count: usize,
100 ) -> Result<()>;
101 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
102 #[cfg(any(test, feature = "seed-support"))]
103
104 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
105 #[cfg(any(test, feature = "seed-support"))]
106 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
107 #[cfg(any(test, feature = "seed-support"))]
108 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
109 #[cfg(any(test, feature = "seed-support"))]
110 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
111 #[cfg(any(test, feature = "seed-support"))]
112
113 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
114 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
115 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
116 -> Result<bool>;
117 #[cfg(any(test, feature = "seed-support"))]
118 async fn add_channel_member(
119 &self,
120 channel_id: ChannelId,
121 user_id: UserId,
122 is_admin: bool,
123 ) -> Result<()>;
124 async fn create_channel_message(
125 &self,
126 channel_id: ChannelId,
127 sender_id: UserId,
128 body: &str,
129 timestamp: OffsetDateTime,
130 nonce: u128,
131 ) -> Result<MessageId>;
132 async fn get_channel_messages(
133 &self,
134 channel_id: ChannelId,
135 count: usize,
136 before_id: Option<MessageId>,
137 ) -> Result<Vec<ChannelMessage>>;
138 #[cfg(test)]
139 async fn teardown(&self, url: &str);
140 #[cfg(test)]
141 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
142}
143
144pub struct PostgresDb {
145 pool: sqlx::PgPool,
146}
147
148impl PostgresDb {
149 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
150 let pool = DbOptions::new()
151 .max_connections(max_connections)
152 .connect(&url)
153 .await
154 .context("failed to connect to postgres database")?;
155 Ok(Self { pool })
156 }
157}
158
159#[async_trait]
160impl Db for PostgresDb {
161 // users
162
163 async fn create_user(
164 &self,
165 github_login: &str,
166 email_address: Option<&str>,
167 admin: bool,
168 ) -> Result<UserId> {
169 let query = "
170 INSERT INTO users (github_login, email_address, admin)
171 VALUES ($1, $2, $3)
172 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
173 RETURNING id
174 ";
175 Ok(sqlx::query_scalar(query)
176 .bind(github_login)
177 .bind(email_address)
178 .bind(admin)
179 .fetch_one(&self.pool)
180 .await
181 .map(UserId)?)
182 }
183
184 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
185 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
186 Ok(sqlx::query_as(query)
187 .bind(limit as i32)
188 .bind((page * limit) as i32)
189 .fetch_all(&self.pool)
190 .await?)
191 }
192
193 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
194 let mut query = QueryBuilder::new(
195 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
196 );
197 query.push_values(
198 users,
199 |mut query, (github_login, email_address, invite_count)| {
200 query
201 .push_bind(github_login)
202 .push_bind(email_address)
203 .push_bind(false)
204 .push_bind(nanoid!(16))
205 .push_bind(invite_count as i32);
206 },
207 );
208 query.push(
209 "
210 ON CONFLICT (github_login) DO UPDATE SET
211 github_login = excluded.github_login,
212 invite_count = excluded.invite_count,
213 invite_code = CASE WHEN users.invite_code IS NULL
214 THEN excluded.invite_code
215 ELSE users.invite_code
216 END
217 RETURNING id
218 ",
219 );
220
221 let rows = query.build().fetch_all(&self.pool).await?;
222 Ok(rows
223 .into_iter()
224 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
225 .collect())
226 }
227
228 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
229 let like_string = fuzzy_like_string(name_query);
230 let query = "
231 SELECT users.*
232 FROM users
233 WHERE github_login ILIKE $1
234 ORDER BY github_login <-> $2
235 LIMIT $3
236 ";
237 Ok(sqlx::query_as(query)
238 .bind(like_string)
239 .bind(name_query)
240 .bind(limit as i32)
241 .fetch_all(&self.pool)
242 .await?)
243 }
244
245 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
246 let users = self.get_users_by_ids(vec![id]).await?;
247 Ok(users.into_iter().next())
248 }
249
250 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
251 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
252 let query = "
253 SELECT users.*
254 FROM users
255 WHERE users.id = ANY ($1)
256 ";
257 Ok(sqlx::query_as(query)
258 .bind(&ids)
259 .fetch_all(&self.pool)
260 .await?)
261 }
262
263 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
264 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
265 Ok(sqlx::query_as(query)
266 .bind(github_login)
267 .fetch_optional(&self.pool)
268 .await?)
269 }
270
271 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
272 let query = "UPDATE users SET admin = $1 WHERE id = $2";
273 Ok(sqlx::query(query)
274 .bind(is_admin)
275 .bind(id.0)
276 .execute(&self.pool)
277 .await
278 .map(drop)?)
279 }
280
281 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
282 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
283 Ok(sqlx::query(query)
284 .bind(connected_once)
285 .bind(id.0)
286 .execute(&self.pool)
287 .await
288 .map(drop)?)
289 }
290
291 async fn destroy_user(&self, id: UserId) -> Result<()> {
292 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
293 sqlx::query(query)
294 .bind(id.0)
295 .execute(&self.pool)
296 .await
297 .map(drop)?;
298 let query = "DELETE FROM users WHERE id = $1;";
299 Ok(sqlx::query(query)
300 .bind(id.0)
301 .execute(&self.pool)
302 .await
303 .map(drop)?)
304 }
305
306 // invite codes
307
308 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
309 let mut tx = self.pool.begin().await?;
310 if count > 0 {
311 sqlx::query(
312 "
313 UPDATE users
314 SET invite_code = $1
315 WHERE id = $2 AND invite_code IS NULL
316 ",
317 )
318 .bind(nanoid!(16))
319 .bind(id)
320 .execute(&mut tx)
321 .await?;
322 }
323
324 sqlx::query(
325 "
326 UPDATE users
327 SET invite_count = $1
328 WHERE id = $2
329 ",
330 )
331 .bind(count as i32)
332 .bind(id)
333 .execute(&mut tx)
334 .await?;
335 tx.commit().await?;
336 Ok(())
337 }
338
339 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
340 let result: Option<(String, i32)> = sqlx::query_as(
341 "
342 SELECT invite_code, invite_count
343 FROM users
344 WHERE id = $1 AND invite_code IS NOT NULL
345 ",
346 )
347 .bind(id)
348 .fetch_optional(&self.pool)
349 .await?;
350 if let Some((code, count)) = result {
351 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
352 } else {
353 Ok(None)
354 }
355 }
356
357 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
358 sqlx::query_as(
359 "
360 SELECT *
361 FROM users
362 WHERE invite_code = $1
363 ",
364 )
365 .bind(code)
366 .fetch_optional(&self.pool)
367 .await?
368 .ok_or_else(|| {
369 Error::Http(
370 StatusCode::NOT_FOUND,
371 "that invite code does not exist".to_string(),
372 )
373 })
374 }
375
376 async fn redeem_invite_code(
377 &self,
378 code: &str,
379 login: &str,
380 email_address: Option<&str>,
381 ) -> Result<UserId> {
382 let mut tx = self.pool.begin().await?;
383
384 let inviter_id: Option<UserId> = sqlx::query_scalar(
385 "
386 UPDATE users
387 SET invite_count = invite_count - 1
388 WHERE
389 invite_code = $1 AND
390 invite_count > 0
391 RETURNING id
392 ",
393 )
394 .bind(code)
395 .fetch_optional(&mut tx)
396 .await?;
397
398 let inviter_id = match inviter_id {
399 Some(inviter_id) => inviter_id,
400 None => {
401 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
402 .bind(code)
403 .fetch_optional(&mut tx)
404 .await?
405 .is_some()
406 {
407 Err(Error::Http(
408 StatusCode::UNAUTHORIZED,
409 "no invites remaining".to_string(),
410 ))?
411 } else {
412 Err(Error::Http(
413 StatusCode::NOT_FOUND,
414 "invite code not found".to_string(),
415 ))?
416 }
417 }
418 };
419
420 let invitee_id = sqlx::query_scalar(
421 "
422 INSERT INTO users
423 (github_login, email_address, admin, inviter_id)
424 VALUES
425 ($1, $2, 'f', $3)
426 RETURNING id
427 ",
428 )
429 .bind(login)
430 .bind(email_address)
431 .bind(inviter_id)
432 .fetch_one(&mut tx)
433 .await
434 .map(UserId)?;
435
436 sqlx::query(
437 "
438 INSERT INTO contacts
439 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
440 VALUES
441 ($1, $2, 't', 't', 't')
442 ",
443 )
444 .bind(inviter_id)
445 .bind(invitee_id)
446 .execute(&mut tx)
447 .await?;
448
449 tx.commit().await?;
450 Ok(invitee_id)
451 }
452
453 // projects
454
455 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
456 Ok(sqlx::query_scalar(
457 "
458 INSERT INTO projects(host_user_id)
459 VALUES ($1)
460 RETURNING id
461 ",
462 )
463 .bind(host_user_id)
464 .fetch_one(&self.pool)
465 .await
466 .map(ProjectId)?)
467 }
468
469 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
470 sqlx::query(
471 "
472 UPDATE projects
473 SET unregistered = 't'
474 WHERE id = $1
475 ",
476 )
477 .bind(project_id)
478 .execute(&self.pool)
479 .await?;
480 Ok(())
481 }
482
483 async fn update_worktree_extensions(
484 &self,
485 project_id: ProjectId,
486 worktree_id: u64,
487 extensions: HashMap<String, usize>,
488 ) -> Result<()> {
489 if extensions.is_empty() {
490 return Ok(());
491 }
492
493 let mut query = QueryBuilder::new(
494 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
495 );
496 query.push_values(extensions, |mut query, (extension, count)| {
497 query
498 .push_bind(project_id)
499 .push_bind(worktree_id as i32)
500 .push_bind(extension)
501 .push_bind(count as i32);
502 });
503 query.push(
504 "
505 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
506 count = excluded.count
507 ",
508 );
509 query.build().execute(&self.pool).await?;
510
511 Ok(())
512 }
513
514 async fn get_project_extensions(
515 &self,
516 project_id: ProjectId,
517 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
518 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
519 struct WorktreeExtension {
520 worktree_id: i32,
521 extension: String,
522 count: i32,
523 }
524
525 let query = "
526 SELECT worktree_id, extension, count
527 FROM worktree_extensions
528 WHERE project_id = $1
529 ";
530 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
531 .bind(&project_id)
532 .fetch_all(&self.pool)
533 .await?;
534
535 let mut extension_counts = HashMap::default();
536 for count in counts {
537 extension_counts
538 .entry(count.worktree_id as u64)
539 .or_insert(HashMap::default())
540 .insert(count.extension, count.count as usize);
541 }
542 Ok(extension_counts)
543 }
544
545 async fn record_project_activity(
546 &self,
547 time_period: Range<OffsetDateTime>,
548 projects: &[(UserId, ProjectId)],
549 ) -> Result<()> {
550 let query = "
551 INSERT INTO project_activity_periods
552 (ended_at, duration_millis, user_id, project_id)
553 VALUES
554 ($1, $2, $3, $4);
555 ";
556
557 let mut tx = self.pool.begin().await?;
558 let duration_millis =
559 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
560 for (user_id, project_id) in projects {
561 sqlx::query(query)
562 .bind(time_period.end)
563 .bind(duration_millis)
564 .bind(user_id)
565 .bind(project_id)
566 .execute(&mut tx)
567 .await?;
568 }
569 tx.commit().await?;
570
571 Ok(())
572 }
573
574 async fn summarize_project_activity(
575 &self,
576 time_period: Range<OffsetDateTime>,
577 max_user_count: usize,
578 ) -> Result<Vec<UserActivitySummary>> {
579 let query = "
580 WITH
581 project_durations AS (
582 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
583 FROM project_activity_periods
584 WHERE $1 <= ended_at AND ended_at <= $2
585 GROUP BY user_id, project_id
586 ),
587 user_durations AS (
588 SELECT user_id, SUM(project_duration) as total_duration
589 FROM project_durations
590 GROUP BY user_id
591 ORDER BY total_duration DESC
592 LIMIT $3
593 )
594 SELECT user_durations.user_id, users.github_login, project_id, project_duration
595 FROM user_durations, project_durations, users
596 WHERE
597 user_durations.user_id = project_durations.user_id AND
598 user_durations.user_id = users.id
599 ORDER BY user_id ASC, project_duration DESC
600 ";
601
602 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
603 .bind(time_period.start)
604 .bind(time_period.end)
605 .bind(max_user_count as i32)
606 .fetch(&self.pool);
607
608 let mut result = Vec::<UserActivitySummary>::new();
609 while let Some(row) = rows.next().await {
610 let (user_id, github_login, project_id, duration_millis) = row?;
611 let project_id = project_id;
612 let duration = Duration::from_millis(duration_millis as u64);
613 if let Some(last_summary) = result.last_mut() {
614 if last_summary.id == user_id {
615 last_summary.project_activity.push((project_id, duration));
616 continue;
617 }
618 }
619 result.push(UserActivitySummary {
620 id: user_id,
621 project_activity: vec![(project_id, duration)],
622 github_login,
623 });
624 }
625
626 Ok(result)
627 }
628
629 // contacts
630
631 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
632 let query = "
633 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
634 FROM contacts
635 WHERE user_id_a = $1 OR user_id_b = $1;
636 ";
637
638 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
639 .bind(user_id)
640 .fetch(&self.pool);
641
642 let mut contacts = vec![Contact::Accepted {
643 user_id,
644 should_notify: false,
645 }];
646 while let Some(row) = rows.next().await {
647 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
648
649 if user_id_a == user_id {
650 if accepted {
651 contacts.push(Contact::Accepted {
652 user_id: user_id_b,
653 should_notify: should_notify && a_to_b,
654 });
655 } else if a_to_b {
656 contacts.push(Contact::Outgoing { user_id: user_id_b })
657 } else {
658 contacts.push(Contact::Incoming {
659 user_id: user_id_b,
660 should_notify,
661 });
662 }
663 } else {
664 if accepted {
665 contacts.push(Contact::Accepted {
666 user_id: user_id_a,
667 should_notify: should_notify && !a_to_b,
668 });
669 } else if a_to_b {
670 contacts.push(Contact::Incoming {
671 user_id: user_id_a,
672 should_notify,
673 });
674 } else {
675 contacts.push(Contact::Outgoing { user_id: user_id_a });
676 }
677 }
678 }
679
680 contacts.sort_unstable_by_key(|contact| contact.user_id());
681
682 Ok(contacts)
683 }
684
685 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
686 let (id_a, id_b) = if user_id_1 < user_id_2 {
687 (user_id_1, user_id_2)
688 } else {
689 (user_id_2, user_id_1)
690 };
691
692 let query = "
693 SELECT 1 FROM contacts
694 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
695 LIMIT 1
696 ";
697 Ok(sqlx::query_scalar::<_, i32>(query)
698 .bind(id_a.0)
699 .bind(id_b.0)
700 .fetch_optional(&self.pool)
701 .await?
702 .is_some())
703 }
704
705 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
706 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
707 (sender_id, receiver_id, true)
708 } else {
709 (receiver_id, sender_id, false)
710 };
711 let query = "
712 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
713 VALUES ($1, $2, $3, 'f', 't')
714 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
715 SET
716 accepted = 't',
717 should_notify = 'f'
718 WHERE
719 NOT contacts.accepted AND
720 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
721 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
722 ";
723 let result = sqlx::query(query)
724 .bind(id_a.0)
725 .bind(id_b.0)
726 .bind(a_to_b)
727 .execute(&self.pool)
728 .await?;
729
730 if result.rows_affected() == 1 {
731 Ok(())
732 } else {
733 Err(anyhow!("contact already requested"))?
734 }
735 }
736
737 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
738 let (id_a, id_b) = if responder_id < requester_id {
739 (responder_id, requester_id)
740 } else {
741 (requester_id, responder_id)
742 };
743 let query = "
744 DELETE FROM contacts
745 WHERE user_id_a = $1 AND user_id_b = $2;
746 ";
747 let result = sqlx::query(query)
748 .bind(id_a.0)
749 .bind(id_b.0)
750 .execute(&self.pool)
751 .await?;
752
753 if result.rows_affected() == 1 {
754 Ok(())
755 } else {
756 Err(anyhow!("no such contact"))?
757 }
758 }
759
760 async fn dismiss_contact_notification(
761 &self,
762 user_id: UserId,
763 contact_user_id: UserId,
764 ) -> Result<()> {
765 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
766 (user_id, contact_user_id, true)
767 } else {
768 (contact_user_id, user_id, false)
769 };
770
771 let query = "
772 UPDATE contacts
773 SET should_notify = 'f'
774 WHERE
775 user_id_a = $1 AND user_id_b = $2 AND
776 (
777 (a_to_b = $3 AND accepted) OR
778 (a_to_b != $3 AND NOT accepted)
779 );
780 ";
781
782 let result = sqlx::query(query)
783 .bind(id_a.0)
784 .bind(id_b.0)
785 .bind(a_to_b)
786 .execute(&self.pool)
787 .await?;
788
789 if result.rows_affected() == 0 {
790 Err(anyhow!("no such contact request"))?;
791 }
792
793 Ok(())
794 }
795
796 async fn respond_to_contact_request(
797 &self,
798 responder_id: UserId,
799 requester_id: UserId,
800 accept: bool,
801 ) -> Result<()> {
802 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
803 (responder_id, requester_id, false)
804 } else {
805 (requester_id, responder_id, true)
806 };
807 let result = if accept {
808 let query = "
809 UPDATE contacts
810 SET accepted = 't', should_notify = 't'
811 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
812 ";
813 sqlx::query(query)
814 .bind(id_a.0)
815 .bind(id_b.0)
816 .bind(a_to_b)
817 .execute(&self.pool)
818 .await?
819 } else {
820 let query = "
821 DELETE FROM contacts
822 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
823 ";
824 sqlx::query(query)
825 .bind(id_a.0)
826 .bind(id_b.0)
827 .bind(a_to_b)
828 .execute(&self.pool)
829 .await?
830 };
831 if result.rows_affected() == 1 {
832 Ok(())
833 } else {
834 Err(anyhow!("no such contact request"))?
835 }
836 }
837
838 // access tokens
839
840 async fn create_access_token_hash(
841 &self,
842 user_id: UserId,
843 access_token_hash: &str,
844 max_access_token_count: usize,
845 ) -> Result<()> {
846 let insert_query = "
847 INSERT INTO access_tokens (user_id, hash)
848 VALUES ($1, $2);
849 ";
850 let cleanup_query = "
851 DELETE FROM access_tokens
852 WHERE id IN (
853 SELECT id from access_tokens
854 WHERE user_id = $1
855 ORDER BY id DESC
856 OFFSET $3
857 )
858 ";
859
860 let mut tx = self.pool.begin().await?;
861 sqlx::query(insert_query)
862 .bind(user_id.0)
863 .bind(access_token_hash)
864 .execute(&mut tx)
865 .await?;
866 sqlx::query(cleanup_query)
867 .bind(user_id.0)
868 .bind(access_token_hash)
869 .bind(max_access_token_count as i32)
870 .execute(&mut tx)
871 .await?;
872 Ok(tx.commit().await?)
873 }
874
875 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
876 let query = "
877 SELECT hash
878 FROM access_tokens
879 WHERE user_id = $1
880 ORDER BY id DESC
881 ";
882 Ok(sqlx::query_scalar(query)
883 .bind(user_id.0)
884 .fetch_all(&self.pool)
885 .await?)
886 }
887
888 // orgs
889
890 #[allow(unused)] // Help rust-analyzer
891 #[cfg(any(test, feature = "seed-support"))]
892 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
893 let query = "
894 SELECT *
895 FROM orgs
896 WHERE slug = $1
897 ";
898 Ok(sqlx::query_as(query)
899 .bind(slug)
900 .fetch_optional(&self.pool)
901 .await?)
902 }
903
904 #[cfg(any(test, feature = "seed-support"))]
905 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
906 let query = "
907 INSERT INTO orgs (name, slug)
908 VALUES ($1, $2)
909 RETURNING id
910 ";
911 Ok(sqlx::query_scalar(query)
912 .bind(name)
913 .bind(slug)
914 .fetch_one(&self.pool)
915 .await
916 .map(OrgId)?)
917 }
918
919 #[cfg(any(test, feature = "seed-support"))]
920 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
921 let query = "
922 INSERT INTO org_memberships (org_id, user_id, admin)
923 VALUES ($1, $2, $3)
924 ON CONFLICT DO NOTHING
925 ";
926 Ok(sqlx::query(query)
927 .bind(org_id.0)
928 .bind(user_id.0)
929 .bind(is_admin)
930 .execute(&self.pool)
931 .await
932 .map(drop)?)
933 }
934
935 // channels
936
937 #[cfg(any(test, feature = "seed-support"))]
938 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
939 let query = "
940 INSERT INTO channels (owner_id, owner_is_user, name)
941 VALUES ($1, false, $2)
942 RETURNING id
943 ";
944 Ok(sqlx::query_scalar(query)
945 .bind(org_id.0)
946 .bind(name)
947 .fetch_one(&self.pool)
948 .await
949 .map(ChannelId)?)
950 }
951
952 #[allow(unused)] // Help rust-analyzer
953 #[cfg(any(test, feature = "seed-support"))]
954 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
955 let query = "
956 SELECT *
957 FROM channels
958 WHERE
959 channels.owner_is_user = false AND
960 channels.owner_id = $1
961 ";
962 Ok(sqlx::query_as(query)
963 .bind(org_id.0)
964 .fetch_all(&self.pool)
965 .await?)
966 }
967
968 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
969 let query = "
970 SELECT
971 channels.*
972 FROM
973 channel_memberships, channels
974 WHERE
975 channel_memberships.user_id = $1 AND
976 channel_memberships.channel_id = channels.id
977 ";
978 Ok(sqlx::query_as(query)
979 .bind(user_id.0)
980 .fetch_all(&self.pool)
981 .await?)
982 }
983
984 async fn can_user_access_channel(
985 &self,
986 user_id: UserId,
987 channel_id: ChannelId,
988 ) -> Result<bool> {
989 let query = "
990 SELECT id
991 FROM channel_memberships
992 WHERE user_id = $1 AND channel_id = $2
993 LIMIT 1
994 ";
995 Ok(sqlx::query_scalar::<_, i32>(query)
996 .bind(user_id.0)
997 .bind(channel_id.0)
998 .fetch_optional(&self.pool)
999 .await
1000 .map(|e| e.is_some())?)
1001 }
1002
1003 #[cfg(any(test, feature = "seed-support"))]
1004 async fn add_channel_member(
1005 &self,
1006 channel_id: ChannelId,
1007 user_id: UserId,
1008 is_admin: bool,
1009 ) -> Result<()> {
1010 let query = "
1011 INSERT INTO channel_memberships (channel_id, user_id, admin)
1012 VALUES ($1, $2, $3)
1013 ON CONFLICT DO NOTHING
1014 ";
1015 Ok(sqlx::query(query)
1016 .bind(channel_id.0)
1017 .bind(user_id.0)
1018 .bind(is_admin)
1019 .execute(&self.pool)
1020 .await
1021 .map(drop)?)
1022 }
1023
1024 // messages
1025
1026 async fn create_channel_message(
1027 &self,
1028 channel_id: ChannelId,
1029 sender_id: UserId,
1030 body: &str,
1031 timestamp: OffsetDateTime,
1032 nonce: u128,
1033 ) -> Result<MessageId> {
1034 let query = "
1035 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1036 VALUES ($1, $2, $3, $4, $5)
1037 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1038 RETURNING id
1039 ";
1040 Ok(sqlx::query_scalar(query)
1041 .bind(channel_id.0)
1042 .bind(sender_id.0)
1043 .bind(body)
1044 .bind(timestamp)
1045 .bind(Uuid::from_u128(nonce))
1046 .fetch_one(&self.pool)
1047 .await
1048 .map(MessageId)?)
1049 }
1050
1051 async fn get_channel_messages(
1052 &self,
1053 channel_id: ChannelId,
1054 count: usize,
1055 before_id: Option<MessageId>,
1056 ) -> Result<Vec<ChannelMessage>> {
1057 let query = r#"
1058 SELECT * FROM (
1059 SELECT
1060 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1061 FROM
1062 channel_messages
1063 WHERE
1064 channel_id = $1 AND
1065 id < $2
1066 ORDER BY id DESC
1067 LIMIT $3
1068 ) as recent_messages
1069 ORDER BY id ASC
1070 "#;
1071 Ok(sqlx::query_as(query)
1072 .bind(channel_id.0)
1073 .bind(before_id.unwrap_or(MessageId::MAX))
1074 .bind(count as i64)
1075 .fetch_all(&self.pool)
1076 .await?)
1077 }
1078
1079 #[cfg(test)]
1080 async fn teardown(&self, url: &str) {
1081 use util::ResultExt;
1082
1083 let query = "
1084 SELECT pg_terminate_backend(pg_stat_activity.pid)
1085 FROM pg_stat_activity
1086 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1087 ";
1088 sqlx::query(query).execute(&self.pool).await.log_err();
1089 self.pool.close().await;
1090 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1091 .await
1092 .log_err();
1093 }
1094
1095 #[cfg(test)]
1096 fn as_fake(&self) -> Option<&tests::FakeDb> {
1097 None
1098 }
1099}
1100
1101macro_rules! id_type {
1102 ($name:ident) => {
1103 #[derive(
1104 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
1105 )]
1106 #[sqlx(transparent)]
1107 #[serde(transparent)]
1108 pub struct $name(pub i32);
1109
1110 impl $name {
1111 #[allow(unused)]
1112 pub const MAX: Self = Self(i32::MAX);
1113
1114 #[allow(unused)]
1115 pub fn from_proto(value: u64) -> Self {
1116 Self(value as i32)
1117 }
1118
1119 #[allow(unused)]
1120 pub fn to_proto(&self) -> u64 {
1121 self.0 as u64
1122 }
1123 }
1124
1125 impl std::fmt::Display for $name {
1126 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1127 self.0.fmt(f)
1128 }
1129 }
1130 };
1131}
1132
1133id_type!(UserId);
1134#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1135pub struct User {
1136 pub id: UserId,
1137 pub github_login: String,
1138 pub email_address: Option<String>,
1139 pub admin: bool,
1140 pub invite_code: Option<String>,
1141 pub invite_count: i32,
1142 pub connected_once: bool,
1143}
1144
1145id_type!(ProjectId);
1146#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1147pub struct Project {
1148 pub id: ProjectId,
1149 pub host_user_id: UserId,
1150 pub unregistered: bool,
1151}
1152
1153#[derive(Clone, Debug, PartialEq, Serialize)]
1154pub struct UserActivitySummary {
1155 pub id: UserId,
1156 pub github_login: String,
1157 pub project_activity: Vec<(ProjectId, Duration)>,
1158}
1159
1160id_type!(OrgId);
1161#[derive(FromRow)]
1162pub struct Org {
1163 pub id: OrgId,
1164 pub name: String,
1165 pub slug: String,
1166}
1167
1168id_type!(ChannelId);
1169#[derive(Clone, Debug, FromRow, Serialize)]
1170pub struct Channel {
1171 pub id: ChannelId,
1172 pub name: String,
1173 pub owner_id: i32,
1174 pub owner_is_user: bool,
1175}
1176
1177id_type!(MessageId);
1178#[derive(Clone, Debug, FromRow)]
1179pub struct ChannelMessage {
1180 pub id: MessageId,
1181 pub channel_id: ChannelId,
1182 pub sender_id: UserId,
1183 pub body: String,
1184 pub sent_at: OffsetDateTime,
1185 pub nonce: Uuid,
1186}
1187
1188#[derive(Clone, Debug, PartialEq, Eq)]
1189pub enum Contact {
1190 Accepted {
1191 user_id: UserId,
1192 should_notify: bool,
1193 },
1194 Outgoing {
1195 user_id: UserId,
1196 },
1197 Incoming {
1198 user_id: UserId,
1199 should_notify: bool,
1200 },
1201}
1202
1203impl Contact {
1204 pub fn user_id(&self) -> UserId {
1205 match self {
1206 Contact::Accepted { user_id, .. } => *user_id,
1207 Contact::Outgoing { user_id } => *user_id,
1208 Contact::Incoming { user_id, .. } => *user_id,
1209 }
1210 }
1211}
1212
1213#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1214pub struct IncomingContactRequest {
1215 pub requester_id: UserId,
1216 pub should_notify: bool,
1217}
1218
1219fn fuzzy_like_string(string: &str) -> String {
1220 let mut result = String::with_capacity(string.len() * 2 + 1);
1221 for c in string.chars() {
1222 if c.is_alphanumeric() {
1223 result.push('%');
1224 result.push(c);
1225 }
1226 }
1227 result.push('%');
1228 result
1229}
1230
1231#[cfg(test)]
1232pub mod tests {
1233 use super::*;
1234 use anyhow::anyhow;
1235 use collections::BTreeMap;
1236 use gpui::executor::Background;
1237 use lazy_static::lazy_static;
1238 use parking_lot::Mutex;
1239 use rand::prelude::*;
1240 use sqlx::{
1241 migrate::{MigrateDatabase, Migrator},
1242 Postgres,
1243 };
1244 use std::{path::Path, sync::Arc};
1245 use util::post_inc;
1246
1247 #[tokio::test(flavor = "multi_thread")]
1248 async fn test_get_users_by_ids() {
1249 for test_db in [
1250 TestDb::postgres().await,
1251 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1252 ] {
1253 let db = test_db.db();
1254
1255 let user = db.create_user("user", None, false).await.unwrap();
1256 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1257 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1258 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1259
1260 assert_eq!(
1261 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1262 .await
1263 .unwrap(),
1264 vec![
1265 User {
1266 id: user,
1267 github_login: "user".to_string(),
1268 admin: false,
1269 ..Default::default()
1270 },
1271 User {
1272 id: friend1,
1273 github_login: "friend-1".to_string(),
1274 admin: false,
1275 ..Default::default()
1276 },
1277 User {
1278 id: friend2,
1279 github_login: "friend-2".to_string(),
1280 admin: false,
1281 ..Default::default()
1282 },
1283 User {
1284 id: friend3,
1285 github_login: "friend-3".to_string(),
1286 admin: false,
1287 ..Default::default()
1288 }
1289 ]
1290 );
1291 }
1292 }
1293
1294 #[tokio::test(flavor = "multi_thread")]
1295 async fn test_create_users() {
1296 let db = TestDb::postgres().await;
1297 let db = db.db();
1298
1299 // Create the first batch of users, ensuring invite counts are assigned
1300 // correctly and the respective invite codes are unique.
1301 let user_ids_batch_1 = db
1302 .create_users(vec![
1303 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1304 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1305 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1306 ])
1307 .await
1308 .unwrap();
1309 assert_eq!(user_ids_batch_1.len(), 3);
1310
1311 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1312 assert_eq!(users.len(), 3);
1313 assert_eq!(users[0].github_login, "user1");
1314 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1315 assert_eq!(users[0].invite_count, 5);
1316 assert_eq!(users[1].github_login, "user2");
1317 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1318 assert_eq!(users[1].invite_count, 4);
1319 assert_eq!(users[2].github_login, "user3");
1320 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1321 assert_eq!(users[2].invite_count, 3);
1322
1323 let invite_code_1 = users[0].invite_code.clone().unwrap();
1324 let invite_code_2 = users[1].invite_code.clone().unwrap();
1325 let invite_code_3 = users[2].invite_code.clone().unwrap();
1326 assert_ne!(invite_code_1, invite_code_2);
1327 assert_ne!(invite_code_1, invite_code_3);
1328 assert_ne!(invite_code_2, invite_code_3);
1329
1330 // Create the second batch of users and include a user that is already in the database, ensuring
1331 // the invite count for the existing user is updated without changing their invite code.
1332 let user_ids_batch_2 = db
1333 .create_users(vec![
1334 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1335 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1336 ])
1337 .await
1338 .unwrap();
1339 assert_eq!(user_ids_batch_2.len(), 2);
1340 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1341
1342 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1343 assert_eq!(users.len(), 2);
1344 assert_eq!(users[0].github_login, "user2");
1345 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1346 assert_eq!(users[0].invite_count, 10);
1347 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1348 assert_eq!(users[1].github_login, "user4");
1349 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1350 assert_eq!(users[1].invite_count, 2);
1351
1352 let invite_code_4 = users[1].invite_code.clone().unwrap();
1353 assert_ne!(invite_code_4, invite_code_1);
1354 assert_ne!(invite_code_4, invite_code_2);
1355 assert_ne!(invite_code_4, invite_code_3);
1356 }
1357
1358 #[tokio::test(flavor = "multi_thread")]
1359 async fn test_worktree_extensions() {
1360 let test_db = TestDb::postgres().await;
1361 let db = test_db.db();
1362
1363 let user = db.create_user("user_1", None, false).await.unwrap();
1364 let project = db.register_project(user).await.unwrap();
1365
1366 db.update_worktree_extensions(project, 100, Default::default()).await.unwrap();
1367 db.update_worktree_extensions(project, 100, [("rs".to_string(), 5), ("md".to_string(), 3)].into_iter().collect()).await.unwrap();
1368 db.update_worktree_extensions(project, 100, [("rs".to_string(), 6), ("md".to_string(), 5)].into_iter().collect()).await.unwrap();
1369 db.update_worktree_extensions(project, 101, [("ts".to_string(), 2), ("md".to_string(), 1)].into_iter().collect()).await.unwrap();
1370
1371 assert_eq!(db.get_project_extensions(project).await.unwrap(), [
1372 (100, [
1373 ("rs".into(), 6),
1374 ("md".into(), 5),
1375 ].into_iter().collect::<HashMap<_, _>>()),
1376 (101, [
1377 ("ts".into(), 2),
1378 ("md".into(), 1),
1379 ].into_iter().collect::<HashMap<_, _>>())
1380 ].into_iter().collect());
1381 }
1382
1383 #[tokio::test(flavor = "multi_thread")]
1384 async fn test_project_activity() {
1385 let test_db = TestDb::postgres().await;
1386 let db = test_db.db();
1387
1388 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1389 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1390 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1391 let project_1 = db.register_project(user_1).await.unwrap();
1392 let project_2 = db.register_project(user_2).await.unwrap();
1393 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1394
1395 // User 2 opens a project
1396 let t1 = t0 + Duration::from_secs(10);
1397 db.record_project_activity(t0..t1, &[(user_2, project_2)])
1398 .await
1399 .unwrap();
1400
1401 let t2 = t1 + Duration::from_secs(10);
1402 db.record_project_activity(t1..t2, &[(user_2, project_2)])
1403 .await
1404 .unwrap();
1405
1406 // User 1 joins the project
1407 let t3 = t2 + Duration::from_secs(10);
1408 db.record_project_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1409 .await
1410 .unwrap();
1411
1412 // User 1 opens another project
1413 let t4 = t3 + Duration::from_secs(10);
1414 db.record_project_activity(
1415 t3..t4,
1416 &[
1417 (user_2, project_2),
1418 (user_1, project_2),
1419 (user_1, project_1),
1420 ],
1421 )
1422 .await
1423 .unwrap();
1424
1425 // User 3 joins that project
1426 let t5 = t4 + Duration::from_secs(10);
1427 db.record_project_activity(
1428 t4..t5,
1429 &[
1430 (user_2, project_2),
1431 (user_1, project_2),
1432 (user_1, project_1),
1433 (user_3, project_1),
1434 ],
1435 )
1436 .await
1437 .unwrap();
1438
1439 // User 2 leaves
1440 let t6 = t5 + Duration::from_secs(5);
1441 db.record_project_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1442 .await
1443 .unwrap();
1444
1445 let summary = db.summarize_project_activity(t0..t6, 10).await.unwrap();
1446 assert_eq!(
1447 summary,
1448 &[
1449 UserActivitySummary {
1450 id: user_1,
1451 github_login: "user_1".to_string(),
1452 project_activity: vec![
1453 (project_2, Duration::from_secs(30)),
1454 (project_1, Duration::from_secs(25))
1455 ]
1456 },
1457 UserActivitySummary {
1458 id: user_2,
1459 github_login: "user_2".to_string(),
1460 project_activity: vec![(project_2, Duration::from_secs(50))]
1461 },
1462 UserActivitySummary {
1463 id: user_3,
1464 github_login: "user_3".to_string(),
1465 project_activity: vec![(project_1, Duration::from_secs(15))]
1466 },
1467 ]
1468 );
1469 }
1470
1471 #[tokio::test(flavor = "multi_thread")]
1472 async fn test_recent_channel_messages() {
1473 for test_db in [
1474 TestDb::postgres().await,
1475 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1476 ] {
1477 let db = test_db.db();
1478 let user = db.create_user("user", None, false).await.unwrap();
1479 let org = db.create_org("org", "org").await.unwrap();
1480 let channel = db.create_org_channel(org, "channel").await.unwrap();
1481 for i in 0..10 {
1482 db.create_channel_message(
1483 channel,
1484 user,
1485 &i.to_string(),
1486 OffsetDateTime::now_utc(),
1487 i,
1488 )
1489 .await
1490 .unwrap();
1491 }
1492
1493 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1494 assert_eq!(
1495 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1496 ["5", "6", "7", "8", "9"]
1497 );
1498
1499 let prev_messages = db
1500 .get_channel_messages(channel, 4, Some(messages[0].id))
1501 .await
1502 .unwrap();
1503 assert_eq!(
1504 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1505 ["1", "2", "3", "4"]
1506 );
1507 }
1508 }
1509
1510 #[tokio::test(flavor = "multi_thread")]
1511 async fn test_channel_message_nonces() {
1512 for test_db in [
1513 TestDb::postgres().await,
1514 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1515 ] {
1516 let db = test_db.db();
1517 let user = db.create_user("user", None, false).await.unwrap();
1518 let org = db.create_org("org", "org").await.unwrap();
1519 let channel = db.create_org_channel(org, "channel").await.unwrap();
1520
1521 let msg1_id = db
1522 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1523 .await
1524 .unwrap();
1525 let msg2_id = db
1526 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1527 .await
1528 .unwrap();
1529 let msg3_id = db
1530 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1531 .await
1532 .unwrap();
1533 let msg4_id = db
1534 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1535 .await
1536 .unwrap();
1537
1538 assert_ne!(msg1_id, msg2_id);
1539 assert_eq!(msg1_id, msg3_id);
1540 assert_eq!(msg2_id, msg4_id);
1541 }
1542 }
1543
1544 #[tokio::test(flavor = "multi_thread")]
1545 async fn test_create_access_tokens() {
1546 let test_db = TestDb::postgres().await;
1547 let db = test_db.db();
1548 let user = db.create_user("the-user", None, false).await.unwrap();
1549
1550 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1551 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1552 assert_eq!(
1553 db.get_access_token_hashes(user).await.unwrap(),
1554 &["h2".to_string(), "h1".to_string()]
1555 );
1556
1557 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1558 assert_eq!(
1559 db.get_access_token_hashes(user).await.unwrap(),
1560 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1561 );
1562
1563 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1564 assert_eq!(
1565 db.get_access_token_hashes(user).await.unwrap(),
1566 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1567 );
1568
1569 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1570 assert_eq!(
1571 db.get_access_token_hashes(user).await.unwrap(),
1572 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1573 );
1574 }
1575
1576 #[test]
1577 fn test_fuzzy_like_string() {
1578 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1579 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1580 assert_eq!(fuzzy_like_string(" z "), "%z%");
1581 }
1582
1583 #[tokio::test(flavor = "multi_thread")]
1584 async fn test_fuzzy_search_users() {
1585 let test_db = TestDb::postgres().await;
1586 let db = test_db.db();
1587 for github_login in [
1588 "California",
1589 "colorado",
1590 "oregon",
1591 "washington",
1592 "florida",
1593 "delaware",
1594 "rhode-island",
1595 ] {
1596 db.create_user(github_login, None, false).await.unwrap();
1597 }
1598
1599 assert_eq!(
1600 fuzzy_search_user_names(db, "clr").await,
1601 &["colorado", "California"]
1602 );
1603 assert_eq!(
1604 fuzzy_search_user_names(db, "ro").await,
1605 &["rhode-island", "colorado", "oregon"],
1606 );
1607
1608 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1609 db.fuzzy_search_users(query, 10)
1610 .await
1611 .unwrap()
1612 .into_iter()
1613 .map(|user| user.github_login)
1614 .collect::<Vec<_>>()
1615 }
1616 }
1617
1618 #[tokio::test(flavor = "multi_thread")]
1619 async fn test_add_contacts() {
1620 for test_db in [
1621 TestDb::postgres().await,
1622 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1623 ] {
1624 let db = test_db.db();
1625
1626 let user_1 = db.create_user("user1", None, false).await.unwrap();
1627 let user_2 = db.create_user("user2", None, false).await.unwrap();
1628 let user_3 = db.create_user("user3", None, false).await.unwrap();
1629
1630 // User starts with no contacts
1631 assert_eq!(
1632 db.get_contacts(user_1).await.unwrap(),
1633 vec![Contact::Accepted {
1634 user_id: user_1,
1635 should_notify: false
1636 }],
1637 );
1638
1639 // User requests a contact. Both users see the pending request.
1640 db.send_contact_request(user_1, user_2).await.unwrap();
1641 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1642 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1643 assert_eq!(
1644 db.get_contacts(user_1).await.unwrap(),
1645 &[
1646 Contact::Accepted {
1647 user_id: user_1,
1648 should_notify: false
1649 },
1650 Contact::Outgoing { user_id: user_2 }
1651 ],
1652 );
1653 assert_eq!(
1654 db.get_contacts(user_2).await.unwrap(),
1655 &[
1656 Contact::Incoming {
1657 user_id: user_1,
1658 should_notify: true
1659 },
1660 Contact::Accepted {
1661 user_id: user_2,
1662 should_notify: false
1663 },
1664 ]
1665 );
1666
1667 // User 2 dismisses the contact request notification without accepting or rejecting.
1668 // We shouldn't notify them again.
1669 db.dismiss_contact_notification(user_1, user_2)
1670 .await
1671 .unwrap_err();
1672 db.dismiss_contact_notification(user_2, user_1)
1673 .await
1674 .unwrap();
1675 assert_eq!(
1676 db.get_contacts(user_2).await.unwrap(),
1677 &[
1678 Contact::Incoming {
1679 user_id: user_1,
1680 should_notify: false
1681 },
1682 Contact::Accepted {
1683 user_id: user_2,
1684 should_notify: false
1685 },
1686 ]
1687 );
1688
1689 // User can't accept their own contact request
1690 db.respond_to_contact_request(user_1, user_2, true)
1691 .await
1692 .unwrap_err();
1693
1694 // User accepts a contact request. Both users see the contact.
1695 db.respond_to_contact_request(user_2, user_1, true)
1696 .await
1697 .unwrap();
1698 assert_eq!(
1699 db.get_contacts(user_1).await.unwrap(),
1700 &[
1701 Contact::Accepted {
1702 user_id: user_1,
1703 should_notify: false
1704 },
1705 Contact::Accepted {
1706 user_id: user_2,
1707 should_notify: true
1708 }
1709 ],
1710 );
1711 assert!(db.has_contact(user_1, user_2).await.unwrap());
1712 assert!(db.has_contact(user_2, user_1).await.unwrap());
1713 assert_eq!(
1714 db.get_contacts(user_2).await.unwrap(),
1715 &[
1716 Contact::Accepted {
1717 user_id: user_1,
1718 should_notify: false,
1719 },
1720 Contact::Accepted {
1721 user_id: user_2,
1722 should_notify: false,
1723 },
1724 ]
1725 );
1726
1727 // Users cannot re-request existing contacts.
1728 db.send_contact_request(user_1, user_2).await.unwrap_err();
1729 db.send_contact_request(user_2, user_1).await.unwrap_err();
1730
1731 // Users can't dismiss notifications of them accepting other users' requests.
1732 db.dismiss_contact_notification(user_2, user_1)
1733 .await
1734 .unwrap_err();
1735 assert_eq!(
1736 db.get_contacts(user_1).await.unwrap(),
1737 &[
1738 Contact::Accepted {
1739 user_id: user_1,
1740 should_notify: false
1741 },
1742 Contact::Accepted {
1743 user_id: user_2,
1744 should_notify: true,
1745 },
1746 ]
1747 );
1748
1749 // Users can dismiss notifications of other users accepting their requests.
1750 db.dismiss_contact_notification(user_1, user_2)
1751 .await
1752 .unwrap();
1753 assert_eq!(
1754 db.get_contacts(user_1).await.unwrap(),
1755 &[
1756 Contact::Accepted {
1757 user_id: user_1,
1758 should_notify: false
1759 },
1760 Contact::Accepted {
1761 user_id: user_2,
1762 should_notify: false,
1763 },
1764 ]
1765 );
1766
1767 // Users send each other concurrent contact requests and
1768 // see that they are immediately accepted.
1769 db.send_contact_request(user_1, user_3).await.unwrap();
1770 db.send_contact_request(user_3, user_1).await.unwrap();
1771 assert_eq!(
1772 db.get_contacts(user_1).await.unwrap(),
1773 &[
1774 Contact::Accepted {
1775 user_id: user_1,
1776 should_notify: false
1777 },
1778 Contact::Accepted {
1779 user_id: user_2,
1780 should_notify: false,
1781 },
1782 Contact::Accepted {
1783 user_id: user_3,
1784 should_notify: false
1785 },
1786 ]
1787 );
1788 assert_eq!(
1789 db.get_contacts(user_3).await.unwrap(),
1790 &[
1791 Contact::Accepted {
1792 user_id: user_1,
1793 should_notify: false
1794 },
1795 Contact::Accepted {
1796 user_id: user_3,
1797 should_notify: false
1798 }
1799 ],
1800 );
1801
1802 // User declines a contact request. Both users see that it is gone.
1803 db.send_contact_request(user_2, user_3).await.unwrap();
1804 db.respond_to_contact_request(user_3, user_2, false)
1805 .await
1806 .unwrap();
1807 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1808 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1809 assert_eq!(
1810 db.get_contacts(user_2).await.unwrap(),
1811 &[
1812 Contact::Accepted {
1813 user_id: user_1,
1814 should_notify: false
1815 },
1816 Contact::Accepted {
1817 user_id: user_2,
1818 should_notify: false
1819 }
1820 ]
1821 );
1822 assert_eq!(
1823 db.get_contacts(user_3).await.unwrap(),
1824 &[
1825 Contact::Accepted {
1826 user_id: user_1,
1827 should_notify: false
1828 },
1829 Contact::Accepted {
1830 user_id: user_3,
1831 should_notify: false
1832 }
1833 ],
1834 );
1835 }
1836 }
1837
1838 #[tokio::test(flavor = "multi_thread")]
1839 async fn test_invite_codes() {
1840 let postgres = TestDb::postgres().await;
1841 let db = postgres.db();
1842 let user1 = db.create_user("user-1", None, false).await.unwrap();
1843
1844 // Initially, user 1 has no invite code
1845 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1846
1847 // Setting invite count to 0 when no code is assigned does not assign a new code
1848 db.set_invite_count(user1, 0).await.unwrap();
1849 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1850
1851 // User 1 creates an invite code that can be used twice.
1852 db.set_invite_count(user1, 2).await.unwrap();
1853 let (invite_code, invite_count) =
1854 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1855 assert_eq!(invite_count, 2);
1856
1857 // User 2 redeems the invite code and becomes a contact of user 1.
1858 let user2 = db
1859 .redeem_invite_code(&invite_code, "user-2", None)
1860 .await
1861 .unwrap();
1862 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1863 assert_eq!(invite_count, 1);
1864 assert_eq!(
1865 db.get_contacts(user1).await.unwrap(),
1866 [
1867 Contact::Accepted {
1868 user_id: user1,
1869 should_notify: false
1870 },
1871 Contact::Accepted {
1872 user_id: user2,
1873 should_notify: true
1874 }
1875 ]
1876 );
1877 assert_eq!(
1878 db.get_contacts(user2).await.unwrap(),
1879 [
1880 Contact::Accepted {
1881 user_id: user1,
1882 should_notify: false
1883 },
1884 Contact::Accepted {
1885 user_id: user2,
1886 should_notify: false
1887 }
1888 ]
1889 );
1890
1891 // User 3 redeems the invite code and becomes a contact of user 1.
1892 let user3 = db
1893 .redeem_invite_code(&invite_code, "user-3", None)
1894 .await
1895 .unwrap();
1896 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1897 assert_eq!(invite_count, 0);
1898 assert_eq!(
1899 db.get_contacts(user1).await.unwrap(),
1900 [
1901 Contact::Accepted {
1902 user_id: user1,
1903 should_notify: false
1904 },
1905 Contact::Accepted {
1906 user_id: user2,
1907 should_notify: true
1908 },
1909 Contact::Accepted {
1910 user_id: user3,
1911 should_notify: true
1912 }
1913 ]
1914 );
1915 assert_eq!(
1916 db.get_contacts(user3).await.unwrap(),
1917 [
1918 Contact::Accepted {
1919 user_id: user1,
1920 should_notify: false
1921 },
1922 Contact::Accepted {
1923 user_id: user3,
1924 should_notify: false
1925 },
1926 ]
1927 );
1928
1929 // Trying to reedem the code for the third time results in an error.
1930 db.redeem_invite_code(&invite_code, "user-4", None)
1931 .await
1932 .unwrap_err();
1933
1934 // Invite count can be updated after the code has been created.
1935 db.set_invite_count(user1, 2).await.unwrap();
1936 let (latest_code, invite_count) =
1937 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1938 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1939 assert_eq!(invite_count, 2);
1940
1941 // User 4 can now redeem the invite code and becomes a contact of user 1.
1942 let user4 = db
1943 .redeem_invite_code(&invite_code, "user-4", None)
1944 .await
1945 .unwrap();
1946 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1947 assert_eq!(invite_count, 1);
1948 assert_eq!(
1949 db.get_contacts(user1).await.unwrap(),
1950 [
1951 Contact::Accepted {
1952 user_id: user1,
1953 should_notify: false
1954 },
1955 Contact::Accepted {
1956 user_id: user2,
1957 should_notify: true
1958 },
1959 Contact::Accepted {
1960 user_id: user3,
1961 should_notify: true
1962 },
1963 Contact::Accepted {
1964 user_id: user4,
1965 should_notify: true
1966 }
1967 ]
1968 );
1969 assert_eq!(
1970 db.get_contacts(user4).await.unwrap(),
1971 [
1972 Contact::Accepted {
1973 user_id: user1,
1974 should_notify: false
1975 },
1976 Contact::Accepted {
1977 user_id: user4,
1978 should_notify: false
1979 },
1980 ]
1981 );
1982
1983 // An existing user cannot redeem invite codes.
1984 db.redeem_invite_code(&invite_code, "user-2", None)
1985 .await
1986 .unwrap_err();
1987 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1988 assert_eq!(invite_count, 1);
1989 }
1990
1991 pub struct TestDb {
1992 pub db: Option<Arc<dyn Db>>,
1993 pub url: String,
1994 }
1995
1996 impl TestDb {
1997 pub async fn postgres() -> Self {
1998 lazy_static! {
1999 static ref LOCK: Mutex<()> = Mutex::new(());
2000 }
2001
2002 let _guard = LOCK.lock();
2003 let mut rng = StdRng::from_entropy();
2004 let name = format!("zed-test-{}", rng.gen::<u128>());
2005 let url = format!("postgres://postgres@localhost/{}", name);
2006 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2007 Postgres::create_database(&url)
2008 .await
2009 .expect("failed to create test db");
2010 let db = PostgresDb::new(&url, 5).await.unwrap();
2011 let migrator = Migrator::new(migrations_path).await.unwrap();
2012 migrator.run(&db.pool).await.unwrap();
2013 Self {
2014 db: Some(Arc::new(db)),
2015 url,
2016 }
2017 }
2018
2019 pub fn fake(background: Arc<Background>) -> Self {
2020 Self {
2021 db: Some(Arc::new(FakeDb::new(background))),
2022 url: Default::default(),
2023 }
2024 }
2025
2026 pub fn db(&self) -> &Arc<dyn Db> {
2027 self.db.as_ref().unwrap()
2028 }
2029 }
2030
2031 impl Drop for TestDb {
2032 fn drop(&mut self) {
2033 if let Some(db) = self.db.take() {
2034 futures::executor::block_on(db.teardown(&self.url));
2035 }
2036 }
2037 }
2038
2039 pub struct FakeDb {
2040 background: Arc<Background>,
2041 pub users: Mutex<BTreeMap<UserId, User>>,
2042 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2043 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), usize>>,
2044 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2045 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2046 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2047 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2048 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2049 pub contacts: Mutex<Vec<FakeContact>>,
2050 next_channel_message_id: Mutex<i32>,
2051 next_user_id: Mutex<i32>,
2052 next_org_id: Mutex<i32>,
2053 next_channel_id: Mutex<i32>,
2054 next_project_id: Mutex<i32>,
2055 }
2056
2057 #[derive(Debug)]
2058 pub struct FakeContact {
2059 pub requester_id: UserId,
2060 pub responder_id: UserId,
2061 pub accepted: bool,
2062 pub should_notify: bool,
2063 }
2064
2065 impl FakeDb {
2066 pub fn new(background: Arc<Background>) -> Self {
2067 Self {
2068 background,
2069 users: Default::default(),
2070 next_user_id: Mutex::new(1),
2071 projects: Default::default(),
2072 worktree_extensions: Default::default(),
2073 next_project_id: Mutex::new(1),
2074 orgs: Default::default(),
2075 next_org_id: Mutex::new(1),
2076 org_memberships: Default::default(),
2077 channels: Default::default(),
2078 next_channel_id: Mutex::new(1),
2079 channel_memberships: Default::default(),
2080 channel_messages: Default::default(),
2081 next_channel_message_id: Mutex::new(1),
2082 contacts: Default::default(),
2083 }
2084 }
2085 }
2086
2087 #[async_trait]
2088 impl Db for FakeDb {
2089 async fn create_user(
2090 &self,
2091 github_login: &str,
2092 email_address: Option<&str>,
2093 admin: bool,
2094 ) -> Result<UserId> {
2095 self.background.simulate_random_delay().await;
2096
2097 let mut users = self.users.lock();
2098 if let Some(user) = users
2099 .values()
2100 .find(|user| user.github_login == github_login)
2101 {
2102 Ok(user.id)
2103 } else {
2104 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2105 users.insert(
2106 user_id,
2107 User {
2108 id: user_id,
2109 github_login: github_login.to_string(),
2110 email_address: email_address.map(str::to_string),
2111 admin,
2112 invite_code: None,
2113 invite_count: 0,
2114 connected_once: false,
2115 },
2116 );
2117 Ok(user_id)
2118 }
2119 }
2120
2121 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2122 unimplemented!()
2123 }
2124
2125 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2126 unimplemented!()
2127 }
2128
2129 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2130 unimplemented!()
2131 }
2132
2133 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2134 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2135 }
2136
2137 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2138 self.background.simulate_random_delay().await;
2139 let users = self.users.lock();
2140 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2141 }
2142
2143 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2144 Ok(self
2145 .users
2146 .lock()
2147 .values()
2148 .find(|user| user.github_login == github_login)
2149 .cloned())
2150 }
2151
2152 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2153 unimplemented!()
2154 }
2155
2156 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2157 self.background.simulate_random_delay().await;
2158 let mut users = self.users.lock();
2159 let mut user = users
2160 .get_mut(&id)
2161 .ok_or_else(|| anyhow!("user not found"))?;
2162 user.connected_once = connected_once;
2163 Ok(())
2164 }
2165
2166 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2167 unimplemented!()
2168 }
2169
2170 // invite codes
2171
2172 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2173 unimplemented!()
2174 }
2175
2176 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2177 Ok(None)
2178 }
2179
2180 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2181 unimplemented!()
2182 }
2183
2184 async fn redeem_invite_code(
2185 &self,
2186 _code: &str,
2187 _login: &str,
2188 _email_address: Option<&str>,
2189 ) -> Result<UserId> {
2190 unimplemented!()
2191 }
2192
2193 // projects
2194
2195 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2196 self.background.simulate_random_delay().await;
2197 if !self.users.lock().contains_key(&host_user_id) {
2198 Err(anyhow!("no such user"))?;
2199 }
2200
2201 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2202 self.projects.lock().insert(
2203 project_id,
2204 Project {
2205 id: project_id,
2206 host_user_id,
2207 unregistered: false,
2208 },
2209 );
2210 Ok(project_id)
2211 }
2212
2213 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2214 self.projects
2215 .lock()
2216 .get_mut(&project_id)
2217 .ok_or_else(|| anyhow!("no such project"))?
2218 .unregistered = true;
2219 Ok(())
2220 }
2221
2222 async fn update_worktree_extensions(
2223 &self,
2224 project_id: ProjectId,
2225 worktree_id: u64,
2226 extensions: HashMap<String, usize>,
2227 ) -> Result<()> {
2228 self.background.simulate_random_delay().await;
2229 if !self.projects.lock().contains_key(&project_id) {
2230 Err(anyhow!("no such project"))?;
2231 }
2232
2233 for (extension, count) in extensions {
2234 self.worktree_extensions
2235 .lock()
2236 .insert((project_id, worktree_id, extension), count);
2237 }
2238
2239 Ok(())
2240 }
2241
2242 async fn get_project_extensions(
2243 &self,
2244 _project_id: ProjectId,
2245 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2246 unimplemented!()
2247 }
2248
2249 async fn record_project_activity(
2250 &self,
2251 _period: Range<OffsetDateTime>,
2252 _active_projects: &[(UserId, ProjectId)],
2253 ) -> Result<()> {
2254 unimplemented!()
2255 }
2256
2257 async fn summarize_project_activity(
2258 &self,
2259 _period: Range<OffsetDateTime>,
2260 _limit: usize,
2261 ) -> Result<Vec<UserActivitySummary>> {
2262 unimplemented!()
2263 }
2264
2265 // contacts
2266
2267 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2268 self.background.simulate_random_delay().await;
2269 let mut contacts = vec![Contact::Accepted {
2270 user_id: id,
2271 should_notify: false,
2272 }];
2273
2274 for contact in self.contacts.lock().iter() {
2275 if contact.requester_id == id {
2276 if contact.accepted {
2277 contacts.push(Contact::Accepted {
2278 user_id: contact.responder_id,
2279 should_notify: contact.should_notify,
2280 });
2281 } else {
2282 contacts.push(Contact::Outgoing {
2283 user_id: contact.responder_id,
2284 });
2285 }
2286 } else if contact.responder_id == id {
2287 if contact.accepted {
2288 contacts.push(Contact::Accepted {
2289 user_id: contact.requester_id,
2290 should_notify: false,
2291 });
2292 } else {
2293 contacts.push(Contact::Incoming {
2294 user_id: contact.requester_id,
2295 should_notify: contact.should_notify,
2296 });
2297 }
2298 }
2299 }
2300
2301 contacts.sort_unstable_by_key(|contact| contact.user_id());
2302 Ok(contacts)
2303 }
2304
2305 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2306 self.background.simulate_random_delay().await;
2307 Ok(self.contacts.lock().iter().any(|contact| {
2308 contact.accepted
2309 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2310 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2311 }))
2312 }
2313
2314 async fn send_contact_request(
2315 &self,
2316 requester_id: UserId,
2317 responder_id: UserId,
2318 ) -> Result<()> {
2319 let mut contacts = self.contacts.lock();
2320 for contact in contacts.iter_mut() {
2321 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2322 if contact.accepted {
2323 Err(anyhow!("contact already exists"))?;
2324 } else {
2325 Err(anyhow!("contact already requested"))?;
2326 }
2327 }
2328 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2329 if contact.accepted {
2330 Err(anyhow!("contact already exists"))?;
2331 } else {
2332 contact.accepted = true;
2333 contact.should_notify = false;
2334 return Ok(());
2335 }
2336 }
2337 }
2338 contacts.push(FakeContact {
2339 requester_id,
2340 responder_id,
2341 accepted: false,
2342 should_notify: true,
2343 });
2344 Ok(())
2345 }
2346
2347 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2348 self.contacts.lock().retain(|contact| {
2349 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2350 });
2351 Ok(())
2352 }
2353
2354 async fn dismiss_contact_notification(
2355 &self,
2356 user_id: UserId,
2357 contact_user_id: UserId,
2358 ) -> Result<()> {
2359 let mut contacts = self.contacts.lock();
2360 for contact in contacts.iter_mut() {
2361 if contact.requester_id == contact_user_id
2362 && contact.responder_id == user_id
2363 && !contact.accepted
2364 {
2365 contact.should_notify = false;
2366 return Ok(());
2367 }
2368 if contact.requester_id == user_id
2369 && contact.responder_id == contact_user_id
2370 && contact.accepted
2371 {
2372 contact.should_notify = false;
2373 return Ok(());
2374 }
2375 }
2376 Err(anyhow!("no such notification"))?
2377 }
2378
2379 async fn respond_to_contact_request(
2380 &self,
2381 responder_id: UserId,
2382 requester_id: UserId,
2383 accept: bool,
2384 ) -> Result<()> {
2385 let mut contacts = self.contacts.lock();
2386 for (ix, contact) in contacts.iter_mut().enumerate() {
2387 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2388 if contact.accepted {
2389 Err(anyhow!("contact already confirmed"))?;
2390 }
2391 if accept {
2392 contact.accepted = true;
2393 contact.should_notify = true;
2394 } else {
2395 contacts.remove(ix);
2396 }
2397 return Ok(());
2398 }
2399 }
2400 Err(anyhow!("no such contact request"))?
2401 }
2402
2403 async fn create_access_token_hash(
2404 &self,
2405 _user_id: UserId,
2406 _access_token_hash: &str,
2407 _max_access_token_count: usize,
2408 ) -> Result<()> {
2409 unimplemented!()
2410 }
2411
2412 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2413 unimplemented!()
2414 }
2415
2416 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2417 unimplemented!()
2418 }
2419
2420 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2421 self.background.simulate_random_delay().await;
2422 let mut orgs = self.orgs.lock();
2423 if orgs.values().any(|org| org.slug == slug) {
2424 Err(anyhow!("org already exists"))?
2425 } else {
2426 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2427 orgs.insert(
2428 org_id,
2429 Org {
2430 id: org_id,
2431 name: name.to_string(),
2432 slug: slug.to_string(),
2433 },
2434 );
2435 Ok(org_id)
2436 }
2437 }
2438
2439 async fn add_org_member(
2440 &self,
2441 org_id: OrgId,
2442 user_id: UserId,
2443 is_admin: bool,
2444 ) -> Result<()> {
2445 self.background.simulate_random_delay().await;
2446 if !self.orgs.lock().contains_key(&org_id) {
2447 Err(anyhow!("org does not exist"))?;
2448 }
2449 if !self.users.lock().contains_key(&user_id) {
2450 Err(anyhow!("user does not exist"))?;
2451 }
2452
2453 self.org_memberships
2454 .lock()
2455 .entry((org_id, user_id))
2456 .or_insert(is_admin);
2457 Ok(())
2458 }
2459
2460 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2461 self.background.simulate_random_delay().await;
2462 if !self.orgs.lock().contains_key(&org_id) {
2463 Err(anyhow!("org does not exist"))?;
2464 }
2465
2466 let mut channels = self.channels.lock();
2467 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2468 channels.insert(
2469 channel_id,
2470 Channel {
2471 id: channel_id,
2472 name: name.to_string(),
2473 owner_id: org_id.0,
2474 owner_is_user: false,
2475 },
2476 );
2477 Ok(channel_id)
2478 }
2479
2480 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2481 self.background.simulate_random_delay().await;
2482 Ok(self
2483 .channels
2484 .lock()
2485 .values()
2486 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2487 .cloned()
2488 .collect())
2489 }
2490
2491 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2492 self.background.simulate_random_delay().await;
2493 let channels = self.channels.lock();
2494 let memberships = self.channel_memberships.lock();
2495 Ok(channels
2496 .values()
2497 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2498 .cloned()
2499 .collect())
2500 }
2501
2502 async fn can_user_access_channel(
2503 &self,
2504 user_id: UserId,
2505 channel_id: ChannelId,
2506 ) -> Result<bool> {
2507 self.background.simulate_random_delay().await;
2508 Ok(self
2509 .channel_memberships
2510 .lock()
2511 .contains_key(&(channel_id, user_id)))
2512 }
2513
2514 async fn add_channel_member(
2515 &self,
2516 channel_id: ChannelId,
2517 user_id: UserId,
2518 is_admin: bool,
2519 ) -> Result<()> {
2520 self.background.simulate_random_delay().await;
2521 if !self.channels.lock().contains_key(&channel_id) {
2522 Err(anyhow!("channel does not exist"))?;
2523 }
2524 if !self.users.lock().contains_key(&user_id) {
2525 Err(anyhow!("user does not exist"))?;
2526 }
2527
2528 self.channel_memberships
2529 .lock()
2530 .entry((channel_id, user_id))
2531 .or_insert(is_admin);
2532 Ok(())
2533 }
2534
2535 async fn create_channel_message(
2536 &self,
2537 channel_id: ChannelId,
2538 sender_id: UserId,
2539 body: &str,
2540 timestamp: OffsetDateTime,
2541 nonce: u128,
2542 ) -> Result<MessageId> {
2543 self.background.simulate_random_delay().await;
2544 if !self.channels.lock().contains_key(&channel_id) {
2545 Err(anyhow!("channel does not exist"))?;
2546 }
2547 if !self.users.lock().contains_key(&sender_id) {
2548 Err(anyhow!("user does not exist"))?;
2549 }
2550
2551 let mut messages = self.channel_messages.lock();
2552 if let Some(message) = messages
2553 .values()
2554 .find(|message| message.nonce.as_u128() == nonce)
2555 {
2556 Ok(message.id)
2557 } else {
2558 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2559 messages.insert(
2560 message_id,
2561 ChannelMessage {
2562 id: message_id,
2563 channel_id,
2564 sender_id,
2565 body: body.to_string(),
2566 sent_at: timestamp,
2567 nonce: Uuid::from_u128(nonce),
2568 },
2569 );
2570 Ok(message_id)
2571 }
2572 }
2573
2574 async fn get_channel_messages(
2575 &self,
2576 channel_id: ChannelId,
2577 count: usize,
2578 before_id: Option<MessageId>,
2579 ) -> Result<Vec<ChannelMessage>> {
2580 let mut messages = self
2581 .channel_messages
2582 .lock()
2583 .values()
2584 .rev()
2585 .filter(|message| {
2586 message.channel_id == channel_id
2587 && message.id < before_id.unwrap_or(MessageId::MAX)
2588 })
2589 .take(count)
2590 .cloned()
2591 .collect::<Vec<_>>();
2592 messages.sort_unstable_by_key(|message| message.id);
2593 Ok(messages)
2594 }
2595
2596 async fn teardown(&self, _: &str) {}
2597
2598 #[cfg(test)]
2599 fn as_fake(&self) -> Option<&FakeDb> {
2600 Some(self)
2601 }
2602 }
2603}