1use std::{cmp, ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use serde::{Deserialize, Serialize};
10pub use sqlx::postgres::PgPoolOptions as DbOptions;
11use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
12use time::{OffsetDateTime, PrimitiveDateTime};
13
14#[async_trait]
15pub trait Db: Send + Sync {
16 async fn create_user(
17 &self,
18 github_login: &str,
19 email_address: Option<&str>,
20 admin: bool,
21 ) -> Result<UserId>;
22 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
23 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
24 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
25 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
26 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
27 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
28 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
29 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
30 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
31 async fn destroy_user(&self, id: UserId) -> Result<()>;
32
33 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
34 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
35 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
36 async fn redeem_invite_code(
37 &self,
38 code: &str,
39 login: &str,
40 email_address: Option<&str>,
41 ) -> Result<UserId>;
42
43 /// Registers a new project for the given user.
44 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
45
46 /// Unregisters a project for the given project id.
47 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
48
49 /// Update file counts by extension for the given project and worktree.
50 async fn update_worktree_extensions(
51 &self,
52 project_id: ProjectId,
53 worktree_id: u64,
54 extensions: HashMap<String, u32>,
55 ) -> Result<()>;
56
57 /// Get the file counts on the given project keyed by their worktree and extension.
58 async fn get_project_extensions(
59 &self,
60 project_id: ProjectId,
61 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
62
63 /// Record which users have been active in which projects during
64 /// a given period of time.
65 async fn record_user_activity(
66 &self,
67 time_period: Range<OffsetDateTime>,
68 active_projects: &[(UserId, ProjectId)],
69 ) -> Result<()>;
70
71 /// Get the number of users who have been active in the given
72 /// time period for at least the given time duration.
73 async fn get_active_user_count(
74 &self,
75 time_period: Range<OffsetDateTime>,
76 min_duration: Duration,
77 only_collaborative: bool,
78 ) -> Result<usize>;
79
80 /// Get the users that have been most active during the given time period,
81 /// along with the amount of time they have been active in each project.
82 async fn get_top_users_activity_summary(
83 &self,
84 time_period: Range<OffsetDateTime>,
85 max_user_count: usize,
86 ) -> Result<Vec<UserActivitySummary>>;
87
88 /// Get the project activity for the given user and time period.
89 async fn get_user_activity_timeline(
90 &self,
91 time_period: Range<OffsetDateTime>,
92 user_id: UserId,
93 ) -> Result<Vec<UserActivityPeriod>>;
94
95 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
96 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
97 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
98 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
99 async fn dismiss_contact_notification(
100 &self,
101 responder_id: UserId,
102 requester_id: UserId,
103 ) -> Result<()>;
104 async fn respond_to_contact_request(
105 &self,
106 responder_id: UserId,
107 requester_id: UserId,
108 accept: bool,
109 ) -> Result<()>;
110
111 async fn create_access_token_hash(
112 &self,
113 user_id: UserId,
114 access_token_hash: &str,
115 max_access_token_count: usize,
116 ) -> Result<()>;
117 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
118 #[cfg(any(test, feature = "seed-support"))]
119
120 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
121 #[cfg(any(test, feature = "seed-support"))]
122 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
123 #[cfg(any(test, feature = "seed-support"))]
124 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
125 #[cfg(any(test, feature = "seed-support"))]
126 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
127 #[cfg(any(test, feature = "seed-support"))]
128
129 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
130 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
131 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
132 -> Result<bool>;
133 #[cfg(any(test, feature = "seed-support"))]
134 async fn add_channel_member(
135 &self,
136 channel_id: ChannelId,
137 user_id: UserId,
138 is_admin: bool,
139 ) -> Result<()>;
140 async fn create_channel_message(
141 &self,
142 channel_id: ChannelId,
143 sender_id: UserId,
144 body: &str,
145 timestamp: OffsetDateTime,
146 nonce: u128,
147 ) -> Result<MessageId>;
148 async fn get_channel_messages(
149 &self,
150 channel_id: ChannelId,
151 count: usize,
152 before_id: Option<MessageId>,
153 ) -> Result<Vec<ChannelMessage>>;
154 #[cfg(test)]
155 async fn teardown(&self, url: &str);
156 #[cfg(test)]
157 fn as_fake(&self) -> Option<&tests::FakeDb>;
158}
159
160pub struct PostgresDb {
161 pool: sqlx::PgPool,
162}
163
164impl PostgresDb {
165 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
166 let pool = DbOptions::new()
167 .max_connections(max_connections)
168 .connect(url)
169 .await
170 .context("failed to connect to postgres database")?;
171 Ok(Self { pool })
172 }
173}
174
175#[async_trait]
176impl Db for PostgresDb {
177 // users
178
179 async fn create_user(
180 &self,
181 github_login: &str,
182 email_address: Option<&str>,
183 admin: bool,
184 ) -> Result<UserId> {
185 let query = "
186 INSERT INTO users (github_login, email_address, admin)
187 VALUES ($1, $2, $3)
188 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
189 RETURNING id
190 ";
191 Ok(sqlx::query_scalar(query)
192 .bind(github_login)
193 .bind(email_address)
194 .bind(admin)
195 .fetch_one(&self.pool)
196 .await
197 .map(UserId)?)
198 }
199
200 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
201 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
202 Ok(sqlx::query_as(query)
203 .bind(limit as i32)
204 .bind((page * limit) as i32)
205 .fetch_all(&self.pool)
206 .await?)
207 }
208
209 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
210 let mut query = QueryBuilder::new(
211 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
212 );
213 query.push_values(
214 users,
215 |mut query, (github_login, email_address, invite_count)| {
216 query
217 .push_bind(github_login)
218 .push_bind(email_address)
219 .push_bind(false)
220 .push_bind(random_invite_code())
221 .push_bind(invite_count as i32);
222 },
223 );
224 query.push(
225 "
226 ON CONFLICT (github_login) DO UPDATE SET
227 github_login = excluded.github_login,
228 invite_count = excluded.invite_count,
229 invite_code = CASE WHEN users.invite_code IS NULL
230 THEN excluded.invite_code
231 ELSE users.invite_code
232 END
233 RETURNING id
234 ",
235 );
236
237 let rows = query.build().fetch_all(&self.pool).await?;
238 Ok(rows
239 .into_iter()
240 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
241 .collect())
242 }
243
244 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
245 let like_string = fuzzy_like_string(name_query);
246 let query = "
247 SELECT users.*
248 FROM users
249 WHERE github_login ILIKE $1
250 ORDER BY github_login <-> $2
251 LIMIT $3
252 ";
253 Ok(sqlx::query_as(query)
254 .bind(like_string)
255 .bind(name_query)
256 .bind(limit as i32)
257 .fetch_all(&self.pool)
258 .await?)
259 }
260
261 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
262 let users = self.get_users_by_ids(vec![id]).await?;
263 Ok(users.into_iter().next())
264 }
265
266 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
267 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
268 let query = "
269 SELECT users.*
270 FROM users
271 WHERE users.id = ANY ($1)
272 ";
273 Ok(sqlx::query_as(query)
274 .bind(&ids)
275 .fetch_all(&self.pool)
276 .await?)
277 }
278
279 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
280 let query = format!(
281 "
282 SELECT users.*
283 FROM users
284 WHERE invite_count = 0
285 AND inviter_id IS{} NULL
286 ",
287 if invited_by_another_user { " NOT" } else { "" }
288 );
289
290 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
291 }
292
293 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
294 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
295 Ok(sqlx::query_as(query)
296 .bind(github_login)
297 .fetch_optional(&self.pool)
298 .await?)
299 }
300
301 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
302 let query = "UPDATE users SET admin = $1 WHERE id = $2";
303 Ok(sqlx::query(query)
304 .bind(is_admin)
305 .bind(id.0)
306 .execute(&self.pool)
307 .await
308 .map(drop)?)
309 }
310
311 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
312 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
313 Ok(sqlx::query(query)
314 .bind(connected_once)
315 .bind(id.0)
316 .execute(&self.pool)
317 .await
318 .map(drop)?)
319 }
320
321 async fn destroy_user(&self, id: UserId) -> Result<()> {
322 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
323 sqlx::query(query)
324 .bind(id.0)
325 .execute(&self.pool)
326 .await
327 .map(drop)?;
328 let query = "DELETE FROM users WHERE id = $1;";
329 Ok(sqlx::query(query)
330 .bind(id.0)
331 .execute(&self.pool)
332 .await
333 .map(drop)?)
334 }
335
336 // invite codes
337
338 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
339 let mut tx = self.pool.begin().await?;
340 if count > 0 {
341 sqlx::query(
342 "
343 UPDATE users
344 SET invite_code = $1
345 WHERE id = $2 AND invite_code IS NULL
346 ",
347 )
348 .bind(random_invite_code())
349 .bind(id)
350 .execute(&mut tx)
351 .await?;
352 }
353
354 sqlx::query(
355 "
356 UPDATE users
357 SET invite_count = $1
358 WHERE id = $2
359 ",
360 )
361 .bind(count as i32)
362 .bind(id)
363 .execute(&mut tx)
364 .await?;
365 tx.commit().await?;
366 Ok(())
367 }
368
369 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
370 let result: Option<(String, i32)> = sqlx::query_as(
371 "
372 SELECT invite_code, invite_count
373 FROM users
374 WHERE id = $1 AND invite_code IS NOT NULL
375 ",
376 )
377 .bind(id)
378 .fetch_optional(&self.pool)
379 .await?;
380 if let Some((code, count)) = result {
381 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
382 } else {
383 Ok(None)
384 }
385 }
386
387 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
388 sqlx::query_as(
389 "
390 SELECT *
391 FROM users
392 WHERE invite_code = $1
393 ",
394 )
395 .bind(code)
396 .fetch_optional(&self.pool)
397 .await?
398 .ok_or_else(|| {
399 Error::Http(
400 StatusCode::NOT_FOUND,
401 "that invite code does not exist".to_string(),
402 )
403 })
404 }
405
406 async fn redeem_invite_code(
407 &self,
408 code: &str,
409 login: &str,
410 email_address: Option<&str>,
411 ) -> Result<UserId> {
412 let mut tx = self.pool.begin().await?;
413
414 let inviter_id: Option<UserId> = sqlx::query_scalar(
415 "
416 UPDATE users
417 SET invite_count = invite_count - 1
418 WHERE
419 invite_code = $1 AND
420 invite_count > 0
421 RETURNING id
422 ",
423 )
424 .bind(code)
425 .fetch_optional(&mut tx)
426 .await?;
427
428 let inviter_id = match inviter_id {
429 Some(inviter_id) => inviter_id,
430 None => {
431 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
432 .bind(code)
433 .fetch_optional(&mut tx)
434 .await?
435 .is_some()
436 {
437 Err(Error::Http(
438 StatusCode::UNAUTHORIZED,
439 "no invites remaining".to_string(),
440 ))?
441 } else {
442 Err(Error::Http(
443 StatusCode::NOT_FOUND,
444 "invite code not found".to_string(),
445 ))?
446 }
447 }
448 };
449
450 let invitee_id = sqlx::query_scalar(
451 "
452 INSERT INTO users
453 (github_login, email_address, admin, inviter_id, invite_code, invite_count)
454 VALUES
455 ($1, $2, 'f', $3, $4, $5)
456 RETURNING id
457 ",
458 )
459 .bind(login)
460 .bind(email_address)
461 .bind(inviter_id)
462 .bind(random_invite_code())
463 .bind(5)
464 .fetch_one(&mut tx)
465 .await
466 .map(UserId)?;
467
468 sqlx::query(
469 "
470 INSERT INTO contacts
471 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
472 VALUES
473 ($1, $2, 't', 't', 't')
474 ",
475 )
476 .bind(inviter_id)
477 .bind(invitee_id)
478 .execute(&mut tx)
479 .await?;
480
481 tx.commit().await?;
482 Ok(invitee_id)
483 }
484
485 // projects
486
487 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
488 Ok(sqlx::query_scalar(
489 "
490 INSERT INTO projects(host_user_id)
491 VALUES ($1)
492 RETURNING id
493 ",
494 )
495 .bind(host_user_id)
496 .fetch_one(&self.pool)
497 .await
498 .map(ProjectId)?)
499 }
500
501 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
502 sqlx::query(
503 "
504 UPDATE projects
505 SET unregistered = 't'
506 WHERE id = $1
507 ",
508 )
509 .bind(project_id)
510 .execute(&self.pool)
511 .await?;
512 Ok(())
513 }
514
515 async fn update_worktree_extensions(
516 &self,
517 project_id: ProjectId,
518 worktree_id: u64,
519 extensions: HashMap<String, u32>,
520 ) -> Result<()> {
521 if extensions.is_empty() {
522 return Ok(());
523 }
524
525 let mut query = QueryBuilder::new(
526 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
527 );
528 query.push_values(extensions, |mut query, (extension, count)| {
529 query
530 .push_bind(project_id)
531 .push_bind(worktree_id as i32)
532 .push_bind(extension)
533 .push_bind(count as i32);
534 });
535 query.push(
536 "
537 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
538 count = excluded.count
539 ",
540 );
541 query.build().execute(&self.pool).await?;
542
543 Ok(())
544 }
545
546 async fn get_project_extensions(
547 &self,
548 project_id: ProjectId,
549 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
550 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
551 struct WorktreeExtension {
552 worktree_id: i32,
553 extension: String,
554 count: i32,
555 }
556
557 let query = "
558 SELECT worktree_id, extension, count
559 FROM worktree_extensions
560 WHERE project_id = $1
561 ";
562 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
563 .bind(&project_id)
564 .fetch_all(&self.pool)
565 .await?;
566
567 let mut extension_counts = HashMap::default();
568 for count in counts {
569 extension_counts
570 .entry(count.worktree_id as u64)
571 .or_insert_with(HashMap::default)
572 .insert(count.extension, count.count as usize);
573 }
574 Ok(extension_counts)
575 }
576
577 async fn record_user_activity(
578 &self,
579 time_period: Range<OffsetDateTime>,
580 projects: &[(UserId, ProjectId)],
581 ) -> Result<()> {
582 let query = "
583 INSERT INTO project_activity_periods
584 (ended_at, duration_millis, user_id, project_id)
585 VALUES
586 ($1, $2, $3, $4);
587 ";
588
589 let mut tx = self.pool.begin().await?;
590 let duration_millis =
591 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
592 for (user_id, project_id) in projects {
593 sqlx::query(query)
594 .bind(time_period.end)
595 .bind(duration_millis)
596 .bind(user_id)
597 .bind(project_id)
598 .execute(&mut tx)
599 .await?;
600 }
601 tx.commit().await?;
602
603 Ok(())
604 }
605
606 async fn get_active_user_count(
607 &self,
608 time_period: Range<OffsetDateTime>,
609 min_duration: Duration,
610 only_collaborative: bool,
611 ) -> Result<usize> {
612 let mut with_clause = String::new();
613 with_clause.push_str("WITH\n");
614 with_clause.push_str(
615 "
616 project_durations AS (
617 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
618 FROM project_activity_periods
619 WHERE $1 < ended_at AND ended_at <= $2
620 GROUP BY user_id, project_id
621 ),
622 ",
623 );
624 with_clause.push_str(
625 "
626 project_collaborators as (
627 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
628 FROM project_durations
629 GROUP BY project_id
630 ),
631 ",
632 );
633
634 if only_collaborative {
635 with_clause.push_str(
636 "
637 user_durations AS (
638 SELECT user_id, SUM(project_duration) as total_duration
639 FROM project_durations, project_collaborators
640 WHERE
641 project_durations.project_id = project_collaborators.project_id AND
642 max_collaborators > 1
643 GROUP BY user_id
644 ORDER BY total_duration DESC
645 LIMIT $3
646 )
647 ",
648 );
649 } else {
650 with_clause.push_str(
651 "
652 user_durations AS (
653 SELECT user_id, SUM(project_duration) as total_duration
654 FROM project_durations
655 GROUP BY user_id
656 ORDER BY total_duration DESC
657 LIMIT $3
658 )
659 ",
660 );
661 }
662
663 let query = format!(
664 "
665 {with_clause}
666 SELECT count(user_durations.user_id)
667 FROM user_durations
668 WHERE user_durations.total_duration >= $3
669 "
670 );
671
672 let count: i64 = sqlx::query_scalar(&query)
673 .bind(time_period.start)
674 .bind(time_period.end)
675 .bind(min_duration.as_millis() as i64)
676 .fetch_one(&self.pool)
677 .await?;
678 Ok(count as usize)
679 }
680
681 async fn get_top_users_activity_summary(
682 &self,
683 time_period: Range<OffsetDateTime>,
684 max_user_count: usize,
685 ) -> Result<Vec<UserActivitySummary>> {
686 let query = "
687 WITH
688 project_durations AS (
689 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
690 FROM project_activity_periods
691 WHERE $1 < ended_at AND ended_at <= $2
692 GROUP BY user_id, project_id
693 ),
694 user_durations AS (
695 SELECT user_id, SUM(project_duration) as total_duration
696 FROM project_durations
697 GROUP BY user_id
698 ORDER BY total_duration DESC
699 LIMIT $3
700 ),
701 project_collaborators as (
702 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
703 FROM project_durations
704 GROUP BY project_id
705 )
706 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
707 FROM user_durations, project_durations, project_collaborators, users
708 WHERE
709 user_durations.user_id = project_durations.user_id AND
710 user_durations.user_id = users.id AND
711 project_durations.project_id = project_collaborators.project_id
712 ORDER BY total_duration DESC, user_id ASC, project_id ASC
713 ";
714
715 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
716 .bind(time_period.start)
717 .bind(time_period.end)
718 .bind(max_user_count as i32)
719 .fetch(&self.pool);
720
721 let mut result = Vec::<UserActivitySummary>::new();
722 while let Some(row) = rows.next().await {
723 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
724 let project_id = project_id;
725 let duration = Duration::from_millis(duration_millis as u64);
726 let project_activity = ProjectActivitySummary {
727 id: project_id,
728 duration,
729 max_collaborators: project_collaborators as usize,
730 };
731 if let Some(last_summary) = result.last_mut() {
732 if last_summary.id == user_id {
733 last_summary.project_activity.push(project_activity);
734 continue;
735 }
736 }
737 result.push(UserActivitySummary {
738 id: user_id,
739 project_activity: vec![project_activity],
740 github_login,
741 });
742 }
743
744 Ok(result)
745 }
746
747 async fn get_user_activity_timeline(
748 &self,
749 time_period: Range<OffsetDateTime>,
750 user_id: UserId,
751 ) -> Result<Vec<UserActivityPeriod>> {
752 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
753
754 let query = "
755 SELECT
756 project_activity_periods.ended_at,
757 project_activity_periods.duration_millis,
758 project_activity_periods.project_id,
759 worktree_extensions.extension,
760 worktree_extensions.count
761 FROM project_activity_periods
762 LEFT OUTER JOIN
763 worktree_extensions
764 ON
765 project_activity_periods.project_id = worktree_extensions.project_id
766 WHERE
767 project_activity_periods.user_id = $1 AND
768 $2 < project_activity_periods.ended_at AND
769 project_activity_periods.ended_at <= $3
770 ORDER BY project_activity_periods.id ASC
771 ";
772
773 let mut rows = sqlx::query_as::<
774 _,
775 (
776 PrimitiveDateTime,
777 i32,
778 ProjectId,
779 Option<String>,
780 Option<i32>,
781 ),
782 >(query)
783 .bind(user_id)
784 .bind(time_period.start)
785 .bind(time_period.end)
786 .fetch(&self.pool);
787
788 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
789 while let Some(row) = rows.next().await {
790 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
791 let ended_at = ended_at.assume_utc();
792 let duration = Duration::from_millis(duration_millis as u64);
793 let started_at = ended_at - duration;
794 let project_time_periods = time_periods.entry(project_id).or_default();
795
796 if let Some(prev_duration) = project_time_periods.last_mut() {
797 if started_at <= prev_duration.end + COALESCE_THRESHOLD
798 && ended_at >= prev_duration.start
799 {
800 prev_duration.end = cmp::max(prev_duration.end, ended_at);
801 } else {
802 project_time_periods.push(UserActivityPeriod {
803 project_id,
804 start: started_at,
805 end: ended_at,
806 extensions: Default::default(),
807 });
808 }
809 } else {
810 project_time_periods.push(UserActivityPeriod {
811 project_id,
812 start: started_at,
813 end: ended_at,
814 extensions: Default::default(),
815 });
816 }
817
818 if let Some((extension, extension_count)) = extension.zip(extension_count) {
819 project_time_periods
820 .last_mut()
821 .unwrap()
822 .extensions
823 .insert(extension, extension_count as usize);
824 }
825 }
826
827 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
828 durations.sort_unstable_by_key(|duration| duration.start);
829 Ok(durations)
830 }
831
832 // contacts
833
834 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
835 let query = "
836 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
837 FROM contacts
838 WHERE user_id_a = $1 OR user_id_b = $1;
839 ";
840
841 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
842 .bind(user_id)
843 .fetch(&self.pool);
844
845 let mut contacts = vec![Contact::Accepted {
846 user_id,
847 should_notify: false,
848 }];
849 while let Some(row) = rows.next().await {
850 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
851
852 if user_id_a == user_id {
853 if accepted {
854 contacts.push(Contact::Accepted {
855 user_id: user_id_b,
856 should_notify: should_notify && a_to_b,
857 });
858 } else if a_to_b {
859 contacts.push(Contact::Outgoing { user_id: user_id_b })
860 } else {
861 contacts.push(Contact::Incoming {
862 user_id: user_id_b,
863 should_notify,
864 });
865 }
866 } else if accepted {
867 contacts.push(Contact::Accepted {
868 user_id: user_id_a,
869 should_notify: should_notify && !a_to_b,
870 });
871 } else if a_to_b {
872 contacts.push(Contact::Incoming {
873 user_id: user_id_a,
874 should_notify,
875 });
876 } else {
877 contacts.push(Contact::Outgoing { user_id: user_id_a });
878 }
879 }
880
881 contacts.sort_unstable_by_key(|contact| contact.user_id());
882
883 Ok(contacts)
884 }
885
886 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
887 let (id_a, id_b) = if user_id_1 < user_id_2 {
888 (user_id_1, user_id_2)
889 } else {
890 (user_id_2, user_id_1)
891 };
892
893 let query = "
894 SELECT 1 FROM contacts
895 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
896 LIMIT 1
897 ";
898 Ok(sqlx::query_scalar::<_, i32>(query)
899 .bind(id_a.0)
900 .bind(id_b.0)
901 .fetch_optional(&self.pool)
902 .await?
903 .is_some())
904 }
905
906 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
907 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
908 (sender_id, receiver_id, true)
909 } else {
910 (receiver_id, sender_id, false)
911 };
912 let query = "
913 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
914 VALUES ($1, $2, $3, 'f', 't')
915 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
916 SET
917 accepted = 't',
918 should_notify = 'f'
919 WHERE
920 NOT contacts.accepted AND
921 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
922 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
923 ";
924 let result = sqlx::query(query)
925 .bind(id_a.0)
926 .bind(id_b.0)
927 .bind(a_to_b)
928 .execute(&self.pool)
929 .await?;
930
931 if result.rows_affected() == 1 {
932 Ok(())
933 } else {
934 Err(anyhow!("contact already requested"))?
935 }
936 }
937
938 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
939 let (id_a, id_b) = if responder_id < requester_id {
940 (responder_id, requester_id)
941 } else {
942 (requester_id, responder_id)
943 };
944 let query = "
945 DELETE FROM contacts
946 WHERE user_id_a = $1 AND user_id_b = $2;
947 ";
948 let result = sqlx::query(query)
949 .bind(id_a.0)
950 .bind(id_b.0)
951 .execute(&self.pool)
952 .await?;
953
954 if result.rows_affected() == 1 {
955 Ok(())
956 } else {
957 Err(anyhow!("no such contact"))?
958 }
959 }
960
961 async fn dismiss_contact_notification(
962 &self,
963 user_id: UserId,
964 contact_user_id: UserId,
965 ) -> Result<()> {
966 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
967 (user_id, contact_user_id, true)
968 } else {
969 (contact_user_id, user_id, false)
970 };
971
972 let query = "
973 UPDATE contacts
974 SET should_notify = 'f'
975 WHERE
976 user_id_a = $1 AND user_id_b = $2 AND
977 (
978 (a_to_b = $3 AND accepted) OR
979 (a_to_b != $3 AND NOT accepted)
980 );
981 ";
982
983 let result = sqlx::query(query)
984 .bind(id_a.0)
985 .bind(id_b.0)
986 .bind(a_to_b)
987 .execute(&self.pool)
988 .await?;
989
990 if result.rows_affected() == 0 {
991 Err(anyhow!("no such contact request"))?;
992 }
993
994 Ok(())
995 }
996
997 async fn respond_to_contact_request(
998 &self,
999 responder_id: UserId,
1000 requester_id: UserId,
1001 accept: bool,
1002 ) -> Result<()> {
1003 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1004 (responder_id, requester_id, false)
1005 } else {
1006 (requester_id, responder_id, true)
1007 };
1008 let result = if accept {
1009 let query = "
1010 UPDATE contacts
1011 SET accepted = 't', should_notify = 't'
1012 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1013 ";
1014 sqlx::query(query)
1015 .bind(id_a.0)
1016 .bind(id_b.0)
1017 .bind(a_to_b)
1018 .execute(&self.pool)
1019 .await?
1020 } else {
1021 let query = "
1022 DELETE FROM contacts
1023 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1024 ";
1025 sqlx::query(query)
1026 .bind(id_a.0)
1027 .bind(id_b.0)
1028 .bind(a_to_b)
1029 .execute(&self.pool)
1030 .await?
1031 };
1032 if result.rows_affected() == 1 {
1033 Ok(())
1034 } else {
1035 Err(anyhow!("no such contact request"))?
1036 }
1037 }
1038
1039 // access tokens
1040
1041 async fn create_access_token_hash(
1042 &self,
1043 user_id: UserId,
1044 access_token_hash: &str,
1045 max_access_token_count: usize,
1046 ) -> Result<()> {
1047 let insert_query = "
1048 INSERT INTO access_tokens (user_id, hash)
1049 VALUES ($1, $2);
1050 ";
1051 let cleanup_query = "
1052 DELETE FROM access_tokens
1053 WHERE id IN (
1054 SELECT id from access_tokens
1055 WHERE user_id = $1
1056 ORDER BY id DESC
1057 OFFSET $3
1058 )
1059 ";
1060
1061 let mut tx = self.pool.begin().await?;
1062 sqlx::query(insert_query)
1063 .bind(user_id.0)
1064 .bind(access_token_hash)
1065 .execute(&mut tx)
1066 .await?;
1067 sqlx::query(cleanup_query)
1068 .bind(user_id.0)
1069 .bind(access_token_hash)
1070 .bind(max_access_token_count as i32)
1071 .execute(&mut tx)
1072 .await?;
1073 Ok(tx.commit().await?)
1074 }
1075
1076 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1077 let query = "
1078 SELECT hash
1079 FROM access_tokens
1080 WHERE user_id = $1
1081 ORDER BY id DESC
1082 ";
1083 Ok(sqlx::query_scalar(query)
1084 .bind(user_id.0)
1085 .fetch_all(&self.pool)
1086 .await?)
1087 }
1088
1089 // orgs
1090
1091 #[allow(unused)] // Help rust-analyzer
1092 #[cfg(any(test, feature = "seed-support"))]
1093 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1094 let query = "
1095 SELECT *
1096 FROM orgs
1097 WHERE slug = $1
1098 ";
1099 Ok(sqlx::query_as(query)
1100 .bind(slug)
1101 .fetch_optional(&self.pool)
1102 .await?)
1103 }
1104
1105 #[cfg(any(test, feature = "seed-support"))]
1106 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1107 let query = "
1108 INSERT INTO orgs (name, slug)
1109 VALUES ($1, $2)
1110 RETURNING id
1111 ";
1112 Ok(sqlx::query_scalar(query)
1113 .bind(name)
1114 .bind(slug)
1115 .fetch_one(&self.pool)
1116 .await
1117 .map(OrgId)?)
1118 }
1119
1120 #[cfg(any(test, feature = "seed-support"))]
1121 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1122 let query = "
1123 INSERT INTO org_memberships (org_id, user_id, admin)
1124 VALUES ($1, $2, $3)
1125 ON CONFLICT DO NOTHING
1126 ";
1127 Ok(sqlx::query(query)
1128 .bind(org_id.0)
1129 .bind(user_id.0)
1130 .bind(is_admin)
1131 .execute(&self.pool)
1132 .await
1133 .map(drop)?)
1134 }
1135
1136 // channels
1137
1138 #[cfg(any(test, feature = "seed-support"))]
1139 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1140 let query = "
1141 INSERT INTO channels (owner_id, owner_is_user, name)
1142 VALUES ($1, false, $2)
1143 RETURNING id
1144 ";
1145 Ok(sqlx::query_scalar(query)
1146 .bind(org_id.0)
1147 .bind(name)
1148 .fetch_one(&self.pool)
1149 .await
1150 .map(ChannelId)?)
1151 }
1152
1153 #[allow(unused)] // Help rust-analyzer
1154 #[cfg(any(test, feature = "seed-support"))]
1155 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1156 let query = "
1157 SELECT *
1158 FROM channels
1159 WHERE
1160 channels.owner_is_user = false AND
1161 channels.owner_id = $1
1162 ";
1163 Ok(sqlx::query_as(query)
1164 .bind(org_id.0)
1165 .fetch_all(&self.pool)
1166 .await?)
1167 }
1168
1169 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1170 let query = "
1171 SELECT
1172 channels.*
1173 FROM
1174 channel_memberships, channels
1175 WHERE
1176 channel_memberships.user_id = $1 AND
1177 channel_memberships.channel_id = channels.id
1178 ";
1179 Ok(sqlx::query_as(query)
1180 .bind(user_id.0)
1181 .fetch_all(&self.pool)
1182 .await?)
1183 }
1184
1185 async fn can_user_access_channel(
1186 &self,
1187 user_id: UserId,
1188 channel_id: ChannelId,
1189 ) -> Result<bool> {
1190 let query = "
1191 SELECT id
1192 FROM channel_memberships
1193 WHERE user_id = $1 AND channel_id = $2
1194 LIMIT 1
1195 ";
1196 Ok(sqlx::query_scalar::<_, i32>(query)
1197 .bind(user_id.0)
1198 .bind(channel_id.0)
1199 .fetch_optional(&self.pool)
1200 .await
1201 .map(|e| e.is_some())?)
1202 }
1203
1204 #[cfg(any(test, feature = "seed-support"))]
1205 async fn add_channel_member(
1206 &self,
1207 channel_id: ChannelId,
1208 user_id: UserId,
1209 is_admin: bool,
1210 ) -> Result<()> {
1211 let query = "
1212 INSERT INTO channel_memberships (channel_id, user_id, admin)
1213 VALUES ($1, $2, $3)
1214 ON CONFLICT DO NOTHING
1215 ";
1216 Ok(sqlx::query(query)
1217 .bind(channel_id.0)
1218 .bind(user_id.0)
1219 .bind(is_admin)
1220 .execute(&self.pool)
1221 .await
1222 .map(drop)?)
1223 }
1224
1225 // messages
1226
1227 async fn create_channel_message(
1228 &self,
1229 channel_id: ChannelId,
1230 sender_id: UserId,
1231 body: &str,
1232 timestamp: OffsetDateTime,
1233 nonce: u128,
1234 ) -> Result<MessageId> {
1235 let query = "
1236 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1237 VALUES ($1, $2, $3, $4, $5)
1238 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1239 RETURNING id
1240 ";
1241 Ok(sqlx::query_scalar(query)
1242 .bind(channel_id.0)
1243 .bind(sender_id.0)
1244 .bind(body)
1245 .bind(timestamp)
1246 .bind(Uuid::from_u128(nonce))
1247 .fetch_one(&self.pool)
1248 .await
1249 .map(MessageId)?)
1250 }
1251
1252 async fn get_channel_messages(
1253 &self,
1254 channel_id: ChannelId,
1255 count: usize,
1256 before_id: Option<MessageId>,
1257 ) -> Result<Vec<ChannelMessage>> {
1258 let query = r#"
1259 SELECT * FROM (
1260 SELECT
1261 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1262 FROM
1263 channel_messages
1264 WHERE
1265 channel_id = $1 AND
1266 id < $2
1267 ORDER BY id DESC
1268 LIMIT $3
1269 ) as recent_messages
1270 ORDER BY id ASC
1271 "#;
1272 Ok(sqlx::query_as(query)
1273 .bind(channel_id.0)
1274 .bind(before_id.unwrap_or(MessageId::MAX))
1275 .bind(count as i64)
1276 .fetch_all(&self.pool)
1277 .await?)
1278 }
1279
1280 #[cfg(test)]
1281 async fn teardown(&self, url: &str) {
1282 use util::ResultExt;
1283
1284 let query = "
1285 SELECT pg_terminate_backend(pg_stat_activity.pid)
1286 FROM pg_stat_activity
1287 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1288 ";
1289 sqlx::query(query).execute(&self.pool).await.log_err();
1290 self.pool.close().await;
1291 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1292 .await
1293 .log_err();
1294 }
1295
1296 #[cfg(test)]
1297 fn as_fake(&self) -> Option<&tests::FakeDb> {
1298 None
1299 }
1300}
1301
1302macro_rules! id_type {
1303 ($name:ident) => {
1304 #[derive(
1305 Clone,
1306 Copy,
1307 Debug,
1308 Default,
1309 PartialEq,
1310 Eq,
1311 PartialOrd,
1312 Ord,
1313 Hash,
1314 sqlx::Type,
1315 Serialize,
1316 Deserialize,
1317 )]
1318 #[sqlx(transparent)]
1319 #[serde(transparent)]
1320 pub struct $name(pub i32);
1321
1322 impl $name {
1323 #[allow(unused)]
1324 pub const MAX: Self = Self(i32::MAX);
1325
1326 #[allow(unused)]
1327 pub fn from_proto(value: u64) -> Self {
1328 Self(value as i32)
1329 }
1330
1331 #[allow(unused)]
1332 pub fn to_proto(self) -> u64 {
1333 self.0 as u64
1334 }
1335 }
1336
1337 impl std::fmt::Display for $name {
1338 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1339 self.0.fmt(f)
1340 }
1341 }
1342 };
1343}
1344
1345id_type!(UserId);
1346#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1347pub struct User {
1348 pub id: UserId,
1349 pub github_login: String,
1350 pub email_address: Option<String>,
1351 pub admin: bool,
1352 pub invite_code: Option<String>,
1353 pub invite_count: i32,
1354 pub connected_once: bool,
1355}
1356
1357id_type!(ProjectId);
1358#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1359pub struct Project {
1360 pub id: ProjectId,
1361 pub host_user_id: UserId,
1362 pub unregistered: bool,
1363}
1364
1365#[derive(Clone, Debug, PartialEq, Serialize)]
1366pub struct UserActivitySummary {
1367 pub id: UserId,
1368 pub github_login: String,
1369 pub project_activity: Vec<ProjectActivitySummary>,
1370}
1371
1372#[derive(Clone, Debug, PartialEq, Serialize)]
1373pub struct ProjectActivitySummary {
1374 id: ProjectId,
1375 duration: Duration,
1376 max_collaborators: usize,
1377}
1378
1379#[derive(Clone, Debug, PartialEq, Serialize)]
1380pub struct UserActivityPeriod {
1381 project_id: ProjectId,
1382 #[serde(with = "time::serde::iso8601")]
1383 start: OffsetDateTime,
1384 #[serde(with = "time::serde::iso8601")]
1385 end: OffsetDateTime,
1386 extensions: HashMap<String, usize>,
1387}
1388
1389id_type!(OrgId);
1390#[derive(FromRow)]
1391pub struct Org {
1392 pub id: OrgId,
1393 pub name: String,
1394 pub slug: String,
1395}
1396
1397id_type!(ChannelId);
1398#[derive(Clone, Debug, FromRow, Serialize)]
1399pub struct Channel {
1400 pub id: ChannelId,
1401 pub name: String,
1402 pub owner_id: i32,
1403 pub owner_is_user: bool,
1404}
1405
1406id_type!(MessageId);
1407#[derive(Clone, Debug, FromRow)]
1408pub struct ChannelMessage {
1409 pub id: MessageId,
1410 pub channel_id: ChannelId,
1411 pub sender_id: UserId,
1412 pub body: String,
1413 pub sent_at: OffsetDateTime,
1414 pub nonce: Uuid,
1415}
1416
1417#[derive(Clone, Debug, PartialEq, Eq)]
1418pub enum Contact {
1419 Accepted {
1420 user_id: UserId,
1421 should_notify: bool,
1422 },
1423 Outgoing {
1424 user_id: UserId,
1425 },
1426 Incoming {
1427 user_id: UserId,
1428 should_notify: bool,
1429 },
1430}
1431
1432impl Contact {
1433 pub fn user_id(&self) -> UserId {
1434 match self {
1435 Contact::Accepted { user_id, .. } => *user_id,
1436 Contact::Outgoing { user_id } => *user_id,
1437 Contact::Incoming { user_id, .. } => *user_id,
1438 }
1439 }
1440}
1441
1442#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1443pub struct IncomingContactRequest {
1444 pub requester_id: UserId,
1445 pub should_notify: bool,
1446}
1447
1448fn fuzzy_like_string(string: &str) -> String {
1449 let mut result = String::with_capacity(string.len() * 2 + 1);
1450 for c in string.chars() {
1451 if c.is_alphanumeric() {
1452 result.push('%');
1453 result.push(c);
1454 }
1455 }
1456 result.push('%');
1457 result
1458}
1459
1460fn random_invite_code() -> String {
1461 nanoid::nanoid!(16)
1462}
1463
1464#[cfg(test)]
1465pub mod tests {
1466 use super::*;
1467 use anyhow::anyhow;
1468 use collections::BTreeMap;
1469 use gpui::executor::{Background, Deterministic};
1470 use lazy_static::lazy_static;
1471 use parking_lot::Mutex;
1472 use rand::prelude::*;
1473 use sqlx::{
1474 migrate::{MigrateDatabase, Migrator},
1475 Postgres,
1476 };
1477 use std::{path::Path, sync::Arc};
1478 use util::post_inc;
1479
1480 #[tokio::test(flavor = "multi_thread")]
1481 async fn test_get_users_by_ids() {
1482 for test_db in [
1483 TestDb::postgres().await,
1484 TestDb::fake(build_background_executor()),
1485 ] {
1486 let db = test_db.db();
1487
1488 let user = db.create_user("user", None, false).await.unwrap();
1489 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1490 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1491 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1492
1493 assert_eq!(
1494 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1495 .await
1496 .unwrap(),
1497 vec![
1498 User {
1499 id: user,
1500 github_login: "user".to_string(),
1501 admin: false,
1502 ..Default::default()
1503 },
1504 User {
1505 id: friend1,
1506 github_login: "friend-1".to_string(),
1507 admin: false,
1508 ..Default::default()
1509 },
1510 User {
1511 id: friend2,
1512 github_login: "friend-2".to_string(),
1513 admin: false,
1514 ..Default::default()
1515 },
1516 User {
1517 id: friend3,
1518 github_login: "friend-3".to_string(),
1519 admin: false,
1520 ..Default::default()
1521 }
1522 ]
1523 );
1524 }
1525 }
1526
1527 #[tokio::test(flavor = "multi_thread")]
1528 async fn test_create_users() {
1529 let db = TestDb::postgres().await;
1530 let db = db.db();
1531
1532 // Create the first batch of users, ensuring invite counts are assigned
1533 // correctly and the respective invite codes are unique.
1534 let user_ids_batch_1 = db
1535 .create_users(vec![
1536 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1537 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1538 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1539 ])
1540 .await
1541 .unwrap();
1542 assert_eq!(user_ids_batch_1.len(), 3);
1543
1544 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1545 assert_eq!(users.len(), 3);
1546 assert_eq!(users[0].github_login, "user1");
1547 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1548 assert_eq!(users[0].invite_count, 5);
1549 assert_eq!(users[1].github_login, "user2");
1550 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1551 assert_eq!(users[1].invite_count, 4);
1552 assert_eq!(users[2].github_login, "user3");
1553 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1554 assert_eq!(users[2].invite_count, 3);
1555
1556 let invite_code_1 = users[0].invite_code.clone().unwrap();
1557 let invite_code_2 = users[1].invite_code.clone().unwrap();
1558 let invite_code_3 = users[2].invite_code.clone().unwrap();
1559 assert_ne!(invite_code_1, invite_code_2);
1560 assert_ne!(invite_code_1, invite_code_3);
1561 assert_ne!(invite_code_2, invite_code_3);
1562
1563 // Create the second batch of users and include a user that is already in the database, ensuring
1564 // the invite count for the existing user is updated without changing their invite code.
1565 let user_ids_batch_2 = db
1566 .create_users(vec![
1567 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1568 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1569 ])
1570 .await
1571 .unwrap();
1572 assert_eq!(user_ids_batch_2.len(), 2);
1573 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1574
1575 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1576 assert_eq!(users.len(), 2);
1577 assert_eq!(users[0].github_login, "user2");
1578 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1579 assert_eq!(users[0].invite_count, 10);
1580 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1581 assert_eq!(users[1].github_login, "user4");
1582 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1583 assert_eq!(users[1].invite_count, 2);
1584
1585 let invite_code_4 = users[1].invite_code.clone().unwrap();
1586 assert_ne!(invite_code_4, invite_code_1);
1587 assert_ne!(invite_code_4, invite_code_2);
1588 assert_ne!(invite_code_4, invite_code_3);
1589 }
1590
1591 #[tokio::test(flavor = "multi_thread")]
1592 async fn test_worktree_extensions() {
1593 let test_db = TestDb::postgres().await;
1594 let db = test_db.db();
1595
1596 let user = db.create_user("user_1", None, false).await.unwrap();
1597 let project = db.register_project(user).await.unwrap();
1598
1599 db.update_worktree_extensions(project, 100, Default::default())
1600 .await
1601 .unwrap();
1602 db.update_worktree_extensions(
1603 project,
1604 100,
1605 [("rs".to_string(), 5), ("md".to_string(), 3)]
1606 .into_iter()
1607 .collect(),
1608 )
1609 .await
1610 .unwrap();
1611 db.update_worktree_extensions(
1612 project,
1613 100,
1614 [("rs".to_string(), 6), ("md".to_string(), 5)]
1615 .into_iter()
1616 .collect(),
1617 )
1618 .await
1619 .unwrap();
1620 db.update_worktree_extensions(
1621 project,
1622 101,
1623 [("ts".to_string(), 2), ("md".to_string(), 1)]
1624 .into_iter()
1625 .collect(),
1626 )
1627 .await
1628 .unwrap();
1629
1630 assert_eq!(
1631 db.get_project_extensions(project).await.unwrap(),
1632 [
1633 (
1634 100,
1635 [("rs".into(), 6), ("md".into(), 5),]
1636 .into_iter()
1637 .collect::<HashMap<_, _>>()
1638 ),
1639 (
1640 101,
1641 [("ts".into(), 2), ("md".into(), 1),]
1642 .into_iter()
1643 .collect::<HashMap<_, _>>()
1644 )
1645 ]
1646 .into_iter()
1647 .collect()
1648 );
1649 }
1650
1651 #[tokio::test(flavor = "multi_thread")]
1652 async fn test_user_activity() {
1653 let test_db = TestDb::postgres().await;
1654 let db = test_db.db();
1655
1656 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1657 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1658 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1659 let project_1 = db.register_project(user_1).await.unwrap();
1660 db.update_worktree_extensions(
1661 project_1,
1662 1,
1663 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1664 )
1665 .await
1666 .unwrap();
1667 let project_2 = db.register_project(user_2).await.unwrap();
1668 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1669
1670 // User 2 opens a project
1671 let t1 = t0 + Duration::from_secs(10);
1672 db.record_user_activity(t0..t1, &[(user_2, project_2)])
1673 .await
1674 .unwrap();
1675
1676 let t2 = t1 + Duration::from_secs(10);
1677 db.record_user_activity(t1..t2, &[(user_2, project_2)])
1678 .await
1679 .unwrap();
1680
1681 // User 1 joins the project
1682 let t3 = t2 + Duration::from_secs(10);
1683 db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1684 .await
1685 .unwrap();
1686
1687 // User 1 opens another project
1688 let t4 = t3 + Duration::from_secs(10);
1689 db.record_user_activity(
1690 t3..t4,
1691 &[
1692 (user_2, project_2),
1693 (user_1, project_2),
1694 (user_1, project_1),
1695 ],
1696 )
1697 .await
1698 .unwrap();
1699
1700 // User 3 joins that project
1701 let t5 = t4 + Duration::from_secs(10);
1702 db.record_user_activity(
1703 t4..t5,
1704 &[
1705 (user_2, project_2),
1706 (user_1, project_2),
1707 (user_1, project_1),
1708 (user_3, project_1),
1709 ],
1710 )
1711 .await
1712 .unwrap();
1713
1714 // User 2 leaves
1715 let t6 = t5 + Duration::from_secs(5);
1716 db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1717 .await
1718 .unwrap();
1719
1720 let t7 = t6 + Duration::from_secs(60);
1721 let t8 = t7 + Duration::from_secs(10);
1722 db.record_user_activity(t7..t8, &[(user_1, project_1)])
1723 .await
1724 .unwrap();
1725
1726 assert_eq!(
1727 db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1728 &[
1729 UserActivitySummary {
1730 id: user_1,
1731 github_login: "user_1".to_string(),
1732 project_activity: vec![
1733 ProjectActivitySummary {
1734 id: project_1,
1735 duration: Duration::from_secs(25),
1736 max_collaborators: 2
1737 },
1738 ProjectActivitySummary {
1739 id: project_2,
1740 duration: Duration::from_secs(30),
1741 max_collaborators: 2
1742 }
1743 ]
1744 },
1745 UserActivitySummary {
1746 id: user_2,
1747 github_login: "user_2".to_string(),
1748 project_activity: vec![ProjectActivitySummary {
1749 id: project_2,
1750 duration: Duration::from_secs(50),
1751 max_collaborators: 2
1752 }]
1753 },
1754 UserActivitySummary {
1755 id: user_3,
1756 github_login: "user_3".to_string(),
1757 project_activity: vec![ProjectActivitySummary {
1758 id: project_1,
1759 duration: Duration::from_secs(15),
1760 max_collaborators: 2
1761 }]
1762 },
1763 ]
1764 );
1765
1766 assert_eq!(
1767 db.get_active_user_count(t0..t6, Duration::from_secs(56), false)
1768 .await
1769 .unwrap(),
1770 0
1771 );
1772 assert_eq!(
1773 db.get_active_user_count(t0..t6, Duration::from_secs(56), true)
1774 .await
1775 .unwrap(),
1776 0
1777 );
1778 assert_eq!(
1779 db.get_active_user_count(t0..t6, Duration::from_secs(54), false)
1780 .await
1781 .unwrap(),
1782 1
1783 );
1784 assert_eq!(
1785 db.get_active_user_count(t0..t6, Duration::from_secs(54), true)
1786 .await
1787 .unwrap(),
1788 1
1789 );
1790 assert_eq!(
1791 db.get_active_user_count(t0..t6, Duration::from_secs(30), false)
1792 .await
1793 .unwrap(),
1794 2
1795 );
1796 assert_eq!(
1797 db.get_active_user_count(t0..t6, Duration::from_secs(30), true)
1798 .await
1799 .unwrap(),
1800 2
1801 );
1802 assert_eq!(
1803 db.get_active_user_count(t0..t6, Duration::from_secs(10), false)
1804 .await
1805 .unwrap(),
1806 3
1807 );
1808 assert_eq!(
1809 db.get_active_user_count(t0..t6, Duration::from_secs(10), true)
1810 .await
1811 .unwrap(),
1812 3
1813 );
1814 assert_eq!(
1815 db.get_active_user_count(t0..t1, Duration::from_secs(5), false)
1816 .await
1817 .unwrap(),
1818 1
1819 );
1820 assert_eq!(
1821 db.get_active_user_count(t0..t1, Duration::from_secs(5), true)
1822 .await
1823 .unwrap(),
1824 0
1825 );
1826
1827 assert_eq!(
1828 db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1829 &[
1830 UserActivityPeriod {
1831 project_id: project_1,
1832 start: t3,
1833 end: t6,
1834 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1835 },
1836 UserActivityPeriod {
1837 project_id: project_2,
1838 start: t3,
1839 end: t5,
1840 extensions: Default::default(),
1841 },
1842 ]
1843 );
1844 assert_eq!(
1845 db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1846 &[
1847 UserActivityPeriod {
1848 project_id: project_2,
1849 start: t2,
1850 end: t5,
1851 extensions: Default::default(),
1852 },
1853 UserActivityPeriod {
1854 project_id: project_1,
1855 start: t3,
1856 end: t6,
1857 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1858 },
1859 UserActivityPeriod {
1860 project_id: project_1,
1861 start: t7,
1862 end: t8,
1863 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1864 },
1865 ]
1866 );
1867 }
1868
1869 #[tokio::test(flavor = "multi_thread")]
1870 async fn test_recent_channel_messages() {
1871 for test_db in [
1872 TestDb::postgres().await,
1873 TestDb::fake(build_background_executor()),
1874 ] {
1875 let db = test_db.db();
1876 let user = db.create_user("user", None, false).await.unwrap();
1877 let org = db.create_org("org", "org").await.unwrap();
1878 let channel = db.create_org_channel(org, "channel").await.unwrap();
1879 for i in 0..10 {
1880 db.create_channel_message(
1881 channel,
1882 user,
1883 &i.to_string(),
1884 OffsetDateTime::now_utc(),
1885 i,
1886 )
1887 .await
1888 .unwrap();
1889 }
1890
1891 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1892 assert_eq!(
1893 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1894 ["5", "6", "7", "8", "9"]
1895 );
1896
1897 let prev_messages = db
1898 .get_channel_messages(channel, 4, Some(messages[0].id))
1899 .await
1900 .unwrap();
1901 assert_eq!(
1902 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1903 ["1", "2", "3", "4"]
1904 );
1905 }
1906 }
1907
1908 #[tokio::test(flavor = "multi_thread")]
1909 async fn test_channel_message_nonces() {
1910 for test_db in [
1911 TestDb::postgres().await,
1912 TestDb::fake(build_background_executor()),
1913 ] {
1914 let db = test_db.db();
1915 let user = db.create_user("user", None, false).await.unwrap();
1916 let org = db.create_org("org", "org").await.unwrap();
1917 let channel = db.create_org_channel(org, "channel").await.unwrap();
1918
1919 let msg1_id = db
1920 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1921 .await
1922 .unwrap();
1923 let msg2_id = db
1924 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1925 .await
1926 .unwrap();
1927 let msg3_id = db
1928 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1929 .await
1930 .unwrap();
1931 let msg4_id = db
1932 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1933 .await
1934 .unwrap();
1935
1936 assert_ne!(msg1_id, msg2_id);
1937 assert_eq!(msg1_id, msg3_id);
1938 assert_eq!(msg2_id, msg4_id);
1939 }
1940 }
1941
1942 #[tokio::test(flavor = "multi_thread")]
1943 async fn test_create_access_tokens() {
1944 let test_db = TestDb::postgres().await;
1945 let db = test_db.db();
1946 let user = db.create_user("the-user", None, false).await.unwrap();
1947
1948 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1949 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1950 assert_eq!(
1951 db.get_access_token_hashes(user).await.unwrap(),
1952 &["h2".to_string(), "h1".to_string()]
1953 );
1954
1955 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1956 assert_eq!(
1957 db.get_access_token_hashes(user).await.unwrap(),
1958 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1959 );
1960
1961 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1962 assert_eq!(
1963 db.get_access_token_hashes(user).await.unwrap(),
1964 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1965 );
1966
1967 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1968 assert_eq!(
1969 db.get_access_token_hashes(user).await.unwrap(),
1970 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1971 );
1972 }
1973
1974 #[test]
1975 fn test_fuzzy_like_string() {
1976 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1977 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1978 assert_eq!(fuzzy_like_string(" z "), "%z%");
1979 }
1980
1981 #[tokio::test(flavor = "multi_thread")]
1982 async fn test_fuzzy_search_users() {
1983 let test_db = TestDb::postgres().await;
1984 let db = test_db.db();
1985 for github_login in [
1986 "California",
1987 "colorado",
1988 "oregon",
1989 "washington",
1990 "florida",
1991 "delaware",
1992 "rhode-island",
1993 ] {
1994 db.create_user(github_login, None, false).await.unwrap();
1995 }
1996
1997 assert_eq!(
1998 fuzzy_search_user_names(db, "clr").await,
1999 &["colorado", "California"]
2000 );
2001 assert_eq!(
2002 fuzzy_search_user_names(db, "ro").await,
2003 &["rhode-island", "colorado", "oregon"],
2004 );
2005
2006 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
2007 db.fuzzy_search_users(query, 10)
2008 .await
2009 .unwrap()
2010 .into_iter()
2011 .map(|user| user.github_login)
2012 .collect::<Vec<_>>()
2013 }
2014 }
2015
2016 #[tokio::test(flavor = "multi_thread")]
2017 async fn test_add_contacts() {
2018 for test_db in [
2019 TestDb::postgres().await,
2020 TestDb::fake(build_background_executor()),
2021 ] {
2022 let db = test_db.db();
2023
2024 let user_1 = db.create_user("user1", None, false).await.unwrap();
2025 let user_2 = db.create_user("user2", None, false).await.unwrap();
2026 let user_3 = db.create_user("user3", None, false).await.unwrap();
2027
2028 // User starts with no contacts
2029 assert_eq!(
2030 db.get_contacts(user_1).await.unwrap(),
2031 vec![Contact::Accepted {
2032 user_id: user_1,
2033 should_notify: false
2034 }],
2035 );
2036
2037 // User requests a contact. Both users see the pending request.
2038 db.send_contact_request(user_1, user_2).await.unwrap();
2039 assert!(!db.has_contact(user_1, user_2).await.unwrap());
2040 assert!(!db.has_contact(user_2, user_1).await.unwrap());
2041 assert_eq!(
2042 db.get_contacts(user_1).await.unwrap(),
2043 &[
2044 Contact::Accepted {
2045 user_id: user_1,
2046 should_notify: false
2047 },
2048 Contact::Outgoing { user_id: user_2 }
2049 ],
2050 );
2051 assert_eq!(
2052 db.get_contacts(user_2).await.unwrap(),
2053 &[
2054 Contact::Incoming {
2055 user_id: user_1,
2056 should_notify: true
2057 },
2058 Contact::Accepted {
2059 user_id: user_2,
2060 should_notify: false
2061 },
2062 ]
2063 );
2064
2065 // User 2 dismisses the contact request notification without accepting or rejecting.
2066 // We shouldn't notify them again.
2067 db.dismiss_contact_notification(user_1, user_2)
2068 .await
2069 .unwrap_err();
2070 db.dismiss_contact_notification(user_2, user_1)
2071 .await
2072 .unwrap();
2073 assert_eq!(
2074 db.get_contacts(user_2).await.unwrap(),
2075 &[
2076 Contact::Incoming {
2077 user_id: user_1,
2078 should_notify: false
2079 },
2080 Contact::Accepted {
2081 user_id: user_2,
2082 should_notify: false
2083 },
2084 ]
2085 );
2086
2087 // User can't accept their own contact request
2088 db.respond_to_contact_request(user_1, user_2, true)
2089 .await
2090 .unwrap_err();
2091
2092 // User accepts a contact request. Both users see the contact.
2093 db.respond_to_contact_request(user_2, user_1, true)
2094 .await
2095 .unwrap();
2096 assert_eq!(
2097 db.get_contacts(user_1).await.unwrap(),
2098 &[
2099 Contact::Accepted {
2100 user_id: user_1,
2101 should_notify: false
2102 },
2103 Contact::Accepted {
2104 user_id: user_2,
2105 should_notify: true
2106 }
2107 ],
2108 );
2109 assert!(db.has_contact(user_1, user_2).await.unwrap());
2110 assert!(db.has_contact(user_2, user_1).await.unwrap());
2111 assert_eq!(
2112 db.get_contacts(user_2).await.unwrap(),
2113 &[
2114 Contact::Accepted {
2115 user_id: user_1,
2116 should_notify: false,
2117 },
2118 Contact::Accepted {
2119 user_id: user_2,
2120 should_notify: false,
2121 },
2122 ]
2123 );
2124
2125 // Users cannot re-request existing contacts.
2126 db.send_contact_request(user_1, user_2).await.unwrap_err();
2127 db.send_contact_request(user_2, user_1).await.unwrap_err();
2128
2129 // Users can't dismiss notifications of them accepting other users' requests.
2130 db.dismiss_contact_notification(user_2, user_1)
2131 .await
2132 .unwrap_err();
2133 assert_eq!(
2134 db.get_contacts(user_1).await.unwrap(),
2135 &[
2136 Contact::Accepted {
2137 user_id: user_1,
2138 should_notify: false
2139 },
2140 Contact::Accepted {
2141 user_id: user_2,
2142 should_notify: true,
2143 },
2144 ]
2145 );
2146
2147 // Users can dismiss notifications of other users accepting their requests.
2148 db.dismiss_contact_notification(user_1, user_2)
2149 .await
2150 .unwrap();
2151 assert_eq!(
2152 db.get_contacts(user_1).await.unwrap(),
2153 &[
2154 Contact::Accepted {
2155 user_id: user_1,
2156 should_notify: false
2157 },
2158 Contact::Accepted {
2159 user_id: user_2,
2160 should_notify: false,
2161 },
2162 ]
2163 );
2164
2165 // Users send each other concurrent contact requests and
2166 // see that they are immediately accepted.
2167 db.send_contact_request(user_1, user_3).await.unwrap();
2168 db.send_contact_request(user_3, user_1).await.unwrap();
2169 assert_eq!(
2170 db.get_contacts(user_1).await.unwrap(),
2171 &[
2172 Contact::Accepted {
2173 user_id: user_1,
2174 should_notify: false
2175 },
2176 Contact::Accepted {
2177 user_id: user_2,
2178 should_notify: false,
2179 },
2180 Contact::Accepted {
2181 user_id: user_3,
2182 should_notify: false
2183 },
2184 ]
2185 );
2186 assert_eq!(
2187 db.get_contacts(user_3).await.unwrap(),
2188 &[
2189 Contact::Accepted {
2190 user_id: user_1,
2191 should_notify: false
2192 },
2193 Contact::Accepted {
2194 user_id: user_3,
2195 should_notify: false
2196 }
2197 ],
2198 );
2199
2200 // User declines a contact request. Both users see that it is gone.
2201 db.send_contact_request(user_2, user_3).await.unwrap();
2202 db.respond_to_contact_request(user_3, user_2, false)
2203 .await
2204 .unwrap();
2205 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2206 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2207 assert_eq!(
2208 db.get_contacts(user_2).await.unwrap(),
2209 &[
2210 Contact::Accepted {
2211 user_id: user_1,
2212 should_notify: false
2213 },
2214 Contact::Accepted {
2215 user_id: user_2,
2216 should_notify: false
2217 }
2218 ]
2219 );
2220 assert_eq!(
2221 db.get_contacts(user_3).await.unwrap(),
2222 &[
2223 Contact::Accepted {
2224 user_id: user_1,
2225 should_notify: false
2226 },
2227 Contact::Accepted {
2228 user_id: user_3,
2229 should_notify: false
2230 }
2231 ],
2232 );
2233 }
2234 }
2235
2236 #[tokio::test(flavor = "multi_thread")]
2237 async fn test_invite_codes() {
2238 let postgres = TestDb::postgres().await;
2239 let db = postgres.db();
2240 let user1 = db.create_user("user-1", None, false).await.unwrap();
2241
2242 // Initially, user 1 has no invite code
2243 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2244
2245 // Setting invite count to 0 when no code is assigned does not assign a new code
2246 db.set_invite_count(user1, 0).await.unwrap();
2247 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2248
2249 // User 1 creates an invite code that can be used twice.
2250 db.set_invite_count(user1, 2).await.unwrap();
2251 let (invite_code, invite_count) =
2252 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2253 assert_eq!(invite_count, 2);
2254
2255 // User 2 redeems the invite code and becomes a contact of user 1.
2256 let user2 = db
2257 .redeem_invite_code(&invite_code, "user-2", None)
2258 .await
2259 .unwrap();
2260 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2261 assert_eq!(invite_count, 1);
2262 assert_eq!(
2263 db.get_contacts(user1).await.unwrap(),
2264 [
2265 Contact::Accepted {
2266 user_id: user1,
2267 should_notify: false
2268 },
2269 Contact::Accepted {
2270 user_id: user2,
2271 should_notify: true
2272 }
2273 ]
2274 );
2275 assert_eq!(
2276 db.get_contacts(user2).await.unwrap(),
2277 [
2278 Contact::Accepted {
2279 user_id: user1,
2280 should_notify: false
2281 },
2282 Contact::Accepted {
2283 user_id: user2,
2284 should_notify: false
2285 }
2286 ]
2287 );
2288
2289 // User 3 redeems the invite code and becomes a contact of user 1.
2290 let user3 = db
2291 .redeem_invite_code(&invite_code, "user-3", None)
2292 .await
2293 .unwrap();
2294 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2295 assert_eq!(invite_count, 0);
2296 assert_eq!(
2297 db.get_contacts(user1).await.unwrap(),
2298 [
2299 Contact::Accepted {
2300 user_id: user1,
2301 should_notify: false
2302 },
2303 Contact::Accepted {
2304 user_id: user2,
2305 should_notify: true
2306 },
2307 Contact::Accepted {
2308 user_id: user3,
2309 should_notify: true
2310 }
2311 ]
2312 );
2313 assert_eq!(
2314 db.get_contacts(user3).await.unwrap(),
2315 [
2316 Contact::Accepted {
2317 user_id: user1,
2318 should_notify: false
2319 },
2320 Contact::Accepted {
2321 user_id: user3,
2322 should_notify: false
2323 },
2324 ]
2325 );
2326
2327 // Trying to reedem the code for the third time results in an error.
2328 db.redeem_invite_code(&invite_code, "user-4", None)
2329 .await
2330 .unwrap_err();
2331
2332 // Invite count can be updated after the code has been created.
2333 db.set_invite_count(user1, 2).await.unwrap();
2334 let (latest_code, invite_count) =
2335 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2336 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2337 assert_eq!(invite_count, 2);
2338
2339 // User 4 can now redeem the invite code and becomes a contact of user 1.
2340 let user4 = db
2341 .redeem_invite_code(&invite_code, "user-4", None)
2342 .await
2343 .unwrap();
2344 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2345 assert_eq!(invite_count, 1);
2346 assert_eq!(
2347 db.get_contacts(user1).await.unwrap(),
2348 [
2349 Contact::Accepted {
2350 user_id: user1,
2351 should_notify: false
2352 },
2353 Contact::Accepted {
2354 user_id: user2,
2355 should_notify: true
2356 },
2357 Contact::Accepted {
2358 user_id: user3,
2359 should_notify: true
2360 },
2361 Contact::Accepted {
2362 user_id: user4,
2363 should_notify: true
2364 }
2365 ]
2366 );
2367 assert_eq!(
2368 db.get_contacts(user4).await.unwrap(),
2369 [
2370 Contact::Accepted {
2371 user_id: user1,
2372 should_notify: false
2373 },
2374 Contact::Accepted {
2375 user_id: user4,
2376 should_notify: false
2377 },
2378 ]
2379 );
2380
2381 // An existing user cannot redeem invite codes.
2382 db.redeem_invite_code(&invite_code, "user-2", None)
2383 .await
2384 .unwrap_err();
2385 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2386 assert_eq!(invite_count, 1);
2387
2388 // Ensure invited users get invite codes too.
2389 assert_eq!(
2390 db.get_invite_code_for_user(user2).await.unwrap().unwrap().1,
2391 5
2392 );
2393 assert_eq!(
2394 db.get_invite_code_for_user(user3).await.unwrap().unwrap().1,
2395 5
2396 );
2397 assert_eq!(
2398 db.get_invite_code_for_user(user4).await.unwrap().unwrap().1,
2399 5
2400 );
2401 }
2402
2403 pub struct TestDb {
2404 pub db: Option<Arc<dyn Db>>,
2405 pub url: String,
2406 }
2407
2408 impl TestDb {
2409 #[allow(clippy::await_holding_lock)]
2410 pub async fn postgres() -> Self {
2411 lazy_static! {
2412 static ref LOCK: Mutex<()> = Mutex::new(());
2413 }
2414
2415 let _guard = LOCK.lock();
2416 let mut rng = StdRng::from_entropy();
2417 let name = format!("zed-test-{}", rng.gen::<u128>());
2418 let url = format!("postgres://postgres@localhost/{}", name);
2419 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2420 Postgres::create_database(&url)
2421 .await
2422 .expect("failed to create test db");
2423 let db = PostgresDb::new(&url, 5).await.unwrap();
2424 let migrator = Migrator::new(migrations_path).await.unwrap();
2425 migrator.run(&db.pool).await.unwrap();
2426 Self {
2427 db: Some(Arc::new(db)),
2428 url,
2429 }
2430 }
2431
2432 pub fn fake(background: Arc<Background>) -> Self {
2433 Self {
2434 db: Some(Arc::new(FakeDb::new(background))),
2435 url: Default::default(),
2436 }
2437 }
2438
2439 pub fn db(&self) -> &Arc<dyn Db> {
2440 self.db.as_ref().unwrap()
2441 }
2442 }
2443
2444 impl Drop for TestDb {
2445 fn drop(&mut self) {
2446 if let Some(db) = self.db.take() {
2447 futures::executor::block_on(db.teardown(&self.url));
2448 }
2449 }
2450 }
2451
2452 pub struct FakeDb {
2453 background: Arc<Background>,
2454 pub users: Mutex<BTreeMap<UserId, User>>,
2455 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2456 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2457 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2458 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2459 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2460 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2461 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2462 pub contacts: Mutex<Vec<FakeContact>>,
2463 next_channel_message_id: Mutex<i32>,
2464 next_user_id: Mutex<i32>,
2465 next_org_id: Mutex<i32>,
2466 next_channel_id: Mutex<i32>,
2467 next_project_id: Mutex<i32>,
2468 }
2469
2470 #[derive(Debug)]
2471 pub struct FakeContact {
2472 pub requester_id: UserId,
2473 pub responder_id: UserId,
2474 pub accepted: bool,
2475 pub should_notify: bool,
2476 }
2477
2478 impl FakeDb {
2479 pub fn new(background: Arc<Background>) -> Self {
2480 Self {
2481 background,
2482 users: Default::default(),
2483 next_user_id: Mutex::new(0),
2484 projects: Default::default(),
2485 worktree_extensions: Default::default(),
2486 next_project_id: Mutex::new(1),
2487 orgs: Default::default(),
2488 next_org_id: Mutex::new(1),
2489 org_memberships: Default::default(),
2490 channels: Default::default(),
2491 next_channel_id: Mutex::new(1),
2492 channel_memberships: Default::default(),
2493 channel_messages: Default::default(),
2494 next_channel_message_id: Mutex::new(1),
2495 contacts: Default::default(),
2496 }
2497 }
2498 }
2499
2500 #[async_trait]
2501 impl Db for FakeDb {
2502 async fn create_user(
2503 &self,
2504 github_login: &str,
2505 email_address: Option<&str>,
2506 admin: bool,
2507 ) -> Result<UserId> {
2508 self.background.simulate_random_delay().await;
2509
2510 let mut users = self.users.lock();
2511 if let Some(user) = users
2512 .values()
2513 .find(|user| user.github_login == github_login)
2514 {
2515 Ok(user.id)
2516 } else {
2517 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2518 users.insert(
2519 user_id,
2520 User {
2521 id: user_id,
2522 github_login: github_login.to_string(),
2523 email_address: email_address.map(str::to_string),
2524 admin,
2525 invite_code: None,
2526 invite_count: 0,
2527 connected_once: false,
2528 },
2529 );
2530 Ok(user_id)
2531 }
2532 }
2533
2534 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2535 unimplemented!()
2536 }
2537
2538 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2539 unimplemented!()
2540 }
2541
2542 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2543 unimplemented!()
2544 }
2545
2546 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2547 self.background.simulate_random_delay().await;
2548 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2549 }
2550
2551 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2552 self.background.simulate_random_delay().await;
2553 let users = self.users.lock();
2554 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2555 }
2556
2557 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2558 unimplemented!()
2559 }
2560
2561 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2562 self.background.simulate_random_delay().await;
2563 Ok(self
2564 .users
2565 .lock()
2566 .values()
2567 .find(|user| user.github_login == github_login)
2568 .cloned())
2569 }
2570
2571 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2572 unimplemented!()
2573 }
2574
2575 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2576 self.background.simulate_random_delay().await;
2577 let mut users = self.users.lock();
2578 let mut user = users
2579 .get_mut(&id)
2580 .ok_or_else(|| anyhow!("user not found"))?;
2581 user.connected_once = connected_once;
2582 Ok(())
2583 }
2584
2585 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2586 unimplemented!()
2587 }
2588
2589 // invite codes
2590
2591 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2592 unimplemented!()
2593 }
2594
2595 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2596 self.background.simulate_random_delay().await;
2597 Ok(None)
2598 }
2599
2600 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2601 unimplemented!()
2602 }
2603
2604 async fn redeem_invite_code(
2605 &self,
2606 _code: &str,
2607 _login: &str,
2608 _email_address: Option<&str>,
2609 ) -> Result<UserId> {
2610 unimplemented!()
2611 }
2612
2613 // projects
2614
2615 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2616 self.background.simulate_random_delay().await;
2617 if !self.users.lock().contains_key(&host_user_id) {
2618 Err(anyhow!("no such user"))?;
2619 }
2620
2621 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2622 self.projects.lock().insert(
2623 project_id,
2624 Project {
2625 id: project_id,
2626 host_user_id,
2627 unregistered: false,
2628 },
2629 );
2630 Ok(project_id)
2631 }
2632
2633 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2634 self.background.simulate_random_delay().await;
2635 self.projects
2636 .lock()
2637 .get_mut(&project_id)
2638 .ok_or_else(|| anyhow!("no such project"))?
2639 .unregistered = true;
2640 Ok(())
2641 }
2642
2643 async fn update_worktree_extensions(
2644 &self,
2645 project_id: ProjectId,
2646 worktree_id: u64,
2647 extensions: HashMap<String, u32>,
2648 ) -> Result<()> {
2649 self.background.simulate_random_delay().await;
2650 if !self.projects.lock().contains_key(&project_id) {
2651 Err(anyhow!("no such project"))?;
2652 }
2653
2654 for (extension, count) in extensions {
2655 self.worktree_extensions
2656 .lock()
2657 .insert((project_id, worktree_id, extension), count);
2658 }
2659
2660 Ok(())
2661 }
2662
2663 async fn get_project_extensions(
2664 &self,
2665 _project_id: ProjectId,
2666 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2667 unimplemented!()
2668 }
2669
2670 async fn record_user_activity(
2671 &self,
2672 _time_period: Range<OffsetDateTime>,
2673 _active_projects: &[(UserId, ProjectId)],
2674 ) -> Result<()> {
2675 unimplemented!()
2676 }
2677
2678 async fn get_active_user_count(
2679 &self,
2680 _time_period: Range<OffsetDateTime>,
2681 _min_duration: Duration,
2682 _only_collaborative: bool,
2683 ) -> Result<usize> {
2684 unimplemented!()
2685 }
2686
2687 async fn get_top_users_activity_summary(
2688 &self,
2689 _time_period: Range<OffsetDateTime>,
2690 _limit: usize,
2691 ) -> Result<Vec<UserActivitySummary>> {
2692 unimplemented!()
2693 }
2694
2695 async fn get_user_activity_timeline(
2696 &self,
2697 _time_period: Range<OffsetDateTime>,
2698 _user_id: UserId,
2699 ) -> Result<Vec<UserActivityPeriod>> {
2700 unimplemented!()
2701 }
2702
2703 // contacts
2704
2705 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2706 self.background.simulate_random_delay().await;
2707 let mut contacts = vec![Contact::Accepted {
2708 user_id: id,
2709 should_notify: false,
2710 }];
2711
2712 for contact in self.contacts.lock().iter() {
2713 if contact.requester_id == id {
2714 if contact.accepted {
2715 contacts.push(Contact::Accepted {
2716 user_id: contact.responder_id,
2717 should_notify: contact.should_notify,
2718 });
2719 } else {
2720 contacts.push(Contact::Outgoing {
2721 user_id: contact.responder_id,
2722 });
2723 }
2724 } else if contact.responder_id == id {
2725 if contact.accepted {
2726 contacts.push(Contact::Accepted {
2727 user_id: contact.requester_id,
2728 should_notify: false,
2729 });
2730 } else {
2731 contacts.push(Contact::Incoming {
2732 user_id: contact.requester_id,
2733 should_notify: contact.should_notify,
2734 });
2735 }
2736 }
2737 }
2738
2739 contacts.sort_unstable_by_key(|contact| contact.user_id());
2740 Ok(contacts)
2741 }
2742
2743 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2744 self.background.simulate_random_delay().await;
2745 Ok(self.contacts.lock().iter().any(|contact| {
2746 contact.accepted
2747 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2748 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2749 }))
2750 }
2751
2752 async fn send_contact_request(
2753 &self,
2754 requester_id: UserId,
2755 responder_id: UserId,
2756 ) -> Result<()> {
2757 self.background.simulate_random_delay().await;
2758 let mut contacts = self.contacts.lock();
2759 for contact in contacts.iter_mut() {
2760 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2761 if contact.accepted {
2762 Err(anyhow!("contact already exists"))?;
2763 } else {
2764 Err(anyhow!("contact already requested"))?;
2765 }
2766 }
2767 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2768 if contact.accepted {
2769 Err(anyhow!("contact already exists"))?;
2770 } else {
2771 contact.accepted = true;
2772 contact.should_notify = false;
2773 return Ok(());
2774 }
2775 }
2776 }
2777 contacts.push(FakeContact {
2778 requester_id,
2779 responder_id,
2780 accepted: false,
2781 should_notify: true,
2782 });
2783 Ok(())
2784 }
2785
2786 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2787 self.background.simulate_random_delay().await;
2788 self.contacts.lock().retain(|contact| {
2789 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2790 });
2791 Ok(())
2792 }
2793
2794 async fn dismiss_contact_notification(
2795 &self,
2796 user_id: UserId,
2797 contact_user_id: UserId,
2798 ) -> Result<()> {
2799 self.background.simulate_random_delay().await;
2800 let mut contacts = self.contacts.lock();
2801 for contact in contacts.iter_mut() {
2802 if contact.requester_id == contact_user_id
2803 && contact.responder_id == user_id
2804 && !contact.accepted
2805 {
2806 contact.should_notify = false;
2807 return Ok(());
2808 }
2809 if contact.requester_id == user_id
2810 && contact.responder_id == contact_user_id
2811 && contact.accepted
2812 {
2813 contact.should_notify = false;
2814 return Ok(());
2815 }
2816 }
2817 Err(anyhow!("no such notification"))?
2818 }
2819
2820 async fn respond_to_contact_request(
2821 &self,
2822 responder_id: UserId,
2823 requester_id: UserId,
2824 accept: bool,
2825 ) -> Result<()> {
2826 self.background.simulate_random_delay().await;
2827 let mut contacts = self.contacts.lock();
2828 for (ix, contact) in contacts.iter_mut().enumerate() {
2829 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2830 if contact.accepted {
2831 Err(anyhow!("contact already confirmed"))?;
2832 }
2833 if accept {
2834 contact.accepted = true;
2835 contact.should_notify = true;
2836 } else {
2837 contacts.remove(ix);
2838 }
2839 return Ok(());
2840 }
2841 }
2842 Err(anyhow!("no such contact request"))?
2843 }
2844
2845 async fn create_access_token_hash(
2846 &self,
2847 _user_id: UserId,
2848 _access_token_hash: &str,
2849 _max_access_token_count: usize,
2850 ) -> Result<()> {
2851 unimplemented!()
2852 }
2853
2854 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2855 unimplemented!()
2856 }
2857
2858 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2859 unimplemented!()
2860 }
2861
2862 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2863 self.background.simulate_random_delay().await;
2864 let mut orgs = self.orgs.lock();
2865 if orgs.values().any(|org| org.slug == slug) {
2866 Err(anyhow!("org already exists"))?
2867 } else {
2868 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2869 orgs.insert(
2870 org_id,
2871 Org {
2872 id: org_id,
2873 name: name.to_string(),
2874 slug: slug.to_string(),
2875 },
2876 );
2877 Ok(org_id)
2878 }
2879 }
2880
2881 async fn add_org_member(
2882 &self,
2883 org_id: OrgId,
2884 user_id: UserId,
2885 is_admin: bool,
2886 ) -> Result<()> {
2887 self.background.simulate_random_delay().await;
2888 if !self.orgs.lock().contains_key(&org_id) {
2889 Err(anyhow!("org does not exist"))?;
2890 }
2891 if !self.users.lock().contains_key(&user_id) {
2892 Err(anyhow!("user does not exist"))?;
2893 }
2894
2895 self.org_memberships
2896 .lock()
2897 .entry((org_id, user_id))
2898 .or_insert(is_admin);
2899 Ok(())
2900 }
2901
2902 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2903 self.background.simulate_random_delay().await;
2904 if !self.orgs.lock().contains_key(&org_id) {
2905 Err(anyhow!("org does not exist"))?;
2906 }
2907
2908 let mut channels = self.channels.lock();
2909 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2910 channels.insert(
2911 channel_id,
2912 Channel {
2913 id: channel_id,
2914 name: name.to_string(),
2915 owner_id: org_id.0,
2916 owner_is_user: false,
2917 },
2918 );
2919 Ok(channel_id)
2920 }
2921
2922 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2923 self.background.simulate_random_delay().await;
2924 Ok(self
2925 .channels
2926 .lock()
2927 .values()
2928 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2929 .cloned()
2930 .collect())
2931 }
2932
2933 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2934 self.background.simulate_random_delay().await;
2935 let channels = self.channels.lock();
2936 let memberships = self.channel_memberships.lock();
2937 Ok(channels
2938 .values()
2939 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2940 .cloned()
2941 .collect())
2942 }
2943
2944 async fn can_user_access_channel(
2945 &self,
2946 user_id: UserId,
2947 channel_id: ChannelId,
2948 ) -> Result<bool> {
2949 self.background.simulate_random_delay().await;
2950 Ok(self
2951 .channel_memberships
2952 .lock()
2953 .contains_key(&(channel_id, user_id)))
2954 }
2955
2956 async fn add_channel_member(
2957 &self,
2958 channel_id: ChannelId,
2959 user_id: UserId,
2960 is_admin: bool,
2961 ) -> Result<()> {
2962 self.background.simulate_random_delay().await;
2963 if !self.channels.lock().contains_key(&channel_id) {
2964 Err(anyhow!("channel does not exist"))?;
2965 }
2966 if !self.users.lock().contains_key(&user_id) {
2967 Err(anyhow!("user does not exist"))?;
2968 }
2969
2970 self.channel_memberships
2971 .lock()
2972 .entry((channel_id, user_id))
2973 .or_insert(is_admin);
2974 Ok(())
2975 }
2976
2977 async fn create_channel_message(
2978 &self,
2979 channel_id: ChannelId,
2980 sender_id: UserId,
2981 body: &str,
2982 timestamp: OffsetDateTime,
2983 nonce: u128,
2984 ) -> Result<MessageId> {
2985 self.background.simulate_random_delay().await;
2986 if !self.channels.lock().contains_key(&channel_id) {
2987 Err(anyhow!("channel does not exist"))?;
2988 }
2989 if !self.users.lock().contains_key(&sender_id) {
2990 Err(anyhow!("user does not exist"))?;
2991 }
2992
2993 let mut messages = self.channel_messages.lock();
2994 if let Some(message) = messages
2995 .values()
2996 .find(|message| message.nonce.as_u128() == nonce)
2997 {
2998 Ok(message.id)
2999 } else {
3000 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
3001 messages.insert(
3002 message_id,
3003 ChannelMessage {
3004 id: message_id,
3005 channel_id,
3006 sender_id,
3007 body: body.to_string(),
3008 sent_at: timestamp,
3009 nonce: Uuid::from_u128(nonce),
3010 },
3011 );
3012 Ok(message_id)
3013 }
3014 }
3015
3016 async fn get_channel_messages(
3017 &self,
3018 channel_id: ChannelId,
3019 count: usize,
3020 before_id: Option<MessageId>,
3021 ) -> Result<Vec<ChannelMessage>> {
3022 self.background.simulate_random_delay().await;
3023 let mut messages = self
3024 .channel_messages
3025 .lock()
3026 .values()
3027 .rev()
3028 .filter(|message| {
3029 message.channel_id == channel_id
3030 && message.id < before_id.unwrap_or(MessageId::MAX)
3031 })
3032 .take(count)
3033 .cloned()
3034 .collect::<Vec<_>>();
3035 messages.sort_unstable_by_key(|message| message.id);
3036 Ok(messages)
3037 }
3038
3039 async fn teardown(&self, _: &str) {}
3040
3041 #[cfg(test)]
3042 fn as_fake(&self) -> Option<&FakeDb> {
3043 Some(self)
3044 }
3045 }
3046
3047 fn build_background_executor() -> Arc<Background> {
3048 Deterministic::new(0).build_background()
3049 }
3050}