1use std::{cmp, ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::{Deserialize, Serialize};
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
29 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
30 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
31 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
32 async fn destroy_user(&self, id: UserId) -> Result<()>;
33
34 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
35 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
36 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
37 async fn redeem_invite_code(
38 &self,
39 code: &str,
40 login: &str,
41 email_address: Option<&str>,
42 ) -> Result<UserId>;
43
44 /// Registers a new project for the given user.
45 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
46
47 /// Unregisters a project for the given project id.
48 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
49
50 /// Update file counts by extension for the given project and worktree.
51 async fn update_worktree_extensions(
52 &self,
53 project_id: ProjectId,
54 worktree_id: u64,
55 extensions: HashMap<String, u32>,
56 ) -> Result<()>;
57
58 /// Get the file counts on the given project keyed by their worktree and extension.
59 async fn get_project_extensions(
60 &self,
61 project_id: ProjectId,
62 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
63
64 /// Record which users have been active in which projects during
65 /// a given period of time.
66 async fn record_user_activity(
67 &self,
68 time_period: Range<OffsetDateTime>,
69 active_projects: &[(UserId, ProjectId)],
70 ) -> Result<()>;
71
72 /// Get the number of users who have been active in the given
73 /// time period for at least the given time duration.
74 async fn get_active_user_count(
75 &self,
76 time_period: Range<OffsetDateTime>,
77 min_duration: Duration,
78 only_collaborative: bool,
79 ) -> Result<usize>;
80
81 /// Get the users that have been most active during the given time period,
82 /// along with the amount of time they have been active in each project.
83 async fn get_top_users_activity_summary(
84 &self,
85 time_period: Range<OffsetDateTime>,
86 max_user_count: usize,
87 ) -> Result<Vec<UserActivitySummary>>;
88
89 /// Get the project activity for the given user and time period.
90 async fn get_user_activity_timeline(
91 &self,
92 time_period: Range<OffsetDateTime>,
93 user_id: UserId,
94 ) -> Result<Vec<UserActivityPeriod>>;
95
96 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
97 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
98 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
99 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
100 async fn dismiss_contact_notification(
101 &self,
102 responder_id: UserId,
103 requester_id: UserId,
104 ) -> Result<()>;
105 async fn respond_to_contact_request(
106 &self,
107 responder_id: UserId,
108 requester_id: UserId,
109 accept: bool,
110 ) -> Result<()>;
111
112 async fn create_access_token_hash(
113 &self,
114 user_id: UserId,
115 access_token_hash: &str,
116 max_access_token_count: usize,
117 ) -> Result<()>;
118 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
119 #[cfg(any(test, feature = "seed-support"))]
120
121 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
122 #[cfg(any(test, feature = "seed-support"))]
123 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
124 #[cfg(any(test, feature = "seed-support"))]
125 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
126 #[cfg(any(test, feature = "seed-support"))]
127 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
128 #[cfg(any(test, feature = "seed-support"))]
129
130 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
131 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
132 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
133 -> Result<bool>;
134 #[cfg(any(test, feature = "seed-support"))]
135 async fn add_channel_member(
136 &self,
137 channel_id: ChannelId,
138 user_id: UserId,
139 is_admin: bool,
140 ) -> Result<()>;
141 async fn create_channel_message(
142 &self,
143 channel_id: ChannelId,
144 sender_id: UserId,
145 body: &str,
146 timestamp: OffsetDateTime,
147 nonce: u128,
148 ) -> Result<MessageId>;
149 async fn get_channel_messages(
150 &self,
151 channel_id: ChannelId,
152 count: usize,
153 before_id: Option<MessageId>,
154 ) -> Result<Vec<ChannelMessage>>;
155 #[cfg(test)]
156 async fn teardown(&self, url: &str);
157 #[cfg(test)]
158 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
159}
160
161pub struct PostgresDb {
162 pool: sqlx::PgPool,
163}
164
165impl PostgresDb {
166 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
167 let pool = DbOptions::new()
168 .max_connections(max_connections)
169 .connect(&url)
170 .await
171 .context("failed to connect to postgres database")?;
172 Ok(Self { pool })
173 }
174}
175
176#[async_trait]
177impl Db for PostgresDb {
178 // users
179
180 async fn create_user(
181 &self,
182 github_login: &str,
183 email_address: Option<&str>,
184 admin: bool,
185 ) -> Result<UserId> {
186 let query = "
187 INSERT INTO users (github_login, email_address, admin)
188 VALUES ($1, $2, $3)
189 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
190 RETURNING id
191 ";
192 Ok(sqlx::query_scalar(query)
193 .bind(github_login)
194 .bind(email_address)
195 .bind(admin)
196 .fetch_one(&self.pool)
197 .await
198 .map(UserId)?)
199 }
200
201 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
202 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
203 Ok(sqlx::query_as(query)
204 .bind(limit as i32)
205 .bind((page * limit) as i32)
206 .fetch_all(&self.pool)
207 .await?)
208 }
209
210 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
211 let mut query = QueryBuilder::new(
212 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
213 );
214 query.push_values(
215 users,
216 |mut query, (github_login, email_address, invite_count)| {
217 query
218 .push_bind(github_login)
219 .push_bind(email_address)
220 .push_bind(false)
221 .push_bind(nanoid!(16))
222 .push_bind(invite_count as i32);
223 },
224 );
225 query.push(
226 "
227 ON CONFLICT (github_login) DO UPDATE SET
228 github_login = excluded.github_login,
229 invite_count = excluded.invite_count,
230 invite_code = CASE WHEN users.invite_code IS NULL
231 THEN excluded.invite_code
232 ELSE users.invite_code
233 END
234 RETURNING id
235 ",
236 );
237
238 let rows = query.build().fetch_all(&self.pool).await?;
239 Ok(rows
240 .into_iter()
241 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
242 .collect())
243 }
244
245 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
246 let like_string = fuzzy_like_string(name_query);
247 let query = "
248 SELECT users.*
249 FROM users
250 WHERE github_login ILIKE $1
251 ORDER BY github_login <-> $2
252 LIMIT $3
253 ";
254 Ok(sqlx::query_as(query)
255 .bind(like_string)
256 .bind(name_query)
257 .bind(limit as i32)
258 .fetch_all(&self.pool)
259 .await?)
260 }
261
262 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
263 let users = self.get_users_by_ids(vec![id]).await?;
264 Ok(users.into_iter().next())
265 }
266
267 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
268 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
269 let query = "
270 SELECT users.*
271 FROM users
272 WHERE users.id = ANY ($1)
273 ";
274 Ok(sqlx::query_as(query)
275 .bind(&ids)
276 .fetch_all(&self.pool)
277 .await?)
278 }
279
280 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
281 let query = format!(
282 "
283 SELECT users.*
284 FROM users
285 WHERE invite_count = 0
286 AND inviter_id IS{} NULL
287 ",
288 if invited_by_another_user { " NOT" } else { "" }
289 );
290
291 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
292 }
293
294 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
295 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
296 Ok(sqlx::query_as(query)
297 .bind(github_login)
298 .fetch_optional(&self.pool)
299 .await?)
300 }
301
302 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
303 let query = "UPDATE users SET admin = $1 WHERE id = $2";
304 Ok(sqlx::query(query)
305 .bind(is_admin)
306 .bind(id.0)
307 .execute(&self.pool)
308 .await
309 .map(drop)?)
310 }
311
312 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
313 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
314 Ok(sqlx::query(query)
315 .bind(connected_once)
316 .bind(id.0)
317 .execute(&self.pool)
318 .await
319 .map(drop)?)
320 }
321
322 async fn destroy_user(&self, id: UserId) -> Result<()> {
323 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
324 sqlx::query(query)
325 .bind(id.0)
326 .execute(&self.pool)
327 .await
328 .map(drop)?;
329 let query = "DELETE FROM users WHERE id = $1;";
330 Ok(sqlx::query(query)
331 .bind(id.0)
332 .execute(&self.pool)
333 .await
334 .map(drop)?)
335 }
336
337 // invite codes
338
339 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
340 let mut tx = self.pool.begin().await?;
341 if count > 0 {
342 sqlx::query(
343 "
344 UPDATE users
345 SET invite_code = $1
346 WHERE id = $2 AND invite_code IS NULL
347 ",
348 )
349 .bind(nanoid!(16))
350 .bind(id)
351 .execute(&mut tx)
352 .await?;
353 }
354
355 sqlx::query(
356 "
357 UPDATE users
358 SET invite_count = $1
359 WHERE id = $2
360 ",
361 )
362 .bind(count as i32)
363 .bind(id)
364 .execute(&mut tx)
365 .await?;
366 tx.commit().await?;
367 Ok(())
368 }
369
370 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
371 let result: Option<(String, i32)> = sqlx::query_as(
372 "
373 SELECT invite_code, invite_count
374 FROM users
375 WHERE id = $1 AND invite_code IS NOT NULL
376 ",
377 )
378 .bind(id)
379 .fetch_optional(&self.pool)
380 .await?;
381 if let Some((code, count)) = result {
382 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
383 } else {
384 Ok(None)
385 }
386 }
387
388 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
389 sqlx::query_as(
390 "
391 SELECT *
392 FROM users
393 WHERE invite_code = $1
394 ",
395 )
396 .bind(code)
397 .fetch_optional(&self.pool)
398 .await?
399 .ok_or_else(|| {
400 Error::Http(
401 StatusCode::NOT_FOUND,
402 "that invite code does not exist".to_string(),
403 )
404 })
405 }
406
407 async fn redeem_invite_code(
408 &self,
409 code: &str,
410 login: &str,
411 email_address: Option<&str>,
412 ) -> Result<UserId> {
413 let mut tx = self.pool.begin().await?;
414
415 let inviter_id: Option<UserId> = sqlx::query_scalar(
416 "
417 UPDATE users
418 SET invite_count = invite_count - 1
419 WHERE
420 invite_code = $1 AND
421 invite_count > 0
422 RETURNING id
423 ",
424 )
425 .bind(code)
426 .fetch_optional(&mut tx)
427 .await?;
428
429 let inviter_id = match inviter_id {
430 Some(inviter_id) => inviter_id,
431 None => {
432 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
433 .bind(code)
434 .fetch_optional(&mut tx)
435 .await?
436 .is_some()
437 {
438 Err(Error::Http(
439 StatusCode::UNAUTHORIZED,
440 "no invites remaining".to_string(),
441 ))?
442 } else {
443 Err(Error::Http(
444 StatusCode::NOT_FOUND,
445 "invite code not found".to_string(),
446 ))?
447 }
448 }
449 };
450
451 let invitee_id = sqlx::query_scalar(
452 "
453 INSERT INTO users
454 (github_login, email_address, admin, inviter_id)
455 VALUES
456 ($1, $2, 'f', $3)
457 RETURNING id
458 ",
459 )
460 .bind(login)
461 .bind(email_address)
462 .bind(inviter_id)
463 .fetch_one(&mut tx)
464 .await
465 .map(UserId)?;
466
467 sqlx::query(
468 "
469 INSERT INTO contacts
470 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
471 VALUES
472 ($1, $2, 't', 't', 't')
473 ",
474 )
475 .bind(inviter_id)
476 .bind(invitee_id)
477 .execute(&mut tx)
478 .await?;
479
480 tx.commit().await?;
481 Ok(invitee_id)
482 }
483
484 // projects
485
486 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
487 Ok(sqlx::query_scalar(
488 "
489 INSERT INTO projects(host_user_id)
490 VALUES ($1)
491 RETURNING id
492 ",
493 )
494 .bind(host_user_id)
495 .fetch_one(&self.pool)
496 .await
497 .map(ProjectId)?)
498 }
499
500 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
501 sqlx::query(
502 "
503 UPDATE projects
504 SET unregistered = 't'
505 WHERE id = $1
506 ",
507 )
508 .bind(project_id)
509 .execute(&self.pool)
510 .await?;
511 Ok(())
512 }
513
514 async fn update_worktree_extensions(
515 &self,
516 project_id: ProjectId,
517 worktree_id: u64,
518 extensions: HashMap<String, u32>,
519 ) -> Result<()> {
520 if extensions.is_empty() {
521 return Ok(());
522 }
523
524 let mut query = QueryBuilder::new(
525 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
526 );
527 query.push_values(extensions, |mut query, (extension, count)| {
528 query
529 .push_bind(project_id)
530 .push_bind(worktree_id as i32)
531 .push_bind(extension)
532 .push_bind(count as i32);
533 });
534 query.push(
535 "
536 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
537 count = excluded.count
538 ",
539 );
540 query.build().execute(&self.pool).await?;
541
542 Ok(())
543 }
544
545 async fn get_project_extensions(
546 &self,
547 project_id: ProjectId,
548 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
549 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
550 struct WorktreeExtension {
551 worktree_id: i32,
552 extension: String,
553 count: i32,
554 }
555
556 let query = "
557 SELECT worktree_id, extension, count
558 FROM worktree_extensions
559 WHERE project_id = $1
560 ";
561 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
562 .bind(&project_id)
563 .fetch_all(&self.pool)
564 .await?;
565
566 let mut extension_counts = HashMap::default();
567 for count in counts {
568 extension_counts
569 .entry(count.worktree_id as u64)
570 .or_insert(HashMap::default())
571 .insert(count.extension, count.count as usize);
572 }
573 Ok(extension_counts)
574 }
575
576 async fn record_user_activity(
577 &self,
578 time_period: Range<OffsetDateTime>,
579 projects: &[(UserId, ProjectId)],
580 ) -> Result<()> {
581 let query = "
582 INSERT INTO project_activity_periods
583 (ended_at, duration_millis, user_id, project_id)
584 VALUES
585 ($1, $2, $3, $4);
586 ";
587
588 let mut tx = self.pool.begin().await?;
589 let duration_millis =
590 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
591 for (user_id, project_id) in projects {
592 sqlx::query(query)
593 .bind(time_period.end)
594 .bind(duration_millis)
595 .bind(user_id)
596 .bind(project_id)
597 .execute(&mut tx)
598 .await?;
599 }
600 tx.commit().await?;
601
602 Ok(())
603 }
604
605 async fn get_active_user_count(
606 &self,
607 time_period: Range<OffsetDateTime>,
608 min_duration: Duration,
609 only_collaborative: bool,
610 ) -> Result<usize> {
611 let mut with_clause = String::new();
612 with_clause.push_str("WITH\n");
613 with_clause.push_str(
614 "
615 project_durations AS (
616 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
617 FROM project_activity_periods
618 WHERE $1 < ended_at AND ended_at <= $2
619 GROUP BY user_id, project_id
620 ),
621 ",
622 );
623 with_clause.push_str(
624 "
625 project_collaborators as (
626 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
627 FROM project_durations
628 GROUP BY project_id
629 ),
630 ",
631 );
632
633 if only_collaborative {
634 with_clause.push_str(
635 "
636 user_durations AS (
637 SELECT user_id, SUM(project_duration) as total_duration
638 FROM project_durations, project_collaborators
639 WHERE
640 project_durations.project_id = project_collaborators.project_id AND
641 max_collaborators > 1
642 GROUP BY user_id
643 ORDER BY total_duration DESC
644 LIMIT $3
645 )
646 ",
647 );
648 } else {
649 with_clause.push_str(
650 "
651 user_durations AS (
652 SELECT user_id, SUM(project_duration) as total_duration
653 FROM project_durations
654 GROUP BY user_id
655 ORDER BY total_duration DESC
656 LIMIT $3
657 )
658 ",
659 );
660 }
661
662 let query = format!(
663 "
664 {with_clause}
665 SELECT count(user_durations.user_id)
666 FROM user_durations
667 WHERE user_durations.total_duration >= $3
668 "
669 );
670
671 let count: i64 = sqlx::query_scalar(&query)
672 .bind(time_period.start)
673 .bind(time_period.end)
674 .bind(min_duration.as_millis() as i64)
675 .fetch_one(&self.pool)
676 .await?;
677 Ok(count as usize)
678 }
679
680 async fn get_top_users_activity_summary(
681 &self,
682 time_period: Range<OffsetDateTime>,
683 max_user_count: usize,
684 ) -> Result<Vec<UserActivitySummary>> {
685 let query = "
686 WITH
687 project_durations AS (
688 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
689 FROM project_activity_periods
690 WHERE $1 < ended_at AND ended_at <= $2
691 GROUP BY user_id, project_id
692 ),
693 user_durations AS (
694 SELECT user_id, SUM(project_duration) as total_duration
695 FROM project_durations
696 GROUP BY user_id
697 ORDER BY total_duration DESC
698 LIMIT $3
699 ),
700 project_collaborators as (
701 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
702 FROM project_durations
703 GROUP BY project_id
704 )
705 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
706 FROM user_durations, project_durations, project_collaborators, users
707 WHERE
708 user_durations.user_id = project_durations.user_id AND
709 user_durations.user_id = users.id AND
710 project_durations.project_id = project_collaborators.project_id
711 ORDER BY total_duration DESC, user_id ASC
712 ";
713
714 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
715 .bind(time_period.start)
716 .bind(time_period.end)
717 .bind(max_user_count as i32)
718 .fetch(&self.pool);
719
720 let mut result = Vec::<UserActivitySummary>::new();
721 while let Some(row) = rows.next().await {
722 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
723 let project_id = project_id;
724 let duration = Duration::from_millis(duration_millis as u64);
725 let project_activity = ProjectActivitySummary {
726 id: project_id,
727 duration,
728 max_collaborators: project_collaborators as usize,
729 };
730 if let Some(last_summary) = result.last_mut() {
731 if last_summary.id == user_id {
732 last_summary.project_activity.push(project_activity);
733 continue;
734 }
735 }
736 result.push(UserActivitySummary {
737 id: user_id,
738 project_activity: vec![project_activity],
739 github_login,
740 });
741 }
742
743 Ok(result)
744 }
745
746 async fn get_user_activity_timeline(
747 &self,
748 time_period: Range<OffsetDateTime>,
749 user_id: UserId,
750 ) -> Result<Vec<UserActivityPeriod>> {
751 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
752
753 let query = "
754 SELECT
755 project_activity_periods.ended_at,
756 project_activity_periods.duration_millis,
757 project_activity_periods.project_id,
758 worktree_extensions.extension,
759 worktree_extensions.count
760 FROM project_activity_periods
761 LEFT OUTER JOIN
762 worktree_extensions
763 ON
764 project_activity_periods.project_id = worktree_extensions.project_id
765 WHERE
766 project_activity_periods.user_id = $1 AND
767 $2 < project_activity_periods.ended_at AND
768 project_activity_periods.ended_at <= $3
769 ORDER BY project_activity_periods.id ASC
770 ";
771
772 let mut rows = sqlx::query_as::<
773 _,
774 (
775 PrimitiveDateTime,
776 i32,
777 ProjectId,
778 Option<String>,
779 Option<i32>,
780 ),
781 >(query)
782 .bind(user_id)
783 .bind(time_period.start)
784 .bind(time_period.end)
785 .fetch(&self.pool);
786
787 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
788 while let Some(row) = rows.next().await {
789 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
790 let ended_at = ended_at.assume_utc();
791 let duration = Duration::from_millis(duration_millis as u64);
792 let started_at = ended_at - duration;
793 let project_time_periods = time_periods.entry(project_id).or_default();
794
795 if let Some(prev_duration) = project_time_periods.last_mut() {
796 if started_at <= prev_duration.end + COALESCE_THRESHOLD
797 && ended_at >= prev_duration.start
798 {
799 prev_duration.end = cmp::max(prev_duration.end, ended_at);
800 } else {
801 project_time_periods.push(UserActivityPeriod {
802 project_id,
803 start: started_at,
804 end: ended_at,
805 extensions: Default::default(),
806 });
807 }
808 } else {
809 project_time_periods.push(UserActivityPeriod {
810 project_id,
811 start: started_at,
812 end: ended_at,
813 extensions: Default::default(),
814 });
815 }
816
817 if let Some((extension, extension_count)) = extension.zip(extension_count) {
818 project_time_periods
819 .last_mut()
820 .unwrap()
821 .extensions
822 .insert(extension, extension_count as usize);
823 }
824 }
825
826 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
827 durations.sort_unstable_by_key(|duration| duration.start);
828 Ok(durations)
829 }
830
831 // contacts
832
833 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
834 let query = "
835 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
836 FROM contacts
837 WHERE user_id_a = $1 OR user_id_b = $1;
838 ";
839
840 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
841 .bind(user_id)
842 .fetch(&self.pool);
843
844 let mut contacts = vec![Contact::Accepted {
845 user_id,
846 should_notify: false,
847 }];
848 while let Some(row) = rows.next().await {
849 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
850
851 if user_id_a == user_id {
852 if accepted {
853 contacts.push(Contact::Accepted {
854 user_id: user_id_b,
855 should_notify: should_notify && a_to_b,
856 });
857 } else if a_to_b {
858 contacts.push(Contact::Outgoing { user_id: user_id_b })
859 } else {
860 contacts.push(Contact::Incoming {
861 user_id: user_id_b,
862 should_notify,
863 });
864 }
865 } else {
866 if accepted {
867 contacts.push(Contact::Accepted {
868 user_id: user_id_a,
869 should_notify: should_notify && !a_to_b,
870 });
871 } else if a_to_b {
872 contacts.push(Contact::Incoming {
873 user_id: user_id_a,
874 should_notify,
875 });
876 } else {
877 contacts.push(Contact::Outgoing { user_id: user_id_a });
878 }
879 }
880 }
881
882 contacts.sort_unstable_by_key(|contact| contact.user_id());
883
884 Ok(contacts)
885 }
886
887 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
888 let (id_a, id_b) = if user_id_1 < user_id_2 {
889 (user_id_1, user_id_2)
890 } else {
891 (user_id_2, user_id_1)
892 };
893
894 let query = "
895 SELECT 1 FROM contacts
896 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
897 LIMIT 1
898 ";
899 Ok(sqlx::query_scalar::<_, i32>(query)
900 .bind(id_a.0)
901 .bind(id_b.0)
902 .fetch_optional(&self.pool)
903 .await?
904 .is_some())
905 }
906
907 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
908 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
909 (sender_id, receiver_id, true)
910 } else {
911 (receiver_id, sender_id, false)
912 };
913 let query = "
914 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
915 VALUES ($1, $2, $3, 'f', 't')
916 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
917 SET
918 accepted = 't',
919 should_notify = 'f'
920 WHERE
921 NOT contacts.accepted AND
922 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
923 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
924 ";
925 let result = sqlx::query(query)
926 .bind(id_a.0)
927 .bind(id_b.0)
928 .bind(a_to_b)
929 .execute(&self.pool)
930 .await?;
931
932 if result.rows_affected() == 1 {
933 Ok(())
934 } else {
935 Err(anyhow!("contact already requested"))?
936 }
937 }
938
939 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
940 let (id_a, id_b) = if responder_id < requester_id {
941 (responder_id, requester_id)
942 } else {
943 (requester_id, responder_id)
944 };
945 let query = "
946 DELETE FROM contacts
947 WHERE user_id_a = $1 AND user_id_b = $2;
948 ";
949 let result = sqlx::query(query)
950 .bind(id_a.0)
951 .bind(id_b.0)
952 .execute(&self.pool)
953 .await?;
954
955 if result.rows_affected() == 1 {
956 Ok(())
957 } else {
958 Err(anyhow!("no such contact"))?
959 }
960 }
961
962 async fn dismiss_contact_notification(
963 &self,
964 user_id: UserId,
965 contact_user_id: UserId,
966 ) -> Result<()> {
967 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
968 (user_id, contact_user_id, true)
969 } else {
970 (contact_user_id, user_id, false)
971 };
972
973 let query = "
974 UPDATE contacts
975 SET should_notify = 'f'
976 WHERE
977 user_id_a = $1 AND user_id_b = $2 AND
978 (
979 (a_to_b = $3 AND accepted) OR
980 (a_to_b != $3 AND NOT accepted)
981 );
982 ";
983
984 let result = sqlx::query(query)
985 .bind(id_a.0)
986 .bind(id_b.0)
987 .bind(a_to_b)
988 .execute(&self.pool)
989 .await?;
990
991 if result.rows_affected() == 0 {
992 Err(anyhow!("no such contact request"))?;
993 }
994
995 Ok(())
996 }
997
998 async fn respond_to_contact_request(
999 &self,
1000 responder_id: UserId,
1001 requester_id: UserId,
1002 accept: bool,
1003 ) -> Result<()> {
1004 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1005 (responder_id, requester_id, false)
1006 } else {
1007 (requester_id, responder_id, true)
1008 };
1009 let result = if accept {
1010 let query = "
1011 UPDATE contacts
1012 SET accepted = 't', should_notify = 't'
1013 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1014 ";
1015 sqlx::query(query)
1016 .bind(id_a.0)
1017 .bind(id_b.0)
1018 .bind(a_to_b)
1019 .execute(&self.pool)
1020 .await?
1021 } else {
1022 let query = "
1023 DELETE FROM contacts
1024 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1025 ";
1026 sqlx::query(query)
1027 .bind(id_a.0)
1028 .bind(id_b.0)
1029 .bind(a_to_b)
1030 .execute(&self.pool)
1031 .await?
1032 };
1033 if result.rows_affected() == 1 {
1034 Ok(())
1035 } else {
1036 Err(anyhow!("no such contact request"))?
1037 }
1038 }
1039
1040 // access tokens
1041
1042 async fn create_access_token_hash(
1043 &self,
1044 user_id: UserId,
1045 access_token_hash: &str,
1046 max_access_token_count: usize,
1047 ) -> Result<()> {
1048 let insert_query = "
1049 INSERT INTO access_tokens (user_id, hash)
1050 VALUES ($1, $2);
1051 ";
1052 let cleanup_query = "
1053 DELETE FROM access_tokens
1054 WHERE id IN (
1055 SELECT id from access_tokens
1056 WHERE user_id = $1
1057 ORDER BY id DESC
1058 OFFSET $3
1059 )
1060 ";
1061
1062 let mut tx = self.pool.begin().await?;
1063 sqlx::query(insert_query)
1064 .bind(user_id.0)
1065 .bind(access_token_hash)
1066 .execute(&mut tx)
1067 .await?;
1068 sqlx::query(cleanup_query)
1069 .bind(user_id.0)
1070 .bind(access_token_hash)
1071 .bind(max_access_token_count as i32)
1072 .execute(&mut tx)
1073 .await?;
1074 Ok(tx.commit().await?)
1075 }
1076
1077 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1078 let query = "
1079 SELECT hash
1080 FROM access_tokens
1081 WHERE user_id = $1
1082 ORDER BY id DESC
1083 ";
1084 Ok(sqlx::query_scalar(query)
1085 .bind(user_id.0)
1086 .fetch_all(&self.pool)
1087 .await?)
1088 }
1089
1090 // orgs
1091
1092 #[allow(unused)] // Help rust-analyzer
1093 #[cfg(any(test, feature = "seed-support"))]
1094 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1095 let query = "
1096 SELECT *
1097 FROM orgs
1098 WHERE slug = $1
1099 ";
1100 Ok(sqlx::query_as(query)
1101 .bind(slug)
1102 .fetch_optional(&self.pool)
1103 .await?)
1104 }
1105
1106 #[cfg(any(test, feature = "seed-support"))]
1107 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1108 let query = "
1109 INSERT INTO orgs (name, slug)
1110 VALUES ($1, $2)
1111 RETURNING id
1112 ";
1113 Ok(sqlx::query_scalar(query)
1114 .bind(name)
1115 .bind(slug)
1116 .fetch_one(&self.pool)
1117 .await
1118 .map(OrgId)?)
1119 }
1120
1121 #[cfg(any(test, feature = "seed-support"))]
1122 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1123 let query = "
1124 INSERT INTO org_memberships (org_id, user_id, admin)
1125 VALUES ($1, $2, $3)
1126 ON CONFLICT DO NOTHING
1127 ";
1128 Ok(sqlx::query(query)
1129 .bind(org_id.0)
1130 .bind(user_id.0)
1131 .bind(is_admin)
1132 .execute(&self.pool)
1133 .await
1134 .map(drop)?)
1135 }
1136
1137 // channels
1138
1139 #[cfg(any(test, feature = "seed-support"))]
1140 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1141 let query = "
1142 INSERT INTO channels (owner_id, owner_is_user, name)
1143 VALUES ($1, false, $2)
1144 RETURNING id
1145 ";
1146 Ok(sqlx::query_scalar(query)
1147 .bind(org_id.0)
1148 .bind(name)
1149 .fetch_one(&self.pool)
1150 .await
1151 .map(ChannelId)?)
1152 }
1153
1154 #[allow(unused)] // Help rust-analyzer
1155 #[cfg(any(test, feature = "seed-support"))]
1156 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1157 let query = "
1158 SELECT *
1159 FROM channels
1160 WHERE
1161 channels.owner_is_user = false AND
1162 channels.owner_id = $1
1163 ";
1164 Ok(sqlx::query_as(query)
1165 .bind(org_id.0)
1166 .fetch_all(&self.pool)
1167 .await?)
1168 }
1169
1170 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1171 let query = "
1172 SELECT
1173 channels.*
1174 FROM
1175 channel_memberships, channels
1176 WHERE
1177 channel_memberships.user_id = $1 AND
1178 channel_memberships.channel_id = channels.id
1179 ";
1180 Ok(sqlx::query_as(query)
1181 .bind(user_id.0)
1182 .fetch_all(&self.pool)
1183 .await?)
1184 }
1185
1186 async fn can_user_access_channel(
1187 &self,
1188 user_id: UserId,
1189 channel_id: ChannelId,
1190 ) -> Result<bool> {
1191 let query = "
1192 SELECT id
1193 FROM channel_memberships
1194 WHERE user_id = $1 AND channel_id = $2
1195 LIMIT 1
1196 ";
1197 Ok(sqlx::query_scalar::<_, i32>(query)
1198 .bind(user_id.0)
1199 .bind(channel_id.0)
1200 .fetch_optional(&self.pool)
1201 .await
1202 .map(|e| e.is_some())?)
1203 }
1204
1205 #[cfg(any(test, feature = "seed-support"))]
1206 async fn add_channel_member(
1207 &self,
1208 channel_id: ChannelId,
1209 user_id: UserId,
1210 is_admin: bool,
1211 ) -> Result<()> {
1212 let query = "
1213 INSERT INTO channel_memberships (channel_id, user_id, admin)
1214 VALUES ($1, $2, $3)
1215 ON CONFLICT DO NOTHING
1216 ";
1217 Ok(sqlx::query(query)
1218 .bind(channel_id.0)
1219 .bind(user_id.0)
1220 .bind(is_admin)
1221 .execute(&self.pool)
1222 .await
1223 .map(drop)?)
1224 }
1225
1226 // messages
1227
1228 async fn create_channel_message(
1229 &self,
1230 channel_id: ChannelId,
1231 sender_id: UserId,
1232 body: &str,
1233 timestamp: OffsetDateTime,
1234 nonce: u128,
1235 ) -> Result<MessageId> {
1236 let query = "
1237 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1238 VALUES ($1, $2, $3, $4, $5)
1239 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1240 RETURNING id
1241 ";
1242 Ok(sqlx::query_scalar(query)
1243 .bind(channel_id.0)
1244 .bind(sender_id.0)
1245 .bind(body)
1246 .bind(timestamp)
1247 .bind(Uuid::from_u128(nonce))
1248 .fetch_one(&self.pool)
1249 .await
1250 .map(MessageId)?)
1251 }
1252
1253 async fn get_channel_messages(
1254 &self,
1255 channel_id: ChannelId,
1256 count: usize,
1257 before_id: Option<MessageId>,
1258 ) -> Result<Vec<ChannelMessage>> {
1259 let query = r#"
1260 SELECT * FROM (
1261 SELECT
1262 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1263 FROM
1264 channel_messages
1265 WHERE
1266 channel_id = $1 AND
1267 id < $2
1268 ORDER BY id DESC
1269 LIMIT $3
1270 ) as recent_messages
1271 ORDER BY id ASC
1272 "#;
1273 Ok(sqlx::query_as(query)
1274 .bind(channel_id.0)
1275 .bind(before_id.unwrap_or(MessageId::MAX))
1276 .bind(count as i64)
1277 .fetch_all(&self.pool)
1278 .await?)
1279 }
1280
1281 #[cfg(test)]
1282 async fn teardown(&self, url: &str) {
1283 use util::ResultExt;
1284
1285 let query = "
1286 SELECT pg_terminate_backend(pg_stat_activity.pid)
1287 FROM pg_stat_activity
1288 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1289 ";
1290 sqlx::query(query).execute(&self.pool).await.log_err();
1291 self.pool.close().await;
1292 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1293 .await
1294 .log_err();
1295 }
1296
1297 #[cfg(test)]
1298 fn as_fake(&self) -> Option<&tests::FakeDb> {
1299 None
1300 }
1301}
1302
1303macro_rules! id_type {
1304 ($name:ident) => {
1305 #[derive(
1306 Clone,
1307 Copy,
1308 Debug,
1309 Default,
1310 PartialEq,
1311 Eq,
1312 PartialOrd,
1313 Ord,
1314 Hash,
1315 sqlx::Type,
1316 Serialize,
1317 Deserialize,
1318 )]
1319 #[sqlx(transparent)]
1320 #[serde(transparent)]
1321 pub struct $name(pub i32);
1322
1323 impl $name {
1324 #[allow(unused)]
1325 pub const MAX: Self = Self(i32::MAX);
1326
1327 #[allow(unused)]
1328 pub fn from_proto(value: u64) -> Self {
1329 Self(value as i32)
1330 }
1331
1332 #[allow(unused)]
1333 pub fn to_proto(&self) -> u64 {
1334 self.0 as u64
1335 }
1336 }
1337
1338 impl std::fmt::Display for $name {
1339 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1340 self.0.fmt(f)
1341 }
1342 }
1343 };
1344}
1345
1346id_type!(UserId);
1347#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1348pub struct User {
1349 pub id: UserId,
1350 pub github_login: String,
1351 pub email_address: Option<String>,
1352 pub admin: bool,
1353 pub invite_code: Option<String>,
1354 pub invite_count: i32,
1355 pub connected_once: bool,
1356}
1357
1358id_type!(ProjectId);
1359#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1360pub struct Project {
1361 pub id: ProjectId,
1362 pub host_user_id: UserId,
1363 pub unregistered: bool,
1364}
1365
1366#[derive(Clone, Debug, PartialEq, Serialize)]
1367pub struct UserActivitySummary {
1368 pub id: UserId,
1369 pub github_login: String,
1370 pub project_activity: Vec<ProjectActivitySummary>,
1371}
1372
1373#[derive(Clone, Debug, PartialEq, Serialize)]
1374pub struct ProjectActivitySummary {
1375 id: ProjectId,
1376 duration: Duration,
1377 max_collaborators: usize,
1378}
1379
1380#[derive(Clone, Debug, PartialEq, Serialize)]
1381pub struct UserActivityPeriod {
1382 project_id: ProjectId,
1383 #[serde(with = "time::serde::iso8601")]
1384 start: OffsetDateTime,
1385 #[serde(with = "time::serde::iso8601")]
1386 end: OffsetDateTime,
1387 extensions: HashMap<String, usize>,
1388}
1389
1390id_type!(OrgId);
1391#[derive(FromRow)]
1392pub struct Org {
1393 pub id: OrgId,
1394 pub name: String,
1395 pub slug: String,
1396}
1397
1398id_type!(ChannelId);
1399#[derive(Clone, Debug, FromRow, Serialize)]
1400pub struct Channel {
1401 pub id: ChannelId,
1402 pub name: String,
1403 pub owner_id: i32,
1404 pub owner_is_user: bool,
1405}
1406
1407id_type!(MessageId);
1408#[derive(Clone, Debug, FromRow)]
1409pub struct ChannelMessage {
1410 pub id: MessageId,
1411 pub channel_id: ChannelId,
1412 pub sender_id: UserId,
1413 pub body: String,
1414 pub sent_at: OffsetDateTime,
1415 pub nonce: Uuid,
1416}
1417
1418#[derive(Clone, Debug, PartialEq, Eq)]
1419pub enum Contact {
1420 Accepted {
1421 user_id: UserId,
1422 should_notify: bool,
1423 },
1424 Outgoing {
1425 user_id: UserId,
1426 },
1427 Incoming {
1428 user_id: UserId,
1429 should_notify: bool,
1430 },
1431}
1432
1433impl Contact {
1434 pub fn user_id(&self) -> UserId {
1435 match self {
1436 Contact::Accepted { user_id, .. } => *user_id,
1437 Contact::Outgoing { user_id } => *user_id,
1438 Contact::Incoming { user_id, .. } => *user_id,
1439 }
1440 }
1441}
1442
1443#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1444pub struct IncomingContactRequest {
1445 pub requester_id: UserId,
1446 pub should_notify: bool,
1447}
1448
1449fn fuzzy_like_string(string: &str) -> String {
1450 let mut result = String::with_capacity(string.len() * 2 + 1);
1451 for c in string.chars() {
1452 if c.is_alphanumeric() {
1453 result.push('%');
1454 result.push(c);
1455 }
1456 }
1457 result.push('%');
1458 result
1459}
1460
1461#[cfg(test)]
1462pub mod tests {
1463 use super::*;
1464 use anyhow::anyhow;
1465 use collections::BTreeMap;
1466 use gpui::executor::{Background, Deterministic};
1467 use lazy_static::lazy_static;
1468 use parking_lot::Mutex;
1469 use rand::prelude::*;
1470 use sqlx::{
1471 migrate::{MigrateDatabase, Migrator},
1472 Postgres,
1473 };
1474 use std::{path::Path, sync::Arc};
1475 use util::post_inc;
1476
1477 #[tokio::test(flavor = "multi_thread")]
1478 async fn test_get_users_by_ids() {
1479 for test_db in [
1480 TestDb::postgres().await,
1481 TestDb::fake(build_background_executor()),
1482 ] {
1483 let db = test_db.db();
1484
1485 let user = db.create_user("user", None, false).await.unwrap();
1486 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1487 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1488 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1489
1490 assert_eq!(
1491 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1492 .await
1493 .unwrap(),
1494 vec![
1495 User {
1496 id: user,
1497 github_login: "user".to_string(),
1498 admin: false,
1499 ..Default::default()
1500 },
1501 User {
1502 id: friend1,
1503 github_login: "friend-1".to_string(),
1504 admin: false,
1505 ..Default::default()
1506 },
1507 User {
1508 id: friend2,
1509 github_login: "friend-2".to_string(),
1510 admin: false,
1511 ..Default::default()
1512 },
1513 User {
1514 id: friend3,
1515 github_login: "friend-3".to_string(),
1516 admin: false,
1517 ..Default::default()
1518 }
1519 ]
1520 );
1521 }
1522 }
1523
1524 #[tokio::test(flavor = "multi_thread")]
1525 async fn test_create_users() {
1526 let db = TestDb::postgres().await;
1527 let db = db.db();
1528
1529 // Create the first batch of users, ensuring invite counts are assigned
1530 // correctly and the respective invite codes are unique.
1531 let user_ids_batch_1 = db
1532 .create_users(vec![
1533 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1534 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1535 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1536 ])
1537 .await
1538 .unwrap();
1539 assert_eq!(user_ids_batch_1.len(), 3);
1540
1541 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1542 assert_eq!(users.len(), 3);
1543 assert_eq!(users[0].github_login, "user1");
1544 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1545 assert_eq!(users[0].invite_count, 5);
1546 assert_eq!(users[1].github_login, "user2");
1547 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1548 assert_eq!(users[1].invite_count, 4);
1549 assert_eq!(users[2].github_login, "user3");
1550 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1551 assert_eq!(users[2].invite_count, 3);
1552
1553 let invite_code_1 = users[0].invite_code.clone().unwrap();
1554 let invite_code_2 = users[1].invite_code.clone().unwrap();
1555 let invite_code_3 = users[2].invite_code.clone().unwrap();
1556 assert_ne!(invite_code_1, invite_code_2);
1557 assert_ne!(invite_code_1, invite_code_3);
1558 assert_ne!(invite_code_2, invite_code_3);
1559
1560 // Create the second batch of users and include a user that is already in the database, ensuring
1561 // the invite count for the existing user is updated without changing their invite code.
1562 let user_ids_batch_2 = db
1563 .create_users(vec![
1564 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1565 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1566 ])
1567 .await
1568 .unwrap();
1569 assert_eq!(user_ids_batch_2.len(), 2);
1570 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1571
1572 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1573 assert_eq!(users.len(), 2);
1574 assert_eq!(users[0].github_login, "user2");
1575 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1576 assert_eq!(users[0].invite_count, 10);
1577 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1578 assert_eq!(users[1].github_login, "user4");
1579 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1580 assert_eq!(users[1].invite_count, 2);
1581
1582 let invite_code_4 = users[1].invite_code.clone().unwrap();
1583 assert_ne!(invite_code_4, invite_code_1);
1584 assert_ne!(invite_code_4, invite_code_2);
1585 assert_ne!(invite_code_4, invite_code_3);
1586 }
1587
1588 #[tokio::test(flavor = "multi_thread")]
1589 async fn test_worktree_extensions() {
1590 let test_db = TestDb::postgres().await;
1591 let db = test_db.db();
1592
1593 let user = db.create_user("user_1", None, false).await.unwrap();
1594 let project = db.register_project(user).await.unwrap();
1595
1596 db.update_worktree_extensions(project, 100, Default::default())
1597 .await
1598 .unwrap();
1599 db.update_worktree_extensions(
1600 project,
1601 100,
1602 [("rs".to_string(), 5), ("md".to_string(), 3)]
1603 .into_iter()
1604 .collect(),
1605 )
1606 .await
1607 .unwrap();
1608 db.update_worktree_extensions(
1609 project,
1610 100,
1611 [("rs".to_string(), 6), ("md".to_string(), 5)]
1612 .into_iter()
1613 .collect(),
1614 )
1615 .await
1616 .unwrap();
1617 db.update_worktree_extensions(
1618 project,
1619 101,
1620 [("ts".to_string(), 2), ("md".to_string(), 1)]
1621 .into_iter()
1622 .collect(),
1623 )
1624 .await
1625 .unwrap();
1626
1627 assert_eq!(
1628 db.get_project_extensions(project).await.unwrap(),
1629 [
1630 (
1631 100,
1632 [("rs".into(), 6), ("md".into(), 5),]
1633 .into_iter()
1634 .collect::<HashMap<_, _>>()
1635 ),
1636 (
1637 101,
1638 [("ts".into(), 2), ("md".into(), 1),]
1639 .into_iter()
1640 .collect::<HashMap<_, _>>()
1641 )
1642 ]
1643 .into_iter()
1644 .collect()
1645 );
1646 }
1647
1648 #[tokio::test(flavor = "multi_thread")]
1649 async fn test_user_activity() {
1650 let test_db = TestDb::postgres().await;
1651 let db = test_db.db();
1652
1653 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1654 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1655 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1656 let project_1 = db.register_project(user_1).await.unwrap();
1657 db.update_worktree_extensions(
1658 project_1,
1659 1,
1660 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1661 )
1662 .await
1663 .unwrap();
1664 let project_2 = db.register_project(user_2).await.unwrap();
1665 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1666
1667 // User 2 opens a project
1668 let t1 = t0 + Duration::from_secs(10);
1669 db.record_user_activity(t0..t1, &[(user_2, project_2)])
1670 .await
1671 .unwrap();
1672
1673 let t2 = t1 + Duration::from_secs(10);
1674 db.record_user_activity(t1..t2, &[(user_2, project_2)])
1675 .await
1676 .unwrap();
1677
1678 // User 1 joins the project
1679 let t3 = t2 + Duration::from_secs(10);
1680 db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1681 .await
1682 .unwrap();
1683
1684 // User 1 opens another project
1685 let t4 = t3 + Duration::from_secs(10);
1686 db.record_user_activity(
1687 t3..t4,
1688 &[
1689 (user_2, project_2),
1690 (user_1, project_2),
1691 (user_1, project_1),
1692 ],
1693 )
1694 .await
1695 .unwrap();
1696
1697 // User 3 joins that project
1698 let t5 = t4 + Duration::from_secs(10);
1699 db.record_user_activity(
1700 t4..t5,
1701 &[
1702 (user_2, project_2),
1703 (user_1, project_2),
1704 (user_1, project_1),
1705 (user_3, project_1),
1706 ],
1707 )
1708 .await
1709 .unwrap();
1710
1711 // User 2 leaves
1712 let t6 = t5 + Duration::from_secs(5);
1713 db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1714 .await
1715 .unwrap();
1716
1717 let t7 = t6 + Duration::from_secs(60);
1718 let t8 = t7 + Duration::from_secs(10);
1719 db.record_user_activity(t7..t8, &[(user_1, project_1)])
1720 .await
1721 .unwrap();
1722
1723 assert_eq!(
1724 db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1725 &[
1726 UserActivitySummary {
1727 id: user_1,
1728 github_login: "user_1".to_string(),
1729 project_activity: vec![
1730 ProjectActivitySummary {
1731 id: project_1,
1732 duration: Duration::from_secs(25),
1733 max_collaborators: 2
1734 },
1735 ProjectActivitySummary {
1736 id: project_2,
1737 duration: Duration::from_secs(30),
1738 max_collaborators: 2
1739 }
1740 ]
1741 },
1742 UserActivitySummary {
1743 id: user_2,
1744 github_login: "user_2".to_string(),
1745 project_activity: vec![ProjectActivitySummary {
1746 id: project_2,
1747 duration: Duration::from_secs(50),
1748 max_collaborators: 2
1749 }]
1750 },
1751 UserActivitySummary {
1752 id: user_3,
1753 github_login: "user_3".to_string(),
1754 project_activity: vec![ProjectActivitySummary {
1755 id: project_1,
1756 duration: Duration::from_secs(15),
1757 max_collaborators: 2
1758 }]
1759 },
1760 ]
1761 );
1762
1763 assert_eq!(
1764 db.get_active_user_count(t0..t6, Duration::from_secs(56), false)
1765 .await
1766 .unwrap(),
1767 0
1768 );
1769 assert_eq!(
1770 db.get_active_user_count(t0..t6, Duration::from_secs(56), true)
1771 .await
1772 .unwrap(),
1773 0
1774 );
1775 assert_eq!(
1776 db.get_active_user_count(t0..t6, Duration::from_secs(54), false)
1777 .await
1778 .unwrap(),
1779 1
1780 );
1781 assert_eq!(
1782 db.get_active_user_count(t0..t6, Duration::from_secs(54), true)
1783 .await
1784 .unwrap(),
1785 1
1786 );
1787 assert_eq!(
1788 db.get_active_user_count(t0..t6, Duration::from_secs(30), false)
1789 .await
1790 .unwrap(),
1791 2
1792 );
1793 assert_eq!(
1794 db.get_active_user_count(t0..t6, Duration::from_secs(30), true)
1795 .await
1796 .unwrap(),
1797 2
1798 );
1799 assert_eq!(
1800 db.get_active_user_count(t0..t6, Duration::from_secs(10), false)
1801 .await
1802 .unwrap(),
1803 3
1804 );
1805 assert_eq!(
1806 db.get_active_user_count(t0..t6, Duration::from_secs(10), true)
1807 .await
1808 .unwrap(),
1809 3
1810 );
1811 assert_eq!(
1812 db.get_active_user_count(t0..t1, Duration::from_secs(5), false)
1813 .await
1814 .unwrap(),
1815 1
1816 );
1817 assert_eq!(
1818 db.get_active_user_count(t0..t1, Duration::from_secs(5), true)
1819 .await
1820 .unwrap(),
1821 0
1822 );
1823
1824 assert_eq!(
1825 db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1826 &[
1827 UserActivityPeriod {
1828 project_id: project_1,
1829 start: t3,
1830 end: t6,
1831 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1832 },
1833 UserActivityPeriod {
1834 project_id: project_2,
1835 start: t3,
1836 end: t5,
1837 extensions: Default::default(),
1838 },
1839 ]
1840 );
1841 assert_eq!(
1842 db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1843 &[
1844 UserActivityPeriod {
1845 project_id: project_2,
1846 start: t2,
1847 end: t5,
1848 extensions: Default::default(),
1849 },
1850 UserActivityPeriod {
1851 project_id: project_1,
1852 start: t3,
1853 end: t6,
1854 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1855 },
1856 UserActivityPeriod {
1857 project_id: project_1,
1858 start: t7,
1859 end: t8,
1860 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1861 },
1862 ]
1863 );
1864 }
1865
1866 #[tokio::test(flavor = "multi_thread")]
1867 async fn test_recent_channel_messages() {
1868 for test_db in [
1869 TestDb::postgres().await,
1870 TestDb::fake(build_background_executor()),
1871 ] {
1872 let db = test_db.db();
1873 let user = db.create_user("user", None, false).await.unwrap();
1874 let org = db.create_org("org", "org").await.unwrap();
1875 let channel = db.create_org_channel(org, "channel").await.unwrap();
1876 for i in 0..10 {
1877 db.create_channel_message(
1878 channel,
1879 user,
1880 &i.to_string(),
1881 OffsetDateTime::now_utc(),
1882 i,
1883 )
1884 .await
1885 .unwrap();
1886 }
1887
1888 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1889 assert_eq!(
1890 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1891 ["5", "6", "7", "8", "9"]
1892 );
1893
1894 let prev_messages = db
1895 .get_channel_messages(channel, 4, Some(messages[0].id))
1896 .await
1897 .unwrap();
1898 assert_eq!(
1899 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1900 ["1", "2", "3", "4"]
1901 );
1902 }
1903 }
1904
1905 #[tokio::test(flavor = "multi_thread")]
1906 async fn test_channel_message_nonces() {
1907 for test_db in [
1908 TestDb::postgres().await,
1909 TestDb::fake(build_background_executor()),
1910 ] {
1911 let db = test_db.db();
1912 let user = db.create_user("user", None, false).await.unwrap();
1913 let org = db.create_org("org", "org").await.unwrap();
1914 let channel = db.create_org_channel(org, "channel").await.unwrap();
1915
1916 let msg1_id = db
1917 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1918 .await
1919 .unwrap();
1920 let msg2_id = db
1921 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1922 .await
1923 .unwrap();
1924 let msg3_id = db
1925 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1926 .await
1927 .unwrap();
1928 let msg4_id = db
1929 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1930 .await
1931 .unwrap();
1932
1933 assert_ne!(msg1_id, msg2_id);
1934 assert_eq!(msg1_id, msg3_id);
1935 assert_eq!(msg2_id, msg4_id);
1936 }
1937 }
1938
1939 #[tokio::test(flavor = "multi_thread")]
1940 async fn test_create_access_tokens() {
1941 let test_db = TestDb::postgres().await;
1942 let db = test_db.db();
1943 let user = db.create_user("the-user", None, false).await.unwrap();
1944
1945 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1946 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1947 assert_eq!(
1948 db.get_access_token_hashes(user).await.unwrap(),
1949 &["h2".to_string(), "h1".to_string()]
1950 );
1951
1952 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1953 assert_eq!(
1954 db.get_access_token_hashes(user).await.unwrap(),
1955 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1956 );
1957
1958 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1959 assert_eq!(
1960 db.get_access_token_hashes(user).await.unwrap(),
1961 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1962 );
1963
1964 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1965 assert_eq!(
1966 db.get_access_token_hashes(user).await.unwrap(),
1967 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1968 );
1969 }
1970
1971 #[test]
1972 fn test_fuzzy_like_string() {
1973 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1974 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1975 assert_eq!(fuzzy_like_string(" z "), "%z%");
1976 }
1977
1978 #[tokio::test(flavor = "multi_thread")]
1979 async fn test_fuzzy_search_users() {
1980 let test_db = TestDb::postgres().await;
1981 let db = test_db.db();
1982 for github_login in [
1983 "California",
1984 "colorado",
1985 "oregon",
1986 "washington",
1987 "florida",
1988 "delaware",
1989 "rhode-island",
1990 ] {
1991 db.create_user(github_login, None, false).await.unwrap();
1992 }
1993
1994 assert_eq!(
1995 fuzzy_search_user_names(db, "clr").await,
1996 &["colorado", "California"]
1997 );
1998 assert_eq!(
1999 fuzzy_search_user_names(db, "ro").await,
2000 &["rhode-island", "colorado", "oregon"],
2001 );
2002
2003 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
2004 db.fuzzy_search_users(query, 10)
2005 .await
2006 .unwrap()
2007 .into_iter()
2008 .map(|user| user.github_login)
2009 .collect::<Vec<_>>()
2010 }
2011 }
2012
2013 #[tokio::test(flavor = "multi_thread")]
2014 async fn test_add_contacts() {
2015 for test_db in [
2016 TestDb::postgres().await,
2017 TestDb::fake(build_background_executor()),
2018 ] {
2019 let db = test_db.db();
2020
2021 let user_1 = db.create_user("user1", None, false).await.unwrap();
2022 let user_2 = db.create_user("user2", None, false).await.unwrap();
2023 let user_3 = db.create_user("user3", None, false).await.unwrap();
2024
2025 // User starts with no contacts
2026 assert_eq!(
2027 db.get_contacts(user_1).await.unwrap(),
2028 vec![Contact::Accepted {
2029 user_id: user_1,
2030 should_notify: false
2031 }],
2032 );
2033
2034 // User requests a contact. Both users see the pending request.
2035 db.send_contact_request(user_1, user_2).await.unwrap();
2036 assert!(!db.has_contact(user_1, user_2).await.unwrap());
2037 assert!(!db.has_contact(user_2, user_1).await.unwrap());
2038 assert_eq!(
2039 db.get_contacts(user_1).await.unwrap(),
2040 &[
2041 Contact::Accepted {
2042 user_id: user_1,
2043 should_notify: false
2044 },
2045 Contact::Outgoing { user_id: user_2 }
2046 ],
2047 );
2048 assert_eq!(
2049 db.get_contacts(user_2).await.unwrap(),
2050 &[
2051 Contact::Incoming {
2052 user_id: user_1,
2053 should_notify: true
2054 },
2055 Contact::Accepted {
2056 user_id: user_2,
2057 should_notify: false
2058 },
2059 ]
2060 );
2061
2062 // User 2 dismisses the contact request notification without accepting or rejecting.
2063 // We shouldn't notify them again.
2064 db.dismiss_contact_notification(user_1, user_2)
2065 .await
2066 .unwrap_err();
2067 db.dismiss_contact_notification(user_2, user_1)
2068 .await
2069 .unwrap();
2070 assert_eq!(
2071 db.get_contacts(user_2).await.unwrap(),
2072 &[
2073 Contact::Incoming {
2074 user_id: user_1,
2075 should_notify: false
2076 },
2077 Contact::Accepted {
2078 user_id: user_2,
2079 should_notify: false
2080 },
2081 ]
2082 );
2083
2084 // User can't accept their own contact request
2085 db.respond_to_contact_request(user_1, user_2, true)
2086 .await
2087 .unwrap_err();
2088
2089 // User accepts a contact request. Both users see the contact.
2090 db.respond_to_contact_request(user_2, user_1, true)
2091 .await
2092 .unwrap();
2093 assert_eq!(
2094 db.get_contacts(user_1).await.unwrap(),
2095 &[
2096 Contact::Accepted {
2097 user_id: user_1,
2098 should_notify: false
2099 },
2100 Contact::Accepted {
2101 user_id: user_2,
2102 should_notify: true
2103 }
2104 ],
2105 );
2106 assert!(db.has_contact(user_1, user_2).await.unwrap());
2107 assert!(db.has_contact(user_2, user_1).await.unwrap());
2108 assert_eq!(
2109 db.get_contacts(user_2).await.unwrap(),
2110 &[
2111 Contact::Accepted {
2112 user_id: user_1,
2113 should_notify: false,
2114 },
2115 Contact::Accepted {
2116 user_id: user_2,
2117 should_notify: false,
2118 },
2119 ]
2120 );
2121
2122 // Users cannot re-request existing contacts.
2123 db.send_contact_request(user_1, user_2).await.unwrap_err();
2124 db.send_contact_request(user_2, user_1).await.unwrap_err();
2125
2126 // Users can't dismiss notifications of them accepting other users' requests.
2127 db.dismiss_contact_notification(user_2, user_1)
2128 .await
2129 .unwrap_err();
2130 assert_eq!(
2131 db.get_contacts(user_1).await.unwrap(),
2132 &[
2133 Contact::Accepted {
2134 user_id: user_1,
2135 should_notify: false
2136 },
2137 Contact::Accepted {
2138 user_id: user_2,
2139 should_notify: true,
2140 },
2141 ]
2142 );
2143
2144 // Users can dismiss notifications of other users accepting their requests.
2145 db.dismiss_contact_notification(user_1, user_2)
2146 .await
2147 .unwrap();
2148 assert_eq!(
2149 db.get_contacts(user_1).await.unwrap(),
2150 &[
2151 Contact::Accepted {
2152 user_id: user_1,
2153 should_notify: false
2154 },
2155 Contact::Accepted {
2156 user_id: user_2,
2157 should_notify: false,
2158 },
2159 ]
2160 );
2161
2162 // Users send each other concurrent contact requests and
2163 // see that they are immediately accepted.
2164 db.send_contact_request(user_1, user_3).await.unwrap();
2165 db.send_contact_request(user_3, user_1).await.unwrap();
2166 assert_eq!(
2167 db.get_contacts(user_1).await.unwrap(),
2168 &[
2169 Contact::Accepted {
2170 user_id: user_1,
2171 should_notify: false
2172 },
2173 Contact::Accepted {
2174 user_id: user_2,
2175 should_notify: false,
2176 },
2177 Contact::Accepted {
2178 user_id: user_3,
2179 should_notify: false
2180 },
2181 ]
2182 );
2183 assert_eq!(
2184 db.get_contacts(user_3).await.unwrap(),
2185 &[
2186 Contact::Accepted {
2187 user_id: user_1,
2188 should_notify: false
2189 },
2190 Contact::Accepted {
2191 user_id: user_3,
2192 should_notify: false
2193 }
2194 ],
2195 );
2196
2197 // User declines a contact request. Both users see that it is gone.
2198 db.send_contact_request(user_2, user_3).await.unwrap();
2199 db.respond_to_contact_request(user_3, user_2, false)
2200 .await
2201 .unwrap();
2202 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2203 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2204 assert_eq!(
2205 db.get_contacts(user_2).await.unwrap(),
2206 &[
2207 Contact::Accepted {
2208 user_id: user_1,
2209 should_notify: false
2210 },
2211 Contact::Accepted {
2212 user_id: user_2,
2213 should_notify: false
2214 }
2215 ]
2216 );
2217 assert_eq!(
2218 db.get_contacts(user_3).await.unwrap(),
2219 &[
2220 Contact::Accepted {
2221 user_id: user_1,
2222 should_notify: false
2223 },
2224 Contact::Accepted {
2225 user_id: user_3,
2226 should_notify: false
2227 }
2228 ],
2229 );
2230 }
2231 }
2232
2233 #[tokio::test(flavor = "multi_thread")]
2234 async fn test_invite_codes() {
2235 let postgres = TestDb::postgres().await;
2236 let db = postgres.db();
2237 let user1 = db.create_user("user-1", None, false).await.unwrap();
2238
2239 // Initially, user 1 has no invite code
2240 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2241
2242 // Setting invite count to 0 when no code is assigned does not assign a new code
2243 db.set_invite_count(user1, 0).await.unwrap();
2244 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2245
2246 // User 1 creates an invite code that can be used twice.
2247 db.set_invite_count(user1, 2).await.unwrap();
2248 let (invite_code, invite_count) =
2249 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2250 assert_eq!(invite_count, 2);
2251
2252 // User 2 redeems the invite code and becomes a contact of user 1.
2253 let user2 = db
2254 .redeem_invite_code(&invite_code, "user-2", None)
2255 .await
2256 .unwrap();
2257 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2258 assert_eq!(invite_count, 1);
2259 assert_eq!(
2260 db.get_contacts(user1).await.unwrap(),
2261 [
2262 Contact::Accepted {
2263 user_id: user1,
2264 should_notify: false
2265 },
2266 Contact::Accepted {
2267 user_id: user2,
2268 should_notify: true
2269 }
2270 ]
2271 );
2272 assert_eq!(
2273 db.get_contacts(user2).await.unwrap(),
2274 [
2275 Contact::Accepted {
2276 user_id: user1,
2277 should_notify: false
2278 },
2279 Contact::Accepted {
2280 user_id: user2,
2281 should_notify: false
2282 }
2283 ]
2284 );
2285
2286 // User 3 redeems the invite code and becomes a contact of user 1.
2287 let user3 = db
2288 .redeem_invite_code(&invite_code, "user-3", None)
2289 .await
2290 .unwrap();
2291 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2292 assert_eq!(invite_count, 0);
2293 assert_eq!(
2294 db.get_contacts(user1).await.unwrap(),
2295 [
2296 Contact::Accepted {
2297 user_id: user1,
2298 should_notify: false
2299 },
2300 Contact::Accepted {
2301 user_id: user2,
2302 should_notify: true
2303 },
2304 Contact::Accepted {
2305 user_id: user3,
2306 should_notify: true
2307 }
2308 ]
2309 );
2310 assert_eq!(
2311 db.get_contacts(user3).await.unwrap(),
2312 [
2313 Contact::Accepted {
2314 user_id: user1,
2315 should_notify: false
2316 },
2317 Contact::Accepted {
2318 user_id: user3,
2319 should_notify: false
2320 },
2321 ]
2322 );
2323
2324 // Trying to reedem the code for the third time results in an error.
2325 db.redeem_invite_code(&invite_code, "user-4", None)
2326 .await
2327 .unwrap_err();
2328
2329 // Invite count can be updated after the code has been created.
2330 db.set_invite_count(user1, 2).await.unwrap();
2331 let (latest_code, invite_count) =
2332 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2333 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2334 assert_eq!(invite_count, 2);
2335
2336 // User 4 can now redeem the invite code and becomes a contact of user 1.
2337 let user4 = db
2338 .redeem_invite_code(&invite_code, "user-4", None)
2339 .await
2340 .unwrap();
2341 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2342 assert_eq!(invite_count, 1);
2343 assert_eq!(
2344 db.get_contacts(user1).await.unwrap(),
2345 [
2346 Contact::Accepted {
2347 user_id: user1,
2348 should_notify: false
2349 },
2350 Contact::Accepted {
2351 user_id: user2,
2352 should_notify: true
2353 },
2354 Contact::Accepted {
2355 user_id: user3,
2356 should_notify: true
2357 },
2358 Contact::Accepted {
2359 user_id: user4,
2360 should_notify: true
2361 }
2362 ]
2363 );
2364 assert_eq!(
2365 db.get_contacts(user4).await.unwrap(),
2366 [
2367 Contact::Accepted {
2368 user_id: user1,
2369 should_notify: false
2370 },
2371 Contact::Accepted {
2372 user_id: user4,
2373 should_notify: false
2374 },
2375 ]
2376 );
2377
2378 // An existing user cannot redeem invite codes.
2379 db.redeem_invite_code(&invite_code, "user-2", None)
2380 .await
2381 .unwrap_err();
2382 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2383 assert_eq!(invite_count, 1);
2384 }
2385
2386 pub struct TestDb {
2387 pub db: Option<Arc<dyn Db>>,
2388 pub url: String,
2389 }
2390
2391 impl TestDb {
2392 pub async fn postgres() -> Self {
2393 lazy_static! {
2394 static ref LOCK: Mutex<()> = Mutex::new(());
2395 }
2396
2397 let _guard = LOCK.lock();
2398 let mut rng = StdRng::from_entropy();
2399 let name = format!("zed-test-{}", rng.gen::<u128>());
2400 let url = format!("postgres://postgres@localhost/{}", name);
2401 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2402 Postgres::create_database(&url)
2403 .await
2404 .expect("failed to create test db");
2405 let db = PostgresDb::new(&url, 5).await.unwrap();
2406 let migrator = Migrator::new(migrations_path).await.unwrap();
2407 migrator.run(&db.pool).await.unwrap();
2408 Self {
2409 db: Some(Arc::new(db)),
2410 url,
2411 }
2412 }
2413
2414 pub fn fake(background: Arc<Background>) -> Self {
2415 Self {
2416 db: Some(Arc::new(FakeDb::new(background))),
2417 url: Default::default(),
2418 }
2419 }
2420
2421 pub fn db(&self) -> &Arc<dyn Db> {
2422 self.db.as_ref().unwrap()
2423 }
2424 }
2425
2426 impl Drop for TestDb {
2427 fn drop(&mut self) {
2428 if let Some(db) = self.db.take() {
2429 futures::executor::block_on(db.teardown(&self.url));
2430 }
2431 }
2432 }
2433
2434 pub struct FakeDb {
2435 background: Arc<Background>,
2436 pub users: Mutex<BTreeMap<UserId, User>>,
2437 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2438 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2439 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2440 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2441 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2442 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2443 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2444 pub contacts: Mutex<Vec<FakeContact>>,
2445 next_channel_message_id: Mutex<i32>,
2446 next_user_id: Mutex<i32>,
2447 next_org_id: Mutex<i32>,
2448 next_channel_id: Mutex<i32>,
2449 next_project_id: Mutex<i32>,
2450 }
2451
2452 #[derive(Debug)]
2453 pub struct FakeContact {
2454 pub requester_id: UserId,
2455 pub responder_id: UserId,
2456 pub accepted: bool,
2457 pub should_notify: bool,
2458 }
2459
2460 impl FakeDb {
2461 pub fn new(background: Arc<Background>) -> Self {
2462 Self {
2463 background,
2464 users: Default::default(),
2465 next_user_id: Mutex::new(0),
2466 projects: Default::default(),
2467 worktree_extensions: Default::default(),
2468 next_project_id: Mutex::new(1),
2469 orgs: Default::default(),
2470 next_org_id: Mutex::new(1),
2471 org_memberships: Default::default(),
2472 channels: Default::default(),
2473 next_channel_id: Mutex::new(1),
2474 channel_memberships: Default::default(),
2475 channel_messages: Default::default(),
2476 next_channel_message_id: Mutex::new(1),
2477 contacts: Default::default(),
2478 }
2479 }
2480 }
2481
2482 #[async_trait]
2483 impl Db for FakeDb {
2484 async fn create_user(
2485 &self,
2486 github_login: &str,
2487 email_address: Option<&str>,
2488 admin: bool,
2489 ) -> Result<UserId> {
2490 self.background.simulate_random_delay().await;
2491
2492 let mut users = self.users.lock();
2493 if let Some(user) = users
2494 .values()
2495 .find(|user| user.github_login == github_login)
2496 {
2497 Ok(user.id)
2498 } else {
2499 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2500 users.insert(
2501 user_id,
2502 User {
2503 id: user_id,
2504 github_login: github_login.to_string(),
2505 email_address: email_address.map(str::to_string),
2506 admin,
2507 invite_code: None,
2508 invite_count: 0,
2509 connected_once: false,
2510 },
2511 );
2512 Ok(user_id)
2513 }
2514 }
2515
2516 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2517 unimplemented!()
2518 }
2519
2520 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2521 unimplemented!()
2522 }
2523
2524 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2525 unimplemented!()
2526 }
2527
2528 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2529 self.background.simulate_random_delay().await;
2530 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2531 }
2532
2533 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2534 self.background.simulate_random_delay().await;
2535 let users = self.users.lock();
2536 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2537 }
2538
2539 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2540 unimplemented!()
2541 }
2542
2543 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2544 self.background.simulate_random_delay().await;
2545 Ok(self
2546 .users
2547 .lock()
2548 .values()
2549 .find(|user| user.github_login == github_login)
2550 .cloned())
2551 }
2552
2553 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2554 unimplemented!()
2555 }
2556
2557 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2558 self.background.simulate_random_delay().await;
2559 let mut users = self.users.lock();
2560 let mut user = users
2561 .get_mut(&id)
2562 .ok_or_else(|| anyhow!("user not found"))?;
2563 user.connected_once = connected_once;
2564 Ok(())
2565 }
2566
2567 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2568 unimplemented!()
2569 }
2570
2571 // invite codes
2572
2573 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2574 unimplemented!()
2575 }
2576
2577 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2578 self.background.simulate_random_delay().await;
2579 Ok(None)
2580 }
2581
2582 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2583 unimplemented!()
2584 }
2585
2586 async fn redeem_invite_code(
2587 &self,
2588 _code: &str,
2589 _login: &str,
2590 _email_address: Option<&str>,
2591 ) -> Result<UserId> {
2592 unimplemented!()
2593 }
2594
2595 // projects
2596
2597 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2598 self.background.simulate_random_delay().await;
2599 if !self.users.lock().contains_key(&host_user_id) {
2600 Err(anyhow!("no such user"))?;
2601 }
2602
2603 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2604 self.projects.lock().insert(
2605 project_id,
2606 Project {
2607 id: project_id,
2608 host_user_id,
2609 unregistered: false,
2610 },
2611 );
2612 Ok(project_id)
2613 }
2614
2615 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2616 self.background.simulate_random_delay().await;
2617 self.projects
2618 .lock()
2619 .get_mut(&project_id)
2620 .ok_or_else(|| anyhow!("no such project"))?
2621 .unregistered = true;
2622 Ok(())
2623 }
2624
2625 async fn update_worktree_extensions(
2626 &self,
2627 project_id: ProjectId,
2628 worktree_id: u64,
2629 extensions: HashMap<String, u32>,
2630 ) -> Result<()> {
2631 self.background.simulate_random_delay().await;
2632 if !self.projects.lock().contains_key(&project_id) {
2633 Err(anyhow!("no such project"))?;
2634 }
2635
2636 for (extension, count) in extensions {
2637 self.worktree_extensions
2638 .lock()
2639 .insert((project_id, worktree_id, extension), count);
2640 }
2641
2642 Ok(())
2643 }
2644
2645 async fn get_project_extensions(
2646 &self,
2647 _project_id: ProjectId,
2648 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2649 unimplemented!()
2650 }
2651
2652 async fn record_user_activity(
2653 &self,
2654 _time_period: Range<OffsetDateTime>,
2655 _active_projects: &[(UserId, ProjectId)],
2656 ) -> Result<()> {
2657 unimplemented!()
2658 }
2659
2660 async fn get_active_user_count(
2661 &self,
2662 _time_period: Range<OffsetDateTime>,
2663 _min_duration: Duration,
2664 _only_collaborative: bool,
2665 ) -> Result<usize> {
2666 unimplemented!()
2667 }
2668
2669 async fn get_top_users_activity_summary(
2670 &self,
2671 _time_period: Range<OffsetDateTime>,
2672 _limit: usize,
2673 ) -> Result<Vec<UserActivitySummary>> {
2674 unimplemented!()
2675 }
2676
2677 async fn get_user_activity_timeline(
2678 &self,
2679 _time_period: Range<OffsetDateTime>,
2680 _user_id: UserId,
2681 ) -> Result<Vec<UserActivityPeriod>> {
2682 unimplemented!()
2683 }
2684
2685 // contacts
2686
2687 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2688 self.background.simulate_random_delay().await;
2689 let mut contacts = vec![Contact::Accepted {
2690 user_id: id,
2691 should_notify: false,
2692 }];
2693
2694 for contact in self.contacts.lock().iter() {
2695 if contact.requester_id == id {
2696 if contact.accepted {
2697 contacts.push(Contact::Accepted {
2698 user_id: contact.responder_id,
2699 should_notify: contact.should_notify,
2700 });
2701 } else {
2702 contacts.push(Contact::Outgoing {
2703 user_id: contact.responder_id,
2704 });
2705 }
2706 } else if contact.responder_id == id {
2707 if contact.accepted {
2708 contacts.push(Contact::Accepted {
2709 user_id: contact.requester_id,
2710 should_notify: false,
2711 });
2712 } else {
2713 contacts.push(Contact::Incoming {
2714 user_id: contact.requester_id,
2715 should_notify: contact.should_notify,
2716 });
2717 }
2718 }
2719 }
2720
2721 contacts.sort_unstable_by_key(|contact| contact.user_id());
2722 Ok(contacts)
2723 }
2724
2725 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2726 self.background.simulate_random_delay().await;
2727 Ok(self.contacts.lock().iter().any(|contact| {
2728 contact.accepted
2729 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2730 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2731 }))
2732 }
2733
2734 async fn send_contact_request(
2735 &self,
2736 requester_id: UserId,
2737 responder_id: UserId,
2738 ) -> Result<()> {
2739 self.background.simulate_random_delay().await;
2740 let mut contacts = self.contacts.lock();
2741 for contact in contacts.iter_mut() {
2742 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2743 if contact.accepted {
2744 Err(anyhow!("contact already exists"))?;
2745 } else {
2746 Err(anyhow!("contact already requested"))?;
2747 }
2748 }
2749 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2750 if contact.accepted {
2751 Err(anyhow!("contact already exists"))?;
2752 } else {
2753 contact.accepted = true;
2754 contact.should_notify = false;
2755 return Ok(());
2756 }
2757 }
2758 }
2759 contacts.push(FakeContact {
2760 requester_id,
2761 responder_id,
2762 accepted: false,
2763 should_notify: true,
2764 });
2765 Ok(())
2766 }
2767
2768 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2769 self.background.simulate_random_delay().await;
2770 self.contacts.lock().retain(|contact| {
2771 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2772 });
2773 Ok(())
2774 }
2775
2776 async fn dismiss_contact_notification(
2777 &self,
2778 user_id: UserId,
2779 contact_user_id: UserId,
2780 ) -> Result<()> {
2781 self.background.simulate_random_delay().await;
2782 let mut contacts = self.contacts.lock();
2783 for contact in contacts.iter_mut() {
2784 if contact.requester_id == contact_user_id
2785 && contact.responder_id == user_id
2786 && !contact.accepted
2787 {
2788 contact.should_notify = false;
2789 return Ok(());
2790 }
2791 if contact.requester_id == user_id
2792 && contact.responder_id == contact_user_id
2793 && contact.accepted
2794 {
2795 contact.should_notify = false;
2796 return Ok(());
2797 }
2798 }
2799 Err(anyhow!("no such notification"))?
2800 }
2801
2802 async fn respond_to_contact_request(
2803 &self,
2804 responder_id: UserId,
2805 requester_id: UserId,
2806 accept: bool,
2807 ) -> Result<()> {
2808 self.background.simulate_random_delay().await;
2809 let mut contacts = self.contacts.lock();
2810 for (ix, contact) in contacts.iter_mut().enumerate() {
2811 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2812 if contact.accepted {
2813 Err(anyhow!("contact already confirmed"))?;
2814 }
2815 if accept {
2816 contact.accepted = true;
2817 contact.should_notify = true;
2818 } else {
2819 contacts.remove(ix);
2820 }
2821 return Ok(());
2822 }
2823 }
2824 Err(anyhow!("no such contact request"))?
2825 }
2826
2827 async fn create_access_token_hash(
2828 &self,
2829 _user_id: UserId,
2830 _access_token_hash: &str,
2831 _max_access_token_count: usize,
2832 ) -> Result<()> {
2833 unimplemented!()
2834 }
2835
2836 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2837 unimplemented!()
2838 }
2839
2840 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2841 unimplemented!()
2842 }
2843
2844 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2845 self.background.simulate_random_delay().await;
2846 let mut orgs = self.orgs.lock();
2847 if orgs.values().any(|org| org.slug == slug) {
2848 Err(anyhow!("org already exists"))?
2849 } else {
2850 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2851 orgs.insert(
2852 org_id,
2853 Org {
2854 id: org_id,
2855 name: name.to_string(),
2856 slug: slug.to_string(),
2857 },
2858 );
2859 Ok(org_id)
2860 }
2861 }
2862
2863 async fn add_org_member(
2864 &self,
2865 org_id: OrgId,
2866 user_id: UserId,
2867 is_admin: bool,
2868 ) -> Result<()> {
2869 self.background.simulate_random_delay().await;
2870 if !self.orgs.lock().contains_key(&org_id) {
2871 Err(anyhow!("org does not exist"))?;
2872 }
2873 if !self.users.lock().contains_key(&user_id) {
2874 Err(anyhow!("user does not exist"))?;
2875 }
2876
2877 self.org_memberships
2878 .lock()
2879 .entry((org_id, user_id))
2880 .or_insert(is_admin);
2881 Ok(())
2882 }
2883
2884 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2885 self.background.simulate_random_delay().await;
2886 if !self.orgs.lock().contains_key(&org_id) {
2887 Err(anyhow!("org does not exist"))?;
2888 }
2889
2890 let mut channels = self.channels.lock();
2891 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2892 channels.insert(
2893 channel_id,
2894 Channel {
2895 id: channel_id,
2896 name: name.to_string(),
2897 owner_id: org_id.0,
2898 owner_is_user: false,
2899 },
2900 );
2901 Ok(channel_id)
2902 }
2903
2904 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2905 self.background.simulate_random_delay().await;
2906 Ok(self
2907 .channels
2908 .lock()
2909 .values()
2910 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2911 .cloned()
2912 .collect())
2913 }
2914
2915 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2916 self.background.simulate_random_delay().await;
2917 let channels = self.channels.lock();
2918 let memberships = self.channel_memberships.lock();
2919 Ok(channels
2920 .values()
2921 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2922 .cloned()
2923 .collect())
2924 }
2925
2926 async fn can_user_access_channel(
2927 &self,
2928 user_id: UserId,
2929 channel_id: ChannelId,
2930 ) -> Result<bool> {
2931 self.background.simulate_random_delay().await;
2932 Ok(self
2933 .channel_memberships
2934 .lock()
2935 .contains_key(&(channel_id, user_id)))
2936 }
2937
2938 async fn add_channel_member(
2939 &self,
2940 channel_id: ChannelId,
2941 user_id: UserId,
2942 is_admin: bool,
2943 ) -> Result<()> {
2944 self.background.simulate_random_delay().await;
2945 if !self.channels.lock().contains_key(&channel_id) {
2946 Err(anyhow!("channel does not exist"))?;
2947 }
2948 if !self.users.lock().contains_key(&user_id) {
2949 Err(anyhow!("user does not exist"))?;
2950 }
2951
2952 self.channel_memberships
2953 .lock()
2954 .entry((channel_id, user_id))
2955 .or_insert(is_admin);
2956 Ok(())
2957 }
2958
2959 async fn create_channel_message(
2960 &self,
2961 channel_id: ChannelId,
2962 sender_id: UserId,
2963 body: &str,
2964 timestamp: OffsetDateTime,
2965 nonce: u128,
2966 ) -> Result<MessageId> {
2967 self.background.simulate_random_delay().await;
2968 if !self.channels.lock().contains_key(&channel_id) {
2969 Err(anyhow!("channel does not exist"))?;
2970 }
2971 if !self.users.lock().contains_key(&sender_id) {
2972 Err(anyhow!("user does not exist"))?;
2973 }
2974
2975 let mut messages = self.channel_messages.lock();
2976 if let Some(message) = messages
2977 .values()
2978 .find(|message| message.nonce.as_u128() == nonce)
2979 {
2980 Ok(message.id)
2981 } else {
2982 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2983 messages.insert(
2984 message_id,
2985 ChannelMessage {
2986 id: message_id,
2987 channel_id,
2988 sender_id,
2989 body: body.to_string(),
2990 sent_at: timestamp,
2991 nonce: Uuid::from_u128(nonce),
2992 },
2993 );
2994 Ok(message_id)
2995 }
2996 }
2997
2998 async fn get_channel_messages(
2999 &self,
3000 channel_id: ChannelId,
3001 count: usize,
3002 before_id: Option<MessageId>,
3003 ) -> Result<Vec<ChannelMessage>> {
3004 self.background.simulate_random_delay().await;
3005 let mut messages = self
3006 .channel_messages
3007 .lock()
3008 .values()
3009 .rev()
3010 .filter(|message| {
3011 message.channel_id == channel_id
3012 && message.id < before_id.unwrap_or(MessageId::MAX)
3013 })
3014 .take(count)
3015 .cloned()
3016 .collect::<Vec<_>>();
3017 messages.sort_unstable_by_key(|message| message.id);
3018 Ok(messages)
3019 }
3020
3021 async fn teardown(&self, _: &str) {}
3022
3023 #[cfg(test)]
3024 fn as_fake(&self) -> Option<&FakeDb> {
3025 Some(self)
3026 }
3027 }
3028
3029 fn build_background_executor() -> Arc<Background> {
3030 Deterministic::new(0).build_background()
3031 }
3032}