1use std::{ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::Serialize;
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
29 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
30 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
31 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
32 async fn destroy_user(&self, id: UserId) -> Result<()>;
33
34 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
35 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
36 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
37 async fn redeem_invite_code(
38 &self,
39 code: &str,
40 login: &str,
41 email_address: Option<&str>,
42 ) -> Result<UserId>;
43
44 /// Registers a new project for the given user.
45 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
46
47 /// Unregisters a project for the given project id.
48 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
49
50 /// Update file counts by extension for the given project and worktree.
51 async fn update_worktree_extensions(
52 &self,
53 project_id: ProjectId,
54 worktree_id: u64,
55 extensions: HashMap<String, usize>,
56 ) -> Result<()>;
57
58 /// Get the file counts on the given project keyed by their worktree and extension.
59 async fn get_project_extensions(
60 &self,
61 project_id: ProjectId,
62 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
63
64 /// Record which users have been active in which projects during
65 /// a given period of time.
66 async fn record_project_activity(
67 &self,
68 time_period: Range<OffsetDateTime>,
69 active_projects: &[(UserId, ProjectId)],
70 ) -> Result<()>;
71
72 /// Get the users that have been most active during the given time period,
73 /// along with the amount of time they have been active in each project.
74 async fn summarize_project_activity(
75 &self,
76 time_period: Range<OffsetDateTime>,
77 max_user_count: usize,
78 ) -> Result<Vec<UserActivitySummary>>;
79
80 /// Get the project activity for the given user and time period.
81 async fn summarize_user_activity(
82 &self,
83 user_id: UserId,
84 time_period: Range<OffsetDateTime>,
85 ) -> Result<Vec<UserActivityDuration>>;
86
87 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
88 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
89 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
90 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
91 async fn dismiss_contact_notification(
92 &self,
93 responder_id: UserId,
94 requester_id: UserId,
95 ) -> Result<()>;
96 async fn respond_to_contact_request(
97 &self,
98 responder_id: UserId,
99 requester_id: UserId,
100 accept: bool,
101 ) -> Result<()>;
102
103 async fn create_access_token_hash(
104 &self,
105 user_id: UserId,
106 access_token_hash: &str,
107 max_access_token_count: usize,
108 ) -> Result<()>;
109 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
110 #[cfg(any(test, feature = "seed-support"))]
111
112 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
113 #[cfg(any(test, feature = "seed-support"))]
114 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
115 #[cfg(any(test, feature = "seed-support"))]
116 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
117 #[cfg(any(test, feature = "seed-support"))]
118 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
119 #[cfg(any(test, feature = "seed-support"))]
120
121 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
122 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
123 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
124 -> Result<bool>;
125 #[cfg(any(test, feature = "seed-support"))]
126 async fn add_channel_member(
127 &self,
128 channel_id: ChannelId,
129 user_id: UserId,
130 is_admin: bool,
131 ) -> Result<()>;
132 async fn create_channel_message(
133 &self,
134 channel_id: ChannelId,
135 sender_id: UserId,
136 body: &str,
137 timestamp: OffsetDateTime,
138 nonce: u128,
139 ) -> Result<MessageId>;
140 async fn get_channel_messages(
141 &self,
142 channel_id: ChannelId,
143 count: usize,
144 before_id: Option<MessageId>,
145 ) -> Result<Vec<ChannelMessage>>;
146 #[cfg(test)]
147 async fn teardown(&self, url: &str);
148 #[cfg(test)]
149 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
150}
151
152pub struct PostgresDb {
153 pool: sqlx::PgPool,
154}
155
156impl PostgresDb {
157 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
158 let pool = DbOptions::new()
159 .max_connections(max_connections)
160 .connect(&url)
161 .await
162 .context("failed to connect to postgres database")?;
163 Ok(Self { pool })
164 }
165}
166
167#[async_trait]
168impl Db for PostgresDb {
169 // users
170
171 async fn create_user(
172 &self,
173 github_login: &str,
174 email_address: Option<&str>,
175 admin: bool,
176 ) -> Result<UserId> {
177 let query = "
178 INSERT INTO users (github_login, email_address, admin)
179 VALUES ($1, $2, $3)
180 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
181 RETURNING id
182 ";
183 Ok(sqlx::query_scalar(query)
184 .bind(github_login)
185 .bind(email_address)
186 .bind(admin)
187 .fetch_one(&self.pool)
188 .await
189 .map(UserId)?)
190 }
191
192 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
193 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
194 Ok(sqlx::query_as(query)
195 .bind(limit as i32)
196 .bind((page * limit) as i32)
197 .fetch_all(&self.pool)
198 .await?)
199 }
200
201 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
202 let mut query = QueryBuilder::new(
203 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
204 );
205 query.push_values(
206 users,
207 |mut query, (github_login, email_address, invite_count)| {
208 query
209 .push_bind(github_login)
210 .push_bind(email_address)
211 .push_bind(false)
212 .push_bind(nanoid!(16))
213 .push_bind(invite_count as i32);
214 },
215 );
216 query.push(
217 "
218 ON CONFLICT (github_login) DO UPDATE SET
219 github_login = excluded.github_login,
220 invite_count = excluded.invite_count,
221 invite_code = CASE WHEN users.invite_code IS NULL
222 THEN excluded.invite_code
223 ELSE users.invite_code
224 END
225 RETURNING id
226 ",
227 );
228
229 let rows = query.build().fetch_all(&self.pool).await?;
230 Ok(rows
231 .into_iter()
232 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
233 .collect())
234 }
235
236 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
237 let like_string = fuzzy_like_string(name_query);
238 let query = "
239 SELECT users.*
240 FROM users
241 WHERE github_login ILIKE $1
242 ORDER BY github_login <-> $2
243 LIMIT $3
244 ";
245 Ok(sqlx::query_as(query)
246 .bind(like_string)
247 .bind(name_query)
248 .bind(limit as i32)
249 .fetch_all(&self.pool)
250 .await?)
251 }
252
253 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
254 let users = self.get_users_by_ids(vec![id]).await?;
255 Ok(users.into_iter().next())
256 }
257
258 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
259 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
260 let query = "
261 SELECT users.*
262 FROM users
263 WHERE users.id = ANY ($1)
264 ";
265 Ok(sqlx::query_as(query)
266 .bind(&ids)
267 .fetch_all(&self.pool)
268 .await?)
269 }
270
271 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
272 let query = format!(
273 "
274 SELECT users.*
275 FROM users
276 WHERE invite_count = 0
277 AND inviter_id IS{} NULL
278 ",
279 if invited_by_another_user { " NOT" } else { "" }
280 );
281
282 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
283 }
284
285 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
286 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
287 Ok(sqlx::query_as(query)
288 .bind(github_login)
289 .fetch_optional(&self.pool)
290 .await?)
291 }
292
293 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
294 let query = "UPDATE users SET admin = $1 WHERE id = $2";
295 Ok(sqlx::query(query)
296 .bind(is_admin)
297 .bind(id.0)
298 .execute(&self.pool)
299 .await
300 .map(drop)?)
301 }
302
303 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
304 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
305 Ok(sqlx::query(query)
306 .bind(connected_once)
307 .bind(id.0)
308 .execute(&self.pool)
309 .await
310 .map(drop)?)
311 }
312
313 async fn destroy_user(&self, id: UserId) -> Result<()> {
314 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
315 sqlx::query(query)
316 .bind(id.0)
317 .execute(&self.pool)
318 .await
319 .map(drop)?;
320 let query = "DELETE FROM users WHERE id = $1;";
321 Ok(sqlx::query(query)
322 .bind(id.0)
323 .execute(&self.pool)
324 .await
325 .map(drop)?)
326 }
327
328 // invite codes
329
330 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
331 let mut tx = self.pool.begin().await?;
332 if count > 0 {
333 sqlx::query(
334 "
335 UPDATE users
336 SET invite_code = $1
337 WHERE id = $2 AND invite_code IS NULL
338 ",
339 )
340 .bind(nanoid!(16))
341 .bind(id)
342 .execute(&mut tx)
343 .await?;
344 }
345
346 sqlx::query(
347 "
348 UPDATE users
349 SET invite_count = $1
350 WHERE id = $2
351 ",
352 )
353 .bind(count as i32)
354 .bind(id)
355 .execute(&mut tx)
356 .await?;
357 tx.commit().await?;
358 Ok(())
359 }
360
361 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
362 let result: Option<(String, i32)> = sqlx::query_as(
363 "
364 SELECT invite_code, invite_count
365 FROM users
366 WHERE id = $1 AND invite_code IS NOT NULL
367 ",
368 )
369 .bind(id)
370 .fetch_optional(&self.pool)
371 .await?;
372 if let Some((code, count)) = result {
373 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
374 } else {
375 Ok(None)
376 }
377 }
378
379 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
380 sqlx::query_as(
381 "
382 SELECT *
383 FROM users
384 WHERE invite_code = $1
385 ",
386 )
387 .bind(code)
388 .fetch_optional(&self.pool)
389 .await?
390 .ok_or_else(|| {
391 Error::Http(
392 StatusCode::NOT_FOUND,
393 "that invite code does not exist".to_string(),
394 )
395 })
396 }
397
398 async fn redeem_invite_code(
399 &self,
400 code: &str,
401 login: &str,
402 email_address: Option<&str>,
403 ) -> Result<UserId> {
404 let mut tx = self.pool.begin().await?;
405
406 let inviter_id: Option<UserId> = sqlx::query_scalar(
407 "
408 UPDATE users
409 SET invite_count = invite_count - 1
410 WHERE
411 invite_code = $1 AND
412 invite_count > 0
413 RETURNING id
414 ",
415 )
416 .bind(code)
417 .fetch_optional(&mut tx)
418 .await?;
419
420 let inviter_id = match inviter_id {
421 Some(inviter_id) => inviter_id,
422 None => {
423 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
424 .bind(code)
425 .fetch_optional(&mut tx)
426 .await?
427 .is_some()
428 {
429 Err(Error::Http(
430 StatusCode::UNAUTHORIZED,
431 "no invites remaining".to_string(),
432 ))?
433 } else {
434 Err(Error::Http(
435 StatusCode::NOT_FOUND,
436 "invite code not found".to_string(),
437 ))?
438 }
439 }
440 };
441
442 let invitee_id = sqlx::query_scalar(
443 "
444 INSERT INTO users
445 (github_login, email_address, admin, inviter_id)
446 VALUES
447 ($1, $2, 'f', $3)
448 RETURNING id
449 ",
450 )
451 .bind(login)
452 .bind(email_address)
453 .bind(inviter_id)
454 .fetch_one(&mut tx)
455 .await
456 .map(UserId)?;
457
458 sqlx::query(
459 "
460 INSERT INTO contacts
461 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
462 VALUES
463 ($1, $2, 't', 't', 't')
464 ",
465 )
466 .bind(inviter_id)
467 .bind(invitee_id)
468 .execute(&mut tx)
469 .await?;
470
471 tx.commit().await?;
472 Ok(invitee_id)
473 }
474
475 // projects
476
477 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
478 Ok(sqlx::query_scalar(
479 "
480 INSERT INTO projects(host_user_id)
481 VALUES ($1)
482 RETURNING id
483 ",
484 )
485 .bind(host_user_id)
486 .fetch_one(&self.pool)
487 .await
488 .map(ProjectId)?)
489 }
490
491 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
492 sqlx::query(
493 "
494 UPDATE projects
495 SET unregistered = 't'
496 WHERE id = $1
497 ",
498 )
499 .bind(project_id)
500 .execute(&self.pool)
501 .await?;
502 Ok(())
503 }
504
505 async fn update_worktree_extensions(
506 &self,
507 project_id: ProjectId,
508 worktree_id: u64,
509 extensions: HashMap<String, usize>,
510 ) -> Result<()> {
511 if extensions.is_empty() {
512 return Ok(());
513 }
514
515 let mut query = QueryBuilder::new(
516 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
517 );
518 query.push_values(extensions, |mut query, (extension, count)| {
519 query
520 .push_bind(project_id)
521 .push_bind(worktree_id as i32)
522 .push_bind(extension)
523 .push_bind(count as i32);
524 });
525 query.push(
526 "
527 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
528 count = excluded.count
529 ",
530 );
531 query.build().execute(&self.pool).await?;
532
533 Ok(())
534 }
535
536 async fn get_project_extensions(
537 &self,
538 project_id: ProjectId,
539 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
540 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
541 struct WorktreeExtension {
542 worktree_id: i32,
543 extension: String,
544 count: i32,
545 }
546
547 let query = "
548 SELECT worktree_id, extension, count
549 FROM worktree_extensions
550 WHERE project_id = $1
551 ";
552 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
553 .bind(&project_id)
554 .fetch_all(&self.pool)
555 .await?;
556
557 let mut extension_counts = HashMap::default();
558 for count in counts {
559 extension_counts
560 .entry(count.worktree_id as u64)
561 .or_insert(HashMap::default())
562 .insert(count.extension, count.count as usize);
563 }
564 Ok(extension_counts)
565 }
566
567 async fn record_project_activity(
568 &self,
569 time_period: Range<OffsetDateTime>,
570 projects: &[(UserId, ProjectId)],
571 ) -> Result<()> {
572 let query = "
573 INSERT INTO project_activity_periods
574 (ended_at, duration_millis, user_id, project_id)
575 VALUES
576 ($1, $2, $3, $4);
577 ";
578
579 let mut tx = self.pool.begin().await?;
580 let duration_millis =
581 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
582 for (user_id, project_id) in projects {
583 sqlx::query(query)
584 .bind(time_period.end)
585 .bind(duration_millis)
586 .bind(user_id)
587 .bind(project_id)
588 .execute(&mut tx)
589 .await?;
590 }
591 tx.commit().await?;
592
593 Ok(())
594 }
595
596 async fn summarize_project_activity(
597 &self,
598 time_period: Range<OffsetDateTime>,
599 max_user_count: usize,
600 ) -> Result<Vec<UserActivitySummary>> {
601 let query = "
602 WITH
603 project_durations AS (
604 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
605 FROM project_activity_periods
606 WHERE $1 < ended_at AND ended_at <= $2
607 GROUP BY user_id, project_id
608 ),
609 user_durations AS (
610 SELECT user_id, SUM(project_duration) as total_duration
611 FROM project_durations
612 GROUP BY user_id
613 ORDER BY total_duration DESC
614 LIMIT $3
615 )
616 SELECT user_durations.user_id, users.github_login, project_id, project_duration
617 FROM user_durations, project_durations, users
618 WHERE
619 user_durations.user_id = project_durations.user_id AND
620 user_durations.user_id = users.id
621 ORDER BY user_id ASC, project_duration DESC
622 ";
623
624 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
625 .bind(time_period.start)
626 .bind(time_period.end)
627 .bind(max_user_count as i32)
628 .fetch(&self.pool);
629
630 let mut result = Vec::<UserActivitySummary>::new();
631 while let Some(row) = rows.next().await {
632 let (user_id, github_login, project_id, duration_millis) = row?;
633 let project_id = project_id;
634 let duration = Duration::from_millis(duration_millis as u64);
635 if let Some(last_summary) = result.last_mut() {
636 if last_summary.id == user_id {
637 last_summary.project_activity.push((project_id, duration));
638 continue;
639 }
640 }
641 result.push(UserActivitySummary {
642 id: user_id,
643 project_activity: vec![(project_id, duration)],
644 github_login,
645 });
646 }
647
648 Ok(result)
649 }
650
651 async fn summarize_user_activity(
652 &self,
653 user_id: UserId,
654 time_period: Range<OffsetDateTime>,
655 ) -> Result<Vec<UserActivityDuration>> {
656 const COALESCE_THRESHOLD: Duration = Duration::from_secs(5);
657
658 let query = "
659 SELECT
660 project_activity_periods.ended_at,
661 project_activity_periods.duration_millis,
662 project_activity_periods.project_id,
663 worktree_extensions.extension,
664 worktree_extensions.count
665 FROM project_activity_periods
666 LEFT OUTER JOIN
667 worktree_extensions
668 ON
669 project_activity_periods.project_id = worktree_extensions.project_id
670 WHERE
671 project_activity_periods.user_id = $1 AND
672 $2 < project_activity_periods.ended_at AND
673 project_activity_periods.ended_at <= $3
674 ORDER BY project_activity_periods.id ASC
675 ";
676
677 let mut rows = sqlx::query_as::<
678 _,
679 (
680 PrimitiveDateTime,
681 i32,
682 ProjectId,
683 Option<String>,
684 Option<i32>,
685 ),
686 >(query)
687 .bind(user_id)
688 .bind(time_period.start)
689 .bind(time_period.end)
690 .fetch(&self.pool);
691
692 let mut durations: HashMap<ProjectId, Vec<UserActivityDuration>> = Default::default();
693 while let Some(row) = rows.next().await {
694 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
695 let ended_at = ended_at.assume_utc();
696 let duration = Duration::from_millis(duration_millis as u64);
697 let started_at = ended_at - duration;
698 let project_durations = durations.entry(project_id).or_default();
699
700 if let Some(prev_duration) = project_durations.last_mut() {
701 if started_at - prev_duration.end <= COALESCE_THRESHOLD {
702 prev_duration.end = ended_at;
703 } else {
704 project_durations.push(UserActivityDuration {
705 project_id,
706 start: started_at,
707 end: ended_at,
708 extensions: Default::default(),
709 });
710 }
711 } else {
712 project_durations.push(UserActivityDuration {
713 project_id,
714 start: started_at,
715 end: ended_at,
716 extensions: Default::default(),
717 });
718 }
719
720 if let Some((extension, extension_count)) = extension.zip(extension_count) {
721 project_durations
722 .last_mut()
723 .unwrap()
724 .extensions
725 .insert(extension, extension_count as usize);
726 }
727 }
728
729 let mut durations = durations.into_values().flatten().collect::<Vec<_>>();
730 durations.sort_unstable_by_key(|duration| duration.start);
731 Ok(durations)
732 }
733
734 // contacts
735
736 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
737 let query = "
738 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
739 FROM contacts
740 WHERE user_id_a = $1 OR user_id_b = $1;
741 ";
742
743 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
744 .bind(user_id)
745 .fetch(&self.pool);
746
747 let mut contacts = vec![Contact::Accepted {
748 user_id,
749 should_notify: false,
750 }];
751 while let Some(row) = rows.next().await {
752 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
753
754 if user_id_a == user_id {
755 if accepted {
756 contacts.push(Contact::Accepted {
757 user_id: user_id_b,
758 should_notify: should_notify && a_to_b,
759 });
760 } else if a_to_b {
761 contacts.push(Contact::Outgoing { user_id: user_id_b })
762 } else {
763 contacts.push(Contact::Incoming {
764 user_id: user_id_b,
765 should_notify,
766 });
767 }
768 } else {
769 if accepted {
770 contacts.push(Contact::Accepted {
771 user_id: user_id_a,
772 should_notify: should_notify && !a_to_b,
773 });
774 } else if a_to_b {
775 contacts.push(Contact::Incoming {
776 user_id: user_id_a,
777 should_notify,
778 });
779 } else {
780 contacts.push(Contact::Outgoing { user_id: user_id_a });
781 }
782 }
783 }
784
785 contacts.sort_unstable_by_key(|contact| contact.user_id());
786
787 Ok(contacts)
788 }
789
790 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
791 let (id_a, id_b) = if user_id_1 < user_id_2 {
792 (user_id_1, user_id_2)
793 } else {
794 (user_id_2, user_id_1)
795 };
796
797 let query = "
798 SELECT 1 FROM contacts
799 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
800 LIMIT 1
801 ";
802 Ok(sqlx::query_scalar::<_, i32>(query)
803 .bind(id_a.0)
804 .bind(id_b.0)
805 .fetch_optional(&self.pool)
806 .await?
807 .is_some())
808 }
809
810 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
811 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
812 (sender_id, receiver_id, true)
813 } else {
814 (receiver_id, sender_id, false)
815 };
816 let query = "
817 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
818 VALUES ($1, $2, $3, 'f', 't')
819 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
820 SET
821 accepted = 't',
822 should_notify = 'f'
823 WHERE
824 NOT contacts.accepted AND
825 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
826 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
827 ";
828 let result = sqlx::query(query)
829 .bind(id_a.0)
830 .bind(id_b.0)
831 .bind(a_to_b)
832 .execute(&self.pool)
833 .await?;
834
835 if result.rows_affected() == 1 {
836 Ok(())
837 } else {
838 Err(anyhow!("contact already requested"))?
839 }
840 }
841
842 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
843 let (id_a, id_b) = if responder_id < requester_id {
844 (responder_id, requester_id)
845 } else {
846 (requester_id, responder_id)
847 };
848 let query = "
849 DELETE FROM contacts
850 WHERE user_id_a = $1 AND user_id_b = $2;
851 ";
852 let result = sqlx::query(query)
853 .bind(id_a.0)
854 .bind(id_b.0)
855 .execute(&self.pool)
856 .await?;
857
858 if result.rows_affected() == 1 {
859 Ok(())
860 } else {
861 Err(anyhow!("no such contact"))?
862 }
863 }
864
865 async fn dismiss_contact_notification(
866 &self,
867 user_id: UserId,
868 contact_user_id: UserId,
869 ) -> Result<()> {
870 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
871 (user_id, contact_user_id, true)
872 } else {
873 (contact_user_id, user_id, false)
874 };
875
876 let query = "
877 UPDATE contacts
878 SET should_notify = 'f'
879 WHERE
880 user_id_a = $1 AND user_id_b = $2 AND
881 (
882 (a_to_b = $3 AND accepted) OR
883 (a_to_b != $3 AND NOT accepted)
884 );
885 ";
886
887 let result = sqlx::query(query)
888 .bind(id_a.0)
889 .bind(id_b.0)
890 .bind(a_to_b)
891 .execute(&self.pool)
892 .await?;
893
894 if result.rows_affected() == 0 {
895 Err(anyhow!("no such contact request"))?;
896 }
897
898 Ok(())
899 }
900
901 async fn respond_to_contact_request(
902 &self,
903 responder_id: UserId,
904 requester_id: UserId,
905 accept: bool,
906 ) -> Result<()> {
907 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
908 (responder_id, requester_id, false)
909 } else {
910 (requester_id, responder_id, true)
911 };
912 let result = if accept {
913 let query = "
914 UPDATE contacts
915 SET accepted = 't', should_notify = 't'
916 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
917 ";
918 sqlx::query(query)
919 .bind(id_a.0)
920 .bind(id_b.0)
921 .bind(a_to_b)
922 .execute(&self.pool)
923 .await?
924 } else {
925 let query = "
926 DELETE FROM contacts
927 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
928 ";
929 sqlx::query(query)
930 .bind(id_a.0)
931 .bind(id_b.0)
932 .bind(a_to_b)
933 .execute(&self.pool)
934 .await?
935 };
936 if result.rows_affected() == 1 {
937 Ok(())
938 } else {
939 Err(anyhow!("no such contact request"))?
940 }
941 }
942
943 // access tokens
944
945 async fn create_access_token_hash(
946 &self,
947 user_id: UserId,
948 access_token_hash: &str,
949 max_access_token_count: usize,
950 ) -> Result<()> {
951 let insert_query = "
952 INSERT INTO access_tokens (user_id, hash)
953 VALUES ($1, $2);
954 ";
955 let cleanup_query = "
956 DELETE FROM access_tokens
957 WHERE id IN (
958 SELECT id from access_tokens
959 WHERE user_id = $1
960 ORDER BY id DESC
961 OFFSET $3
962 )
963 ";
964
965 let mut tx = self.pool.begin().await?;
966 sqlx::query(insert_query)
967 .bind(user_id.0)
968 .bind(access_token_hash)
969 .execute(&mut tx)
970 .await?;
971 sqlx::query(cleanup_query)
972 .bind(user_id.0)
973 .bind(access_token_hash)
974 .bind(max_access_token_count as i32)
975 .execute(&mut tx)
976 .await?;
977 Ok(tx.commit().await?)
978 }
979
980 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
981 let query = "
982 SELECT hash
983 FROM access_tokens
984 WHERE user_id = $1
985 ORDER BY id DESC
986 ";
987 Ok(sqlx::query_scalar(query)
988 .bind(user_id.0)
989 .fetch_all(&self.pool)
990 .await?)
991 }
992
993 // orgs
994
995 #[allow(unused)] // Help rust-analyzer
996 #[cfg(any(test, feature = "seed-support"))]
997 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
998 let query = "
999 SELECT *
1000 FROM orgs
1001 WHERE slug = $1
1002 ";
1003 Ok(sqlx::query_as(query)
1004 .bind(slug)
1005 .fetch_optional(&self.pool)
1006 .await?)
1007 }
1008
1009 #[cfg(any(test, feature = "seed-support"))]
1010 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1011 let query = "
1012 INSERT INTO orgs (name, slug)
1013 VALUES ($1, $2)
1014 RETURNING id
1015 ";
1016 Ok(sqlx::query_scalar(query)
1017 .bind(name)
1018 .bind(slug)
1019 .fetch_one(&self.pool)
1020 .await
1021 .map(OrgId)?)
1022 }
1023
1024 #[cfg(any(test, feature = "seed-support"))]
1025 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1026 let query = "
1027 INSERT INTO org_memberships (org_id, user_id, admin)
1028 VALUES ($1, $2, $3)
1029 ON CONFLICT DO NOTHING
1030 ";
1031 Ok(sqlx::query(query)
1032 .bind(org_id.0)
1033 .bind(user_id.0)
1034 .bind(is_admin)
1035 .execute(&self.pool)
1036 .await
1037 .map(drop)?)
1038 }
1039
1040 // channels
1041
1042 #[cfg(any(test, feature = "seed-support"))]
1043 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1044 let query = "
1045 INSERT INTO channels (owner_id, owner_is_user, name)
1046 VALUES ($1, false, $2)
1047 RETURNING id
1048 ";
1049 Ok(sqlx::query_scalar(query)
1050 .bind(org_id.0)
1051 .bind(name)
1052 .fetch_one(&self.pool)
1053 .await
1054 .map(ChannelId)?)
1055 }
1056
1057 #[allow(unused)] // Help rust-analyzer
1058 #[cfg(any(test, feature = "seed-support"))]
1059 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1060 let query = "
1061 SELECT *
1062 FROM channels
1063 WHERE
1064 channels.owner_is_user = false AND
1065 channels.owner_id = $1
1066 ";
1067 Ok(sqlx::query_as(query)
1068 .bind(org_id.0)
1069 .fetch_all(&self.pool)
1070 .await?)
1071 }
1072
1073 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1074 let query = "
1075 SELECT
1076 channels.*
1077 FROM
1078 channel_memberships, channels
1079 WHERE
1080 channel_memberships.user_id = $1 AND
1081 channel_memberships.channel_id = channels.id
1082 ";
1083 Ok(sqlx::query_as(query)
1084 .bind(user_id.0)
1085 .fetch_all(&self.pool)
1086 .await?)
1087 }
1088
1089 async fn can_user_access_channel(
1090 &self,
1091 user_id: UserId,
1092 channel_id: ChannelId,
1093 ) -> Result<bool> {
1094 let query = "
1095 SELECT id
1096 FROM channel_memberships
1097 WHERE user_id = $1 AND channel_id = $2
1098 LIMIT 1
1099 ";
1100 Ok(sqlx::query_scalar::<_, i32>(query)
1101 .bind(user_id.0)
1102 .bind(channel_id.0)
1103 .fetch_optional(&self.pool)
1104 .await
1105 .map(|e| e.is_some())?)
1106 }
1107
1108 #[cfg(any(test, feature = "seed-support"))]
1109 async fn add_channel_member(
1110 &self,
1111 channel_id: ChannelId,
1112 user_id: UserId,
1113 is_admin: bool,
1114 ) -> Result<()> {
1115 let query = "
1116 INSERT INTO channel_memberships (channel_id, user_id, admin)
1117 VALUES ($1, $2, $3)
1118 ON CONFLICT DO NOTHING
1119 ";
1120 Ok(sqlx::query(query)
1121 .bind(channel_id.0)
1122 .bind(user_id.0)
1123 .bind(is_admin)
1124 .execute(&self.pool)
1125 .await
1126 .map(drop)?)
1127 }
1128
1129 // messages
1130
1131 async fn create_channel_message(
1132 &self,
1133 channel_id: ChannelId,
1134 sender_id: UserId,
1135 body: &str,
1136 timestamp: OffsetDateTime,
1137 nonce: u128,
1138 ) -> Result<MessageId> {
1139 let query = "
1140 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1141 VALUES ($1, $2, $3, $4, $5)
1142 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1143 RETURNING id
1144 ";
1145 Ok(sqlx::query_scalar(query)
1146 .bind(channel_id.0)
1147 .bind(sender_id.0)
1148 .bind(body)
1149 .bind(timestamp)
1150 .bind(Uuid::from_u128(nonce))
1151 .fetch_one(&self.pool)
1152 .await
1153 .map(MessageId)?)
1154 }
1155
1156 async fn get_channel_messages(
1157 &self,
1158 channel_id: ChannelId,
1159 count: usize,
1160 before_id: Option<MessageId>,
1161 ) -> Result<Vec<ChannelMessage>> {
1162 let query = r#"
1163 SELECT * FROM (
1164 SELECT
1165 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1166 FROM
1167 channel_messages
1168 WHERE
1169 channel_id = $1 AND
1170 id < $2
1171 ORDER BY id DESC
1172 LIMIT $3
1173 ) as recent_messages
1174 ORDER BY id ASC
1175 "#;
1176 Ok(sqlx::query_as(query)
1177 .bind(channel_id.0)
1178 .bind(before_id.unwrap_or(MessageId::MAX))
1179 .bind(count as i64)
1180 .fetch_all(&self.pool)
1181 .await?)
1182 }
1183
1184 #[cfg(test)]
1185 async fn teardown(&self, url: &str) {
1186 use util::ResultExt;
1187
1188 let query = "
1189 SELECT pg_terminate_backend(pg_stat_activity.pid)
1190 FROM pg_stat_activity
1191 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1192 ";
1193 sqlx::query(query).execute(&self.pool).await.log_err();
1194 self.pool.close().await;
1195 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1196 .await
1197 .log_err();
1198 }
1199
1200 #[cfg(test)]
1201 fn as_fake(&self) -> Option<&tests::FakeDb> {
1202 None
1203 }
1204}
1205
1206macro_rules! id_type {
1207 ($name:ident) => {
1208 #[derive(
1209 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
1210 )]
1211 #[sqlx(transparent)]
1212 #[serde(transparent)]
1213 pub struct $name(pub i32);
1214
1215 impl $name {
1216 #[allow(unused)]
1217 pub const MAX: Self = Self(i32::MAX);
1218
1219 #[allow(unused)]
1220 pub fn from_proto(value: u64) -> Self {
1221 Self(value as i32)
1222 }
1223
1224 #[allow(unused)]
1225 pub fn to_proto(&self) -> u64 {
1226 self.0 as u64
1227 }
1228 }
1229
1230 impl std::fmt::Display for $name {
1231 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1232 self.0.fmt(f)
1233 }
1234 }
1235 };
1236}
1237
1238id_type!(UserId);
1239#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1240pub struct User {
1241 pub id: UserId,
1242 pub github_login: String,
1243 pub email_address: Option<String>,
1244 pub admin: bool,
1245 pub invite_code: Option<String>,
1246 pub invite_count: i32,
1247 pub connected_once: bool,
1248}
1249
1250id_type!(ProjectId);
1251#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1252pub struct Project {
1253 pub id: ProjectId,
1254 pub host_user_id: UserId,
1255 pub unregistered: bool,
1256}
1257
1258#[derive(Clone, Debug, PartialEq, Serialize)]
1259pub struct UserActivitySummary {
1260 pub id: UserId,
1261 pub github_login: String,
1262 pub project_activity: Vec<(ProjectId, Duration)>,
1263}
1264
1265#[derive(Clone, Debug, PartialEq, Serialize)]
1266pub struct UserActivityDuration {
1267 project_id: ProjectId,
1268 start: OffsetDateTime,
1269 end: OffsetDateTime,
1270 extensions: HashMap<String, usize>,
1271}
1272
1273id_type!(OrgId);
1274#[derive(FromRow)]
1275pub struct Org {
1276 pub id: OrgId,
1277 pub name: String,
1278 pub slug: String,
1279}
1280
1281id_type!(ChannelId);
1282#[derive(Clone, Debug, FromRow, Serialize)]
1283pub struct Channel {
1284 pub id: ChannelId,
1285 pub name: String,
1286 pub owner_id: i32,
1287 pub owner_is_user: bool,
1288}
1289
1290id_type!(MessageId);
1291#[derive(Clone, Debug, FromRow)]
1292pub struct ChannelMessage {
1293 pub id: MessageId,
1294 pub channel_id: ChannelId,
1295 pub sender_id: UserId,
1296 pub body: String,
1297 pub sent_at: OffsetDateTime,
1298 pub nonce: Uuid,
1299}
1300
1301#[derive(Clone, Debug, PartialEq, Eq)]
1302pub enum Contact {
1303 Accepted {
1304 user_id: UserId,
1305 should_notify: bool,
1306 },
1307 Outgoing {
1308 user_id: UserId,
1309 },
1310 Incoming {
1311 user_id: UserId,
1312 should_notify: bool,
1313 },
1314}
1315
1316impl Contact {
1317 pub fn user_id(&self) -> UserId {
1318 match self {
1319 Contact::Accepted { user_id, .. } => *user_id,
1320 Contact::Outgoing { user_id } => *user_id,
1321 Contact::Incoming { user_id, .. } => *user_id,
1322 }
1323 }
1324}
1325
1326#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1327pub struct IncomingContactRequest {
1328 pub requester_id: UserId,
1329 pub should_notify: bool,
1330}
1331
1332fn fuzzy_like_string(string: &str) -> String {
1333 let mut result = String::with_capacity(string.len() * 2 + 1);
1334 for c in string.chars() {
1335 if c.is_alphanumeric() {
1336 result.push('%');
1337 result.push(c);
1338 }
1339 }
1340 result.push('%');
1341 result
1342}
1343
1344#[cfg(test)]
1345pub mod tests {
1346 use super::*;
1347 use anyhow::anyhow;
1348 use collections::BTreeMap;
1349 use gpui::executor::Background;
1350 use lazy_static::lazy_static;
1351 use parking_lot::Mutex;
1352 use rand::prelude::*;
1353 use sqlx::{
1354 migrate::{MigrateDatabase, Migrator},
1355 Postgres,
1356 };
1357 use std::{path::Path, sync::Arc};
1358 use util::post_inc;
1359
1360 #[tokio::test(flavor = "multi_thread")]
1361 async fn test_get_users_by_ids() {
1362 for test_db in [
1363 TestDb::postgres().await,
1364 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1365 ] {
1366 let db = test_db.db();
1367
1368 let user = db.create_user("user", None, false).await.unwrap();
1369 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1370 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1371 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1372
1373 assert_eq!(
1374 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1375 .await
1376 .unwrap(),
1377 vec![
1378 User {
1379 id: user,
1380 github_login: "user".to_string(),
1381 admin: false,
1382 ..Default::default()
1383 },
1384 User {
1385 id: friend1,
1386 github_login: "friend-1".to_string(),
1387 admin: false,
1388 ..Default::default()
1389 },
1390 User {
1391 id: friend2,
1392 github_login: "friend-2".to_string(),
1393 admin: false,
1394 ..Default::default()
1395 },
1396 User {
1397 id: friend3,
1398 github_login: "friend-3".to_string(),
1399 admin: false,
1400 ..Default::default()
1401 }
1402 ]
1403 );
1404 }
1405 }
1406
1407 #[tokio::test(flavor = "multi_thread")]
1408 async fn test_create_users() {
1409 let db = TestDb::postgres().await;
1410 let db = db.db();
1411
1412 // Create the first batch of users, ensuring invite counts are assigned
1413 // correctly and the respective invite codes are unique.
1414 let user_ids_batch_1 = db
1415 .create_users(vec![
1416 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1417 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1418 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1419 ])
1420 .await
1421 .unwrap();
1422 assert_eq!(user_ids_batch_1.len(), 3);
1423
1424 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1425 assert_eq!(users.len(), 3);
1426 assert_eq!(users[0].github_login, "user1");
1427 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1428 assert_eq!(users[0].invite_count, 5);
1429 assert_eq!(users[1].github_login, "user2");
1430 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1431 assert_eq!(users[1].invite_count, 4);
1432 assert_eq!(users[2].github_login, "user3");
1433 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1434 assert_eq!(users[2].invite_count, 3);
1435
1436 let invite_code_1 = users[0].invite_code.clone().unwrap();
1437 let invite_code_2 = users[1].invite_code.clone().unwrap();
1438 let invite_code_3 = users[2].invite_code.clone().unwrap();
1439 assert_ne!(invite_code_1, invite_code_2);
1440 assert_ne!(invite_code_1, invite_code_3);
1441 assert_ne!(invite_code_2, invite_code_3);
1442
1443 // Create the second batch of users and include a user that is already in the database, ensuring
1444 // the invite count for the existing user is updated without changing their invite code.
1445 let user_ids_batch_2 = db
1446 .create_users(vec![
1447 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1448 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1449 ])
1450 .await
1451 .unwrap();
1452 assert_eq!(user_ids_batch_2.len(), 2);
1453 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1454
1455 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1456 assert_eq!(users.len(), 2);
1457 assert_eq!(users[0].github_login, "user2");
1458 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1459 assert_eq!(users[0].invite_count, 10);
1460 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1461 assert_eq!(users[1].github_login, "user4");
1462 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1463 assert_eq!(users[1].invite_count, 2);
1464
1465 let invite_code_4 = users[1].invite_code.clone().unwrap();
1466 assert_ne!(invite_code_4, invite_code_1);
1467 assert_ne!(invite_code_4, invite_code_2);
1468 assert_ne!(invite_code_4, invite_code_3);
1469 }
1470
1471 #[tokio::test(flavor = "multi_thread")]
1472 async fn test_worktree_extensions() {
1473 let test_db = TestDb::postgres().await;
1474 let db = test_db.db();
1475
1476 let user = db.create_user("user_1", None, false).await.unwrap();
1477 let project = db.register_project(user).await.unwrap();
1478
1479 db.update_worktree_extensions(project, 100, Default::default())
1480 .await
1481 .unwrap();
1482 db.update_worktree_extensions(
1483 project,
1484 100,
1485 [("rs".to_string(), 5), ("md".to_string(), 3)]
1486 .into_iter()
1487 .collect(),
1488 )
1489 .await
1490 .unwrap();
1491 db.update_worktree_extensions(
1492 project,
1493 100,
1494 [("rs".to_string(), 6), ("md".to_string(), 5)]
1495 .into_iter()
1496 .collect(),
1497 )
1498 .await
1499 .unwrap();
1500 db.update_worktree_extensions(
1501 project,
1502 101,
1503 [("ts".to_string(), 2), ("md".to_string(), 1)]
1504 .into_iter()
1505 .collect(),
1506 )
1507 .await
1508 .unwrap();
1509
1510 assert_eq!(
1511 db.get_project_extensions(project).await.unwrap(),
1512 [
1513 (
1514 100,
1515 [("rs".into(), 6), ("md".into(), 5),]
1516 .into_iter()
1517 .collect::<HashMap<_, _>>()
1518 ),
1519 (
1520 101,
1521 [("ts".into(), 2), ("md".into(), 1),]
1522 .into_iter()
1523 .collect::<HashMap<_, _>>()
1524 )
1525 ]
1526 .into_iter()
1527 .collect()
1528 );
1529 }
1530
1531 #[tokio::test(flavor = "multi_thread")]
1532 async fn test_project_activity() {
1533 let test_db = TestDb::postgres().await;
1534 let db = test_db.db();
1535
1536 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1537 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1538 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1539 let project_1 = db.register_project(user_1).await.unwrap();
1540 db.update_worktree_extensions(
1541 project_1,
1542 1,
1543 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1544 )
1545 .await
1546 .unwrap();
1547 let project_2 = db.register_project(user_2).await.unwrap();
1548 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1549
1550 // User 2 opens a project
1551 let t1 = t0 + Duration::from_secs(10);
1552 db.record_project_activity(t0..t1, &[(user_2, project_2)])
1553 .await
1554 .unwrap();
1555
1556 let t2 = t1 + Duration::from_secs(10);
1557 db.record_project_activity(t1..t2, &[(user_2, project_2)])
1558 .await
1559 .unwrap();
1560
1561 // User 1 joins the project
1562 let t3 = t2 + Duration::from_secs(10);
1563 db.record_project_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1564 .await
1565 .unwrap();
1566
1567 // User 1 opens another project
1568 let t4 = t3 + Duration::from_secs(10);
1569 db.record_project_activity(
1570 t3..t4,
1571 &[
1572 (user_2, project_2),
1573 (user_1, project_2),
1574 (user_1, project_1),
1575 ],
1576 )
1577 .await
1578 .unwrap();
1579
1580 // User 3 joins that project
1581 let t5 = t4 + Duration::from_secs(10);
1582 db.record_project_activity(
1583 t4..t5,
1584 &[
1585 (user_2, project_2),
1586 (user_1, project_2),
1587 (user_1, project_1),
1588 (user_3, project_1),
1589 ],
1590 )
1591 .await
1592 .unwrap();
1593
1594 // User 2 leaves
1595 let t6 = t5 + Duration::from_secs(5);
1596 db.record_project_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1597 .await
1598 .unwrap();
1599
1600 let t7 = t6 + Duration::from_secs(60);
1601 let t8 = t7 + Duration::from_secs(10);
1602 db.record_project_activity(t7..t8, &[(user_1, project_1)])
1603 .await
1604 .unwrap();
1605
1606 assert_eq!(
1607 db.summarize_project_activity(t0..t6, 10).await.unwrap(),
1608 &[
1609 UserActivitySummary {
1610 id: user_1,
1611 github_login: "user_1".to_string(),
1612 project_activity: vec![
1613 (project_2, Duration::from_secs(30)),
1614 (project_1, Duration::from_secs(25))
1615 ]
1616 },
1617 UserActivitySummary {
1618 id: user_2,
1619 github_login: "user_2".to_string(),
1620 project_activity: vec![(project_2, Duration::from_secs(50))]
1621 },
1622 UserActivitySummary {
1623 id: user_3,
1624 github_login: "user_3".to_string(),
1625 project_activity: vec![(project_1, Duration::from_secs(15))]
1626 },
1627 ]
1628 );
1629 assert_eq!(
1630 db.summarize_user_activity(user_1, t3..t6).await.unwrap(),
1631 &[
1632 UserActivityDuration {
1633 project_id: project_1,
1634 start: t3,
1635 end: t6,
1636 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1637 },
1638 UserActivityDuration {
1639 project_id: project_2,
1640 start: t3,
1641 end: t5,
1642 extensions: Default::default(),
1643 },
1644 ]
1645 );
1646 assert_eq!(
1647 db.summarize_user_activity(user_1, t0..t8).await.unwrap(),
1648 &[
1649 UserActivityDuration {
1650 project_id: project_2,
1651 start: t2,
1652 end: t5,
1653 extensions: Default::default(),
1654 },
1655 UserActivityDuration {
1656 project_id: project_1,
1657 start: t3,
1658 end: t6,
1659 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1660 },
1661 UserActivityDuration {
1662 project_id: project_1,
1663 start: t7,
1664 end: t8,
1665 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1666 },
1667 ]
1668 );
1669 }
1670
1671 #[tokio::test(flavor = "multi_thread")]
1672 async fn test_recent_channel_messages() {
1673 for test_db in [
1674 TestDb::postgres().await,
1675 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1676 ] {
1677 let db = test_db.db();
1678 let user = db.create_user("user", None, false).await.unwrap();
1679 let org = db.create_org("org", "org").await.unwrap();
1680 let channel = db.create_org_channel(org, "channel").await.unwrap();
1681 for i in 0..10 {
1682 db.create_channel_message(
1683 channel,
1684 user,
1685 &i.to_string(),
1686 OffsetDateTime::now_utc(),
1687 i,
1688 )
1689 .await
1690 .unwrap();
1691 }
1692
1693 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1694 assert_eq!(
1695 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1696 ["5", "6", "7", "8", "9"]
1697 );
1698
1699 let prev_messages = db
1700 .get_channel_messages(channel, 4, Some(messages[0].id))
1701 .await
1702 .unwrap();
1703 assert_eq!(
1704 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1705 ["1", "2", "3", "4"]
1706 );
1707 }
1708 }
1709
1710 #[tokio::test(flavor = "multi_thread")]
1711 async fn test_channel_message_nonces() {
1712 for test_db in [
1713 TestDb::postgres().await,
1714 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1715 ] {
1716 let db = test_db.db();
1717 let user = db.create_user("user", None, false).await.unwrap();
1718 let org = db.create_org("org", "org").await.unwrap();
1719 let channel = db.create_org_channel(org, "channel").await.unwrap();
1720
1721 let msg1_id = db
1722 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1723 .await
1724 .unwrap();
1725 let msg2_id = db
1726 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1727 .await
1728 .unwrap();
1729 let msg3_id = db
1730 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1731 .await
1732 .unwrap();
1733 let msg4_id = db
1734 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1735 .await
1736 .unwrap();
1737
1738 assert_ne!(msg1_id, msg2_id);
1739 assert_eq!(msg1_id, msg3_id);
1740 assert_eq!(msg2_id, msg4_id);
1741 }
1742 }
1743
1744 #[tokio::test(flavor = "multi_thread")]
1745 async fn test_create_access_tokens() {
1746 let test_db = TestDb::postgres().await;
1747 let db = test_db.db();
1748 let user = db.create_user("the-user", None, false).await.unwrap();
1749
1750 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1751 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1752 assert_eq!(
1753 db.get_access_token_hashes(user).await.unwrap(),
1754 &["h2".to_string(), "h1".to_string()]
1755 );
1756
1757 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1758 assert_eq!(
1759 db.get_access_token_hashes(user).await.unwrap(),
1760 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1761 );
1762
1763 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1764 assert_eq!(
1765 db.get_access_token_hashes(user).await.unwrap(),
1766 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1767 );
1768
1769 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1770 assert_eq!(
1771 db.get_access_token_hashes(user).await.unwrap(),
1772 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1773 );
1774 }
1775
1776 #[test]
1777 fn test_fuzzy_like_string() {
1778 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1779 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1780 assert_eq!(fuzzy_like_string(" z "), "%z%");
1781 }
1782
1783 #[tokio::test(flavor = "multi_thread")]
1784 async fn test_fuzzy_search_users() {
1785 let test_db = TestDb::postgres().await;
1786 let db = test_db.db();
1787 for github_login in [
1788 "California",
1789 "colorado",
1790 "oregon",
1791 "washington",
1792 "florida",
1793 "delaware",
1794 "rhode-island",
1795 ] {
1796 db.create_user(github_login, None, false).await.unwrap();
1797 }
1798
1799 assert_eq!(
1800 fuzzy_search_user_names(db, "clr").await,
1801 &["colorado", "California"]
1802 );
1803 assert_eq!(
1804 fuzzy_search_user_names(db, "ro").await,
1805 &["rhode-island", "colorado", "oregon"],
1806 );
1807
1808 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1809 db.fuzzy_search_users(query, 10)
1810 .await
1811 .unwrap()
1812 .into_iter()
1813 .map(|user| user.github_login)
1814 .collect::<Vec<_>>()
1815 }
1816 }
1817
1818 #[tokio::test(flavor = "multi_thread")]
1819 async fn test_add_contacts() {
1820 for test_db in [
1821 TestDb::postgres().await,
1822 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1823 ] {
1824 let db = test_db.db();
1825
1826 let user_1 = db.create_user("user1", None, false).await.unwrap();
1827 let user_2 = db.create_user("user2", None, false).await.unwrap();
1828 let user_3 = db.create_user("user3", None, false).await.unwrap();
1829
1830 // User starts with no contacts
1831 assert_eq!(
1832 db.get_contacts(user_1).await.unwrap(),
1833 vec![Contact::Accepted {
1834 user_id: user_1,
1835 should_notify: false
1836 }],
1837 );
1838
1839 // User requests a contact. Both users see the pending request.
1840 db.send_contact_request(user_1, user_2).await.unwrap();
1841 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1842 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1843 assert_eq!(
1844 db.get_contacts(user_1).await.unwrap(),
1845 &[
1846 Contact::Accepted {
1847 user_id: user_1,
1848 should_notify: false
1849 },
1850 Contact::Outgoing { user_id: user_2 }
1851 ],
1852 );
1853 assert_eq!(
1854 db.get_contacts(user_2).await.unwrap(),
1855 &[
1856 Contact::Incoming {
1857 user_id: user_1,
1858 should_notify: true
1859 },
1860 Contact::Accepted {
1861 user_id: user_2,
1862 should_notify: false
1863 },
1864 ]
1865 );
1866
1867 // User 2 dismisses the contact request notification without accepting or rejecting.
1868 // We shouldn't notify them again.
1869 db.dismiss_contact_notification(user_1, user_2)
1870 .await
1871 .unwrap_err();
1872 db.dismiss_contact_notification(user_2, user_1)
1873 .await
1874 .unwrap();
1875 assert_eq!(
1876 db.get_contacts(user_2).await.unwrap(),
1877 &[
1878 Contact::Incoming {
1879 user_id: user_1,
1880 should_notify: false
1881 },
1882 Contact::Accepted {
1883 user_id: user_2,
1884 should_notify: false
1885 },
1886 ]
1887 );
1888
1889 // User can't accept their own contact request
1890 db.respond_to_contact_request(user_1, user_2, true)
1891 .await
1892 .unwrap_err();
1893
1894 // User accepts a contact request. Both users see the contact.
1895 db.respond_to_contact_request(user_2, user_1, true)
1896 .await
1897 .unwrap();
1898 assert_eq!(
1899 db.get_contacts(user_1).await.unwrap(),
1900 &[
1901 Contact::Accepted {
1902 user_id: user_1,
1903 should_notify: false
1904 },
1905 Contact::Accepted {
1906 user_id: user_2,
1907 should_notify: true
1908 }
1909 ],
1910 );
1911 assert!(db.has_contact(user_1, user_2).await.unwrap());
1912 assert!(db.has_contact(user_2, user_1).await.unwrap());
1913 assert_eq!(
1914 db.get_contacts(user_2).await.unwrap(),
1915 &[
1916 Contact::Accepted {
1917 user_id: user_1,
1918 should_notify: false,
1919 },
1920 Contact::Accepted {
1921 user_id: user_2,
1922 should_notify: false,
1923 },
1924 ]
1925 );
1926
1927 // Users cannot re-request existing contacts.
1928 db.send_contact_request(user_1, user_2).await.unwrap_err();
1929 db.send_contact_request(user_2, user_1).await.unwrap_err();
1930
1931 // Users can't dismiss notifications of them accepting other users' requests.
1932 db.dismiss_contact_notification(user_2, user_1)
1933 .await
1934 .unwrap_err();
1935 assert_eq!(
1936 db.get_contacts(user_1).await.unwrap(),
1937 &[
1938 Contact::Accepted {
1939 user_id: user_1,
1940 should_notify: false
1941 },
1942 Contact::Accepted {
1943 user_id: user_2,
1944 should_notify: true,
1945 },
1946 ]
1947 );
1948
1949 // Users can dismiss notifications of other users accepting their requests.
1950 db.dismiss_contact_notification(user_1, user_2)
1951 .await
1952 .unwrap();
1953 assert_eq!(
1954 db.get_contacts(user_1).await.unwrap(),
1955 &[
1956 Contact::Accepted {
1957 user_id: user_1,
1958 should_notify: false
1959 },
1960 Contact::Accepted {
1961 user_id: user_2,
1962 should_notify: false,
1963 },
1964 ]
1965 );
1966
1967 // Users send each other concurrent contact requests and
1968 // see that they are immediately accepted.
1969 db.send_contact_request(user_1, user_3).await.unwrap();
1970 db.send_contact_request(user_3, user_1).await.unwrap();
1971 assert_eq!(
1972 db.get_contacts(user_1).await.unwrap(),
1973 &[
1974 Contact::Accepted {
1975 user_id: user_1,
1976 should_notify: false
1977 },
1978 Contact::Accepted {
1979 user_id: user_2,
1980 should_notify: false,
1981 },
1982 Contact::Accepted {
1983 user_id: user_3,
1984 should_notify: false
1985 },
1986 ]
1987 );
1988 assert_eq!(
1989 db.get_contacts(user_3).await.unwrap(),
1990 &[
1991 Contact::Accepted {
1992 user_id: user_1,
1993 should_notify: false
1994 },
1995 Contact::Accepted {
1996 user_id: user_3,
1997 should_notify: false
1998 }
1999 ],
2000 );
2001
2002 // User declines a contact request. Both users see that it is gone.
2003 db.send_contact_request(user_2, user_3).await.unwrap();
2004 db.respond_to_contact_request(user_3, user_2, false)
2005 .await
2006 .unwrap();
2007 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2008 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2009 assert_eq!(
2010 db.get_contacts(user_2).await.unwrap(),
2011 &[
2012 Contact::Accepted {
2013 user_id: user_1,
2014 should_notify: false
2015 },
2016 Contact::Accepted {
2017 user_id: user_2,
2018 should_notify: false
2019 }
2020 ]
2021 );
2022 assert_eq!(
2023 db.get_contacts(user_3).await.unwrap(),
2024 &[
2025 Contact::Accepted {
2026 user_id: user_1,
2027 should_notify: false
2028 },
2029 Contact::Accepted {
2030 user_id: user_3,
2031 should_notify: false
2032 }
2033 ],
2034 );
2035 }
2036 }
2037
2038 #[tokio::test(flavor = "multi_thread")]
2039 async fn test_invite_codes() {
2040 let postgres = TestDb::postgres().await;
2041 let db = postgres.db();
2042 let user1 = db.create_user("user-1", None, false).await.unwrap();
2043
2044 // Initially, user 1 has no invite code
2045 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2046
2047 // Setting invite count to 0 when no code is assigned does not assign a new code
2048 db.set_invite_count(user1, 0).await.unwrap();
2049 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2050
2051 // User 1 creates an invite code that can be used twice.
2052 db.set_invite_count(user1, 2).await.unwrap();
2053 let (invite_code, invite_count) =
2054 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2055 assert_eq!(invite_count, 2);
2056
2057 // User 2 redeems the invite code and becomes a contact of user 1.
2058 let user2 = db
2059 .redeem_invite_code(&invite_code, "user-2", None)
2060 .await
2061 .unwrap();
2062 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2063 assert_eq!(invite_count, 1);
2064 assert_eq!(
2065 db.get_contacts(user1).await.unwrap(),
2066 [
2067 Contact::Accepted {
2068 user_id: user1,
2069 should_notify: false
2070 },
2071 Contact::Accepted {
2072 user_id: user2,
2073 should_notify: true
2074 }
2075 ]
2076 );
2077 assert_eq!(
2078 db.get_contacts(user2).await.unwrap(),
2079 [
2080 Contact::Accepted {
2081 user_id: user1,
2082 should_notify: false
2083 },
2084 Contact::Accepted {
2085 user_id: user2,
2086 should_notify: false
2087 }
2088 ]
2089 );
2090
2091 // User 3 redeems the invite code and becomes a contact of user 1.
2092 let user3 = db
2093 .redeem_invite_code(&invite_code, "user-3", None)
2094 .await
2095 .unwrap();
2096 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2097 assert_eq!(invite_count, 0);
2098 assert_eq!(
2099 db.get_contacts(user1).await.unwrap(),
2100 [
2101 Contact::Accepted {
2102 user_id: user1,
2103 should_notify: false
2104 },
2105 Contact::Accepted {
2106 user_id: user2,
2107 should_notify: true
2108 },
2109 Contact::Accepted {
2110 user_id: user3,
2111 should_notify: true
2112 }
2113 ]
2114 );
2115 assert_eq!(
2116 db.get_contacts(user3).await.unwrap(),
2117 [
2118 Contact::Accepted {
2119 user_id: user1,
2120 should_notify: false
2121 },
2122 Contact::Accepted {
2123 user_id: user3,
2124 should_notify: false
2125 },
2126 ]
2127 );
2128
2129 // Trying to reedem the code for the third time results in an error.
2130 db.redeem_invite_code(&invite_code, "user-4", None)
2131 .await
2132 .unwrap_err();
2133
2134 // Invite count can be updated after the code has been created.
2135 db.set_invite_count(user1, 2).await.unwrap();
2136 let (latest_code, invite_count) =
2137 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2138 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2139 assert_eq!(invite_count, 2);
2140
2141 // User 4 can now redeem the invite code and becomes a contact of user 1.
2142 let user4 = db
2143 .redeem_invite_code(&invite_code, "user-4", None)
2144 .await
2145 .unwrap();
2146 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2147 assert_eq!(invite_count, 1);
2148 assert_eq!(
2149 db.get_contacts(user1).await.unwrap(),
2150 [
2151 Contact::Accepted {
2152 user_id: user1,
2153 should_notify: false
2154 },
2155 Contact::Accepted {
2156 user_id: user2,
2157 should_notify: true
2158 },
2159 Contact::Accepted {
2160 user_id: user3,
2161 should_notify: true
2162 },
2163 Contact::Accepted {
2164 user_id: user4,
2165 should_notify: true
2166 }
2167 ]
2168 );
2169 assert_eq!(
2170 db.get_contacts(user4).await.unwrap(),
2171 [
2172 Contact::Accepted {
2173 user_id: user1,
2174 should_notify: false
2175 },
2176 Contact::Accepted {
2177 user_id: user4,
2178 should_notify: false
2179 },
2180 ]
2181 );
2182
2183 // An existing user cannot redeem invite codes.
2184 db.redeem_invite_code(&invite_code, "user-2", None)
2185 .await
2186 .unwrap_err();
2187 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2188 assert_eq!(invite_count, 1);
2189 }
2190
2191 pub struct TestDb {
2192 pub db: Option<Arc<dyn Db>>,
2193 pub url: String,
2194 }
2195
2196 impl TestDb {
2197 pub async fn postgres() -> Self {
2198 lazy_static! {
2199 static ref LOCK: Mutex<()> = Mutex::new(());
2200 }
2201
2202 let _guard = LOCK.lock();
2203 let mut rng = StdRng::from_entropy();
2204 let name = format!("zed-test-{}", rng.gen::<u128>());
2205 let url = format!("postgres://postgres@localhost/{}", name);
2206 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2207 Postgres::create_database(&url)
2208 .await
2209 .expect("failed to create test db");
2210 let db = PostgresDb::new(&url, 5).await.unwrap();
2211 let migrator = Migrator::new(migrations_path).await.unwrap();
2212 migrator.run(&db.pool).await.unwrap();
2213 Self {
2214 db: Some(Arc::new(db)),
2215 url,
2216 }
2217 }
2218
2219 pub fn fake(background: Arc<Background>) -> Self {
2220 Self {
2221 db: Some(Arc::new(FakeDb::new(background))),
2222 url: Default::default(),
2223 }
2224 }
2225
2226 pub fn db(&self) -> &Arc<dyn Db> {
2227 self.db.as_ref().unwrap()
2228 }
2229 }
2230
2231 impl Drop for TestDb {
2232 fn drop(&mut self) {
2233 if let Some(db) = self.db.take() {
2234 futures::executor::block_on(db.teardown(&self.url));
2235 }
2236 }
2237 }
2238
2239 pub struct FakeDb {
2240 background: Arc<Background>,
2241 pub users: Mutex<BTreeMap<UserId, User>>,
2242 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2243 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), usize>>,
2244 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2245 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2246 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2247 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2248 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2249 pub contacts: Mutex<Vec<FakeContact>>,
2250 next_channel_message_id: Mutex<i32>,
2251 next_user_id: Mutex<i32>,
2252 next_org_id: Mutex<i32>,
2253 next_channel_id: Mutex<i32>,
2254 next_project_id: Mutex<i32>,
2255 }
2256
2257 #[derive(Debug)]
2258 pub struct FakeContact {
2259 pub requester_id: UserId,
2260 pub responder_id: UserId,
2261 pub accepted: bool,
2262 pub should_notify: bool,
2263 }
2264
2265 impl FakeDb {
2266 pub fn new(background: Arc<Background>) -> Self {
2267 Self {
2268 background,
2269 users: Default::default(),
2270 next_user_id: Mutex::new(1),
2271 projects: Default::default(),
2272 worktree_extensions: Default::default(),
2273 next_project_id: Mutex::new(1),
2274 orgs: Default::default(),
2275 next_org_id: Mutex::new(1),
2276 org_memberships: Default::default(),
2277 channels: Default::default(),
2278 next_channel_id: Mutex::new(1),
2279 channel_memberships: Default::default(),
2280 channel_messages: Default::default(),
2281 next_channel_message_id: Mutex::new(1),
2282 contacts: Default::default(),
2283 }
2284 }
2285 }
2286
2287 #[async_trait]
2288 impl Db for FakeDb {
2289 async fn create_user(
2290 &self,
2291 github_login: &str,
2292 email_address: Option<&str>,
2293 admin: bool,
2294 ) -> Result<UserId> {
2295 self.background.simulate_random_delay().await;
2296
2297 let mut users = self.users.lock();
2298 if let Some(user) = users
2299 .values()
2300 .find(|user| user.github_login == github_login)
2301 {
2302 Ok(user.id)
2303 } else {
2304 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2305 users.insert(
2306 user_id,
2307 User {
2308 id: user_id,
2309 github_login: github_login.to_string(),
2310 email_address: email_address.map(str::to_string),
2311 admin,
2312 invite_code: None,
2313 invite_count: 0,
2314 connected_once: false,
2315 },
2316 );
2317 Ok(user_id)
2318 }
2319 }
2320
2321 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2322 unimplemented!()
2323 }
2324
2325 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2326 unimplemented!()
2327 }
2328
2329 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2330 unimplemented!()
2331 }
2332
2333 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2334 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2335 }
2336
2337 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2338 self.background.simulate_random_delay().await;
2339 let users = self.users.lock();
2340 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2341 }
2342
2343 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2344 unimplemented!()
2345 }
2346
2347 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2348 Ok(self
2349 .users
2350 .lock()
2351 .values()
2352 .find(|user| user.github_login == github_login)
2353 .cloned())
2354 }
2355
2356 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2357 unimplemented!()
2358 }
2359
2360 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2361 self.background.simulate_random_delay().await;
2362 let mut users = self.users.lock();
2363 let mut user = users
2364 .get_mut(&id)
2365 .ok_or_else(|| anyhow!("user not found"))?;
2366 user.connected_once = connected_once;
2367 Ok(())
2368 }
2369
2370 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2371 unimplemented!()
2372 }
2373
2374 // invite codes
2375
2376 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2377 unimplemented!()
2378 }
2379
2380 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2381 Ok(None)
2382 }
2383
2384 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2385 unimplemented!()
2386 }
2387
2388 async fn redeem_invite_code(
2389 &self,
2390 _code: &str,
2391 _login: &str,
2392 _email_address: Option<&str>,
2393 ) -> Result<UserId> {
2394 unimplemented!()
2395 }
2396
2397 // projects
2398
2399 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2400 self.background.simulate_random_delay().await;
2401 if !self.users.lock().contains_key(&host_user_id) {
2402 Err(anyhow!("no such user"))?;
2403 }
2404
2405 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2406 self.projects.lock().insert(
2407 project_id,
2408 Project {
2409 id: project_id,
2410 host_user_id,
2411 unregistered: false,
2412 },
2413 );
2414 Ok(project_id)
2415 }
2416
2417 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2418 self.projects
2419 .lock()
2420 .get_mut(&project_id)
2421 .ok_or_else(|| anyhow!("no such project"))?
2422 .unregistered = true;
2423 Ok(())
2424 }
2425
2426 async fn update_worktree_extensions(
2427 &self,
2428 project_id: ProjectId,
2429 worktree_id: u64,
2430 extensions: HashMap<String, usize>,
2431 ) -> Result<()> {
2432 self.background.simulate_random_delay().await;
2433 if !self.projects.lock().contains_key(&project_id) {
2434 Err(anyhow!("no such project"))?;
2435 }
2436
2437 for (extension, count) in extensions {
2438 self.worktree_extensions
2439 .lock()
2440 .insert((project_id, worktree_id, extension), count);
2441 }
2442
2443 Ok(())
2444 }
2445
2446 async fn get_project_extensions(
2447 &self,
2448 _project_id: ProjectId,
2449 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2450 unimplemented!()
2451 }
2452
2453 async fn record_project_activity(
2454 &self,
2455 _period: Range<OffsetDateTime>,
2456 _active_projects: &[(UserId, ProjectId)],
2457 ) -> Result<()> {
2458 unimplemented!()
2459 }
2460
2461 async fn summarize_project_activity(
2462 &self,
2463 _period: Range<OffsetDateTime>,
2464 _limit: usize,
2465 ) -> Result<Vec<UserActivitySummary>> {
2466 unimplemented!()
2467 }
2468
2469 async fn summarize_user_activity(
2470 &self,
2471 _user_id: UserId,
2472 _time_period: Range<OffsetDateTime>,
2473 ) -> Result<Vec<UserActivityDuration>> {
2474 unimplemented!()
2475 }
2476
2477 // contacts
2478
2479 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2480 self.background.simulate_random_delay().await;
2481 let mut contacts = vec![Contact::Accepted {
2482 user_id: id,
2483 should_notify: false,
2484 }];
2485
2486 for contact in self.contacts.lock().iter() {
2487 if contact.requester_id == id {
2488 if contact.accepted {
2489 contacts.push(Contact::Accepted {
2490 user_id: contact.responder_id,
2491 should_notify: contact.should_notify,
2492 });
2493 } else {
2494 contacts.push(Contact::Outgoing {
2495 user_id: contact.responder_id,
2496 });
2497 }
2498 } else if contact.responder_id == id {
2499 if contact.accepted {
2500 contacts.push(Contact::Accepted {
2501 user_id: contact.requester_id,
2502 should_notify: false,
2503 });
2504 } else {
2505 contacts.push(Contact::Incoming {
2506 user_id: contact.requester_id,
2507 should_notify: contact.should_notify,
2508 });
2509 }
2510 }
2511 }
2512
2513 contacts.sort_unstable_by_key(|contact| contact.user_id());
2514 Ok(contacts)
2515 }
2516
2517 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2518 self.background.simulate_random_delay().await;
2519 Ok(self.contacts.lock().iter().any(|contact| {
2520 contact.accepted
2521 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2522 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2523 }))
2524 }
2525
2526 async fn send_contact_request(
2527 &self,
2528 requester_id: UserId,
2529 responder_id: UserId,
2530 ) -> Result<()> {
2531 let mut contacts = self.contacts.lock();
2532 for contact in contacts.iter_mut() {
2533 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2534 if contact.accepted {
2535 Err(anyhow!("contact already exists"))?;
2536 } else {
2537 Err(anyhow!("contact already requested"))?;
2538 }
2539 }
2540 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2541 if contact.accepted {
2542 Err(anyhow!("contact already exists"))?;
2543 } else {
2544 contact.accepted = true;
2545 contact.should_notify = false;
2546 return Ok(());
2547 }
2548 }
2549 }
2550 contacts.push(FakeContact {
2551 requester_id,
2552 responder_id,
2553 accepted: false,
2554 should_notify: true,
2555 });
2556 Ok(())
2557 }
2558
2559 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2560 self.contacts.lock().retain(|contact| {
2561 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2562 });
2563 Ok(())
2564 }
2565
2566 async fn dismiss_contact_notification(
2567 &self,
2568 user_id: UserId,
2569 contact_user_id: UserId,
2570 ) -> Result<()> {
2571 let mut contacts = self.contacts.lock();
2572 for contact in contacts.iter_mut() {
2573 if contact.requester_id == contact_user_id
2574 && contact.responder_id == user_id
2575 && !contact.accepted
2576 {
2577 contact.should_notify = false;
2578 return Ok(());
2579 }
2580 if contact.requester_id == user_id
2581 && contact.responder_id == contact_user_id
2582 && contact.accepted
2583 {
2584 contact.should_notify = false;
2585 return Ok(());
2586 }
2587 }
2588 Err(anyhow!("no such notification"))?
2589 }
2590
2591 async fn respond_to_contact_request(
2592 &self,
2593 responder_id: UserId,
2594 requester_id: UserId,
2595 accept: bool,
2596 ) -> Result<()> {
2597 let mut contacts = self.contacts.lock();
2598 for (ix, contact) in contacts.iter_mut().enumerate() {
2599 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2600 if contact.accepted {
2601 Err(anyhow!("contact already confirmed"))?;
2602 }
2603 if accept {
2604 contact.accepted = true;
2605 contact.should_notify = true;
2606 } else {
2607 contacts.remove(ix);
2608 }
2609 return Ok(());
2610 }
2611 }
2612 Err(anyhow!("no such contact request"))?
2613 }
2614
2615 async fn create_access_token_hash(
2616 &self,
2617 _user_id: UserId,
2618 _access_token_hash: &str,
2619 _max_access_token_count: usize,
2620 ) -> Result<()> {
2621 unimplemented!()
2622 }
2623
2624 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2625 unimplemented!()
2626 }
2627
2628 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2629 unimplemented!()
2630 }
2631
2632 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2633 self.background.simulate_random_delay().await;
2634 let mut orgs = self.orgs.lock();
2635 if orgs.values().any(|org| org.slug == slug) {
2636 Err(anyhow!("org already exists"))?
2637 } else {
2638 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2639 orgs.insert(
2640 org_id,
2641 Org {
2642 id: org_id,
2643 name: name.to_string(),
2644 slug: slug.to_string(),
2645 },
2646 );
2647 Ok(org_id)
2648 }
2649 }
2650
2651 async fn add_org_member(
2652 &self,
2653 org_id: OrgId,
2654 user_id: UserId,
2655 is_admin: bool,
2656 ) -> Result<()> {
2657 self.background.simulate_random_delay().await;
2658 if !self.orgs.lock().contains_key(&org_id) {
2659 Err(anyhow!("org does not exist"))?;
2660 }
2661 if !self.users.lock().contains_key(&user_id) {
2662 Err(anyhow!("user does not exist"))?;
2663 }
2664
2665 self.org_memberships
2666 .lock()
2667 .entry((org_id, user_id))
2668 .or_insert(is_admin);
2669 Ok(())
2670 }
2671
2672 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2673 self.background.simulate_random_delay().await;
2674 if !self.orgs.lock().contains_key(&org_id) {
2675 Err(anyhow!("org does not exist"))?;
2676 }
2677
2678 let mut channels = self.channels.lock();
2679 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2680 channels.insert(
2681 channel_id,
2682 Channel {
2683 id: channel_id,
2684 name: name.to_string(),
2685 owner_id: org_id.0,
2686 owner_is_user: false,
2687 },
2688 );
2689 Ok(channel_id)
2690 }
2691
2692 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2693 self.background.simulate_random_delay().await;
2694 Ok(self
2695 .channels
2696 .lock()
2697 .values()
2698 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2699 .cloned()
2700 .collect())
2701 }
2702
2703 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2704 self.background.simulate_random_delay().await;
2705 let channels = self.channels.lock();
2706 let memberships = self.channel_memberships.lock();
2707 Ok(channels
2708 .values()
2709 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2710 .cloned()
2711 .collect())
2712 }
2713
2714 async fn can_user_access_channel(
2715 &self,
2716 user_id: UserId,
2717 channel_id: ChannelId,
2718 ) -> Result<bool> {
2719 self.background.simulate_random_delay().await;
2720 Ok(self
2721 .channel_memberships
2722 .lock()
2723 .contains_key(&(channel_id, user_id)))
2724 }
2725
2726 async fn add_channel_member(
2727 &self,
2728 channel_id: ChannelId,
2729 user_id: UserId,
2730 is_admin: bool,
2731 ) -> Result<()> {
2732 self.background.simulate_random_delay().await;
2733 if !self.channels.lock().contains_key(&channel_id) {
2734 Err(anyhow!("channel does not exist"))?;
2735 }
2736 if !self.users.lock().contains_key(&user_id) {
2737 Err(anyhow!("user does not exist"))?;
2738 }
2739
2740 self.channel_memberships
2741 .lock()
2742 .entry((channel_id, user_id))
2743 .or_insert(is_admin);
2744 Ok(())
2745 }
2746
2747 async fn create_channel_message(
2748 &self,
2749 channel_id: ChannelId,
2750 sender_id: UserId,
2751 body: &str,
2752 timestamp: OffsetDateTime,
2753 nonce: u128,
2754 ) -> Result<MessageId> {
2755 self.background.simulate_random_delay().await;
2756 if !self.channels.lock().contains_key(&channel_id) {
2757 Err(anyhow!("channel does not exist"))?;
2758 }
2759 if !self.users.lock().contains_key(&sender_id) {
2760 Err(anyhow!("user does not exist"))?;
2761 }
2762
2763 let mut messages = self.channel_messages.lock();
2764 if let Some(message) = messages
2765 .values()
2766 .find(|message| message.nonce.as_u128() == nonce)
2767 {
2768 Ok(message.id)
2769 } else {
2770 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2771 messages.insert(
2772 message_id,
2773 ChannelMessage {
2774 id: message_id,
2775 channel_id,
2776 sender_id,
2777 body: body.to_string(),
2778 sent_at: timestamp,
2779 nonce: Uuid::from_u128(nonce),
2780 },
2781 );
2782 Ok(message_id)
2783 }
2784 }
2785
2786 async fn get_channel_messages(
2787 &self,
2788 channel_id: ChannelId,
2789 count: usize,
2790 before_id: Option<MessageId>,
2791 ) -> Result<Vec<ChannelMessage>> {
2792 let mut messages = self
2793 .channel_messages
2794 .lock()
2795 .values()
2796 .rev()
2797 .filter(|message| {
2798 message.channel_id == channel_id
2799 && message.id < before_id.unwrap_or(MessageId::MAX)
2800 })
2801 .take(count)
2802 .cloned()
2803 .collect::<Vec<_>>();
2804 messages.sort_unstable_by_key(|message| message.id);
2805 Ok(messages)
2806 }
2807
2808 async fn teardown(&self, _: &str) {}
2809
2810 #[cfg(test)]
2811 fn as_fake(&self) -> Option<&FakeDb> {
2812 Some(self)
2813 }
2814 }
2815}