1use std::{cmp, ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use serde::{Deserialize, Serialize};
10pub use sqlx::postgres::PgPoolOptions as DbOptions;
11use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
12use time::{OffsetDateTime, PrimitiveDateTime};
13
14#[async_trait]
15pub trait Db: Send + Sync {
16 async fn create_user(
17 &self,
18 github_login: &str,
19 email_address: Option<&str>,
20 admin: bool,
21 ) -> Result<UserId>;
22 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
23 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
24 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
25 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
26 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
27 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
28 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
29 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
30 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
31 async fn destroy_user(&self, id: UserId) -> Result<()>;
32
33 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
34 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
35 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
36 async fn redeem_invite_code(
37 &self,
38 code: &str,
39 login: &str,
40 email_address: Option<&str>,
41 ) -> Result<UserId>;
42
43 /// Registers a new project for the given user.
44 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
45
46 /// Unregisters a project for the given project id.
47 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
48
49 /// Update file counts by extension for the given project and worktree.
50 async fn update_worktree_extensions(
51 &self,
52 project_id: ProjectId,
53 worktree_id: u64,
54 extensions: HashMap<String, u32>,
55 ) -> Result<()>;
56
57 /// Get the file counts on the given project keyed by their worktree and extension.
58 async fn get_project_extensions(
59 &self,
60 project_id: ProjectId,
61 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
62
63 /// Record which users have been active in which projects during
64 /// a given period of time.
65 async fn record_user_activity(
66 &self,
67 time_period: Range<OffsetDateTime>,
68 active_projects: &[(UserId, ProjectId)],
69 ) -> Result<()>;
70
71 /// Get the number of users who have been active in the given
72 /// time period for at least the given time duration.
73 async fn get_active_user_count(
74 &self,
75 time_period: Range<OffsetDateTime>,
76 min_duration: Duration,
77 only_collaborative: bool,
78 ) -> Result<usize>;
79
80 /// Get the users that have been most active during the given time period,
81 /// along with the amount of time they have been active in each project.
82 async fn get_top_users_activity_summary(
83 &self,
84 time_period: Range<OffsetDateTime>,
85 max_user_count: usize,
86 ) -> Result<Vec<UserActivitySummary>>;
87
88 /// Get the project activity for the given user and time period.
89 async fn get_user_activity_timeline(
90 &self,
91 time_period: Range<OffsetDateTime>,
92 user_id: UserId,
93 ) -> Result<Vec<UserActivityPeriod>>;
94
95 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
96 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
97 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
98 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
99 async fn dismiss_contact_notification(
100 &self,
101 responder_id: UserId,
102 requester_id: UserId,
103 ) -> Result<()>;
104 async fn respond_to_contact_request(
105 &self,
106 responder_id: UserId,
107 requester_id: UserId,
108 accept: bool,
109 ) -> Result<()>;
110
111 async fn create_access_token_hash(
112 &self,
113 user_id: UserId,
114 access_token_hash: &str,
115 max_access_token_count: usize,
116 ) -> Result<()>;
117 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
118 #[cfg(any(test, feature = "seed-support"))]
119
120 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
121 #[cfg(any(test, feature = "seed-support"))]
122 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
123 #[cfg(any(test, feature = "seed-support"))]
124 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
125 #[cfg(any(test, feature = "seed-support"))]
126 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
127 #[cfg(any(test, feature = "seed-support"))]
128
129 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
130 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
131 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
132 -> Result<bool>;
133 #[cfg(any(test, feature = "seed-support"))]
134 async fn add_channel_member(
135 &self,
136 channel_id: ChannelId,
137 user_id: UserId,
138 is_admin: bool,
139 ) -> Result<()>;
140 async fn create_channel_message(
141 &self,
142 channel_id: ChannelId,
143 sender_id: UserId,
144 body: &str,
145 timestamp: OffsetDateTime,
146 nonce: u128,
147 ) -> Result<MessageId>;
148 async fn get_channel_messages(
149 &self,
150 channel_id: ChannelId,
151 count: usize,
152 before_id: Option<MessageId>,
153 ) -> Result<Vec<ChannelMessage>>;
154 #[cfg(test)]
155 async fn teardown(&self, url: &str);
156 #[cfg(test)]
157 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
158}
159
160pub struct PostgresDb {
161 pool: sqlx::PgPool,
162}
163
164impl PostgresDb {
165 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
166 let pool = DbOptions::new()
167 .max_connections(max_connections)
168 .connect(&url)
169 .await
170 .context("failed to connect to postgres database")?;
171 Ok(Self { pool })
172 }
173}
174
175#[async_trait]
176impl Db for PostgresDb {
177 // users
178
179 async fn create_user(
180 &self,
181 github_login: &str,
182 email_address: Option<&str>,
183 admin: bool,
184 ) -> Result<UserId> {
185 let query = "
186 INSERT INTO users (github_login, email_address, admin)
187 VALUES ($1, $2, $3)
188 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
189 RETURNING id
190 ";
191 Ok(sqlx::query_scalar(query)
192 .bind(github_login)
193 .bind(email_address)
194 .bind(admin)
195 .fetch_one(&self.pool)
196 .await
197 .map(UserId)?)
198 }
199
200 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
201 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
202 Ok(sqlx::query_as(query)
203 .bind(limit as i32)
204 .bind((page * limit) as i32)
205 .fetch_all(&self.pool)
206 .await?)
207 }
208
209 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
210 let mut query = QueryBuilder::new(
211 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
212 );
213 query.push_values(
214 users,
215 |mut query, (github_login, email_address, invite_count)| {
216 query
217 .push_bind(github_login)
218 .push_bind(email_address)
219 .push_bind(false)
220 .push_bind(random_invite_code())
221 .push_bind(invite_count as i32);
222 },
223 );
224 query.push(
225 "
226 ON CONFLICT (github_login) DO UPDATE SET
227 github_login = excluded.github_login,
228 invite_count = excluded.invite_count,
229 invite_code = CASE WHEN users.invite_code IS NULL
230 THEN excluded.invite_code
231 ELSE users.invite_code
232 END
233 RETURNING id
234 ",
235 );
236
237 let rows = query.build().fetch_all(&self.pool).await?;
238 Ok(rows
239 .into_iter()
240 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
241 .collect())
242 }
243
244 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
245 let like_string = fuzzy_like_string(name_query);
246 let query = "
247 SELECT users.*
248 FROM users
249 WHERE github_login ILIKE $1
250 ORDER BY github_login <-> $2
251 LIMIT $3
252 ";
253 Ok(sqlx::query_as(query)
254 .bind(like_string)
255 .bind(name_query)
256 .bind(limit as i32)
257 .fetch_all(&self.pool)
258 .await?)
259 }
260
261 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
262 let users = self.get_users_by_ids(vec![id]).await?;
263 Ok(users.into_iter().next())
264 }
265
266 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
267 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
268 let query = "
269 SELECT users.*
270 FROM users
271 WHERE users.id = ANY ($1)
272 ";
273 Ok(sqlx::query_as(query)
274 .bind(&ids)
275 .fetch_all(&self.pool)
276 .await?)
277 }
278
279 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
280 let query = format!(
281 "
282 SELECT users.*
283 FROM users
284 WHERE invite_count = 0
285 AND inviter_id IS{} NULL
286 ",
287 if invited_by_another_user { " NOT" } else { "" }
288 );
289
290 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
291 }
292
293 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
294 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
295 Ok(sqlx::query_as(query)
296 .bind(github_login)
297 .fetch_optional(&self.pool)
298 .await?)
299 }
300
301 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
302 let query = "UPDATE users SET admin = $1 WHERE id = $2";
303 Ok(sqlx::query(query)
304 .bind(is_admin)
305 .bind(id.0)
306 .execute(&self.pool)
307 .await
308 .map(drop)?)
309 }
310
311 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
312 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
313 Ok(sqlx::query(query)
314 .bind(connected_once)
315 .bind(id.0)
316 .execute(&self.pool)
317 .await
318 .map(drop)?)
319 }
320
321 async fn destroy_user(&self, id: UserId) -> Result<()> {
322 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
323 sqlx::query(query)
324 .bind(id.0)
325 .execute(&self.pool)
326 .await
327 .map(drop)?;
328 let query = "DELETE FROM users WHERE id = $1;";
329 Ok(sqlx::query(query)
330 .bind(id.0)
331 .execute(&self.pool)
332 .await
333 .map(drop)?)
334 }
335
336 // invite codes
337
338 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
339 let mut tx = self.pool.begin().await?;
340 if count > 0 {
341 sqlx::query(
342 "
343 UPDATE users
344 SET invite_code = $1
345 WHERE id = $2 AND invite_code IS NULL
346 ",
347 )
348 .bind(random_invite_code())
349 .bind(id)
350 .execute(&mut tx)
351 .await?;
352 }
353
354 sqlx::query(
355 "
356 UPDATE users
357 SET invite_count = $1
358 WHERE id = $2
359 ",
360 )
361 .bind(count as i32)
362 .bind(id)
363 .execute(&mut tx)
364 .await?;
365 tx.commit().await?;
366 Ok(())
367 }
368
369 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
370 let result: Option<(String, i32)> = sqlx::query_as(
371 "
372 SELECT invite_code, invite_count
373 FROM users
374 WHERE id = $1 AND invite_code IS NOT NULL
375 ",
376 )
377 .bind(id)
378 .fetch_optional(&self.pool)
379 .await?;
380 if let Some((code, count)) = result {
381 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
382 } else {
383 Ok(None)
384 }
385 }
386
387 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
388 sqlx::query_as(
389 "
390 SELECT *
391 FROM users
392 WHERE invite_code = $1
393 ",
394 )
395 .bind(code)
396 .fetch_optional(&self.pool)
397 .await?
398 .ok_or_else(|| {
399 Error::Http(
400 StatusCode::NOT_FOUND,
401 "that invite code does not exist".to_string(),
402 )
403 })
404 }
405
406 async fn redeem_invite_code(
407 &self,
408 code: &str,
409 login: &str,
410 email_address: Option<&str>,
411 ) -> Result<UserId> {
412 let mut tx = self.pool.begin().await?;
413
414 let inviter_id: Option<UserId> = sqlx::query_scalar(
415 "
416 UPDATE users
417 SET invite_count = invite_count - 1
418 WHERE
419 invite_code = $1 AND
420 invite_count > 0
421 RETURNING id
422 ",
423 )
424 .bind(code)
425 .fetch_optional(&mut tx)
426 .await?;
427
428 let inviter_id = match inviter_id {
429 Some(inviter_id) => inviter_id,
430 None => {
431 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
432 .bind(code)
433 .fetch_optional(&mut tx)
434 .await?
435 .is_some()
436 {
437 Err(Error::Http(
438 StatusCode::UNAUTHORIZED,
439 "no invites remaining".to_string(),
440 ))?
441 } else {
442 Err(Error::Http(
443 StatusCode::NOT_FOUND,
444 "invite code not found".to_string(),
445 ))?
446 }
447 }
448 };
449
450 let invitee_id = sqlx::query_scalar(
451 "
452 INSERT INTO users
453 (github_login, email_address, admin, inviter_id, invite_code, invite_count)
454 VALUES
455 ($1, $2, 'f', $3, $4, $5)
456 RETURNING id
457 ",
458 )
459 .bind(login)
460 .bind(email_address)
461 .bind(inviter_id)
462 .bind(random_invite_code())
463 .bind(5)
464 .fetch_one(&mut tx)
465 .await
466 .map(UserId)?;
467
468 sqlx::query(
469 "
470 INSERT INTO contacts
471 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
472 VALUES
473 ($1, $2, 't', 't', 't')
474 ",
475 )
476 .bind(inviter_id)
477 .bind(invitee_id)
478 .execute(&mut tx)
479 .await?;
480
481 tx.commit().await?;
482 Ok(invitee_id)
483 }
484
485 // projects
486
487 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
488 Ok(sqlx::query_scalar(
489 "
490 INSERT INTO projects(host_user_id)
491 VALUES ($1)
492 RETURNING id
493 ",
494 )
495 .bind(host_user_id)
496 .fetch_one(&self.pool)
497 .await
498 .map(ProjectId)?)
499 }
500
501 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
502 sqlx::query(
503 "
504 UPDATE projects
505 SET unregistered = 't'
506 WHERE id = $1
507 ",
508 )
509 .bind(project_id)
510 .execute(&self.pool)
511 .await?;
512 Ok(())
513 }
514
515 async fn update_worktree_extensions(
516 &self,
517 project_id: ProjectId,
518 worktree_id: u64,
519 extensions: HashMap<String, u32>,
520 ) -> Result<()> {
521 if extensions.is_empty() {
522 return Ok(());
523 }
524
525 let mut query = QueryBuilder::new(
526 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
527 );
528 query.push_values(extensions, |mut query, (extension, count)| {
529 query
530 .push_bind(project_id)
531 .push_bind(worktree_id as i32)
532 .push_bind(extension)
533 .push_bind(count as i32);
534 });
535 query.push(
536 "
537 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
538 count = excluded.count
539 ",
540 );
541 query.build().execute(&self.pool).await?;
542
543 Ok(())
544 }
545
546 async fn get_project_extensions(
547 &self,
548 project_id: ProjectId,
549 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
550 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
551 struct WorktreeExtension {
552 worktree_id: i32,
553 extension: String,
554 count: i32,
555 }
556
557 let query = "
558 SELECT worktree_id, extension, count
559 FROM worktree_extensions
560 WHERE project_id = $1
561 ";
562 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
563 .bind(&project_id)
564 .fetch_all(&self.pool)
565 .await?;
566
567 let mut extension_counts = HashMap::default();
568 for count in counts {
569 extension_counts
570 .entry(count.worktree_id as u64)
571 .or_insert(HashMap::default())
572 .insert(count.extension, count.count as usize);
573 }
574 Ok(extension_counts)
575 }
576
577 async fn record_user_activity(
578 &self,
579 time_period: Range<OffsetDateTime>,
580 projects: &[(UserId, ProjectId)],
581 ) -> Result<()> {
582 let query = "
583 INSERT INTO project_activity_periods
584 (ended_at, duration_millis, user_id, project_id)
585 VALUES
586 ($1, $2, $3, $4);
587 ";
588
589 let mut tx = self.pool.begin().await?;
590 let duration_millis =
591 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
592 for (user_id, project_id) in projects {
593 sqlx::query(query)
594 .bind(time_period.end)
595 .bind(duration_millis)
596 .bind(user_id)
597 .bind(project_id)
598 .execute(&mut tx)
599 .await?;
600 }
601 tx.commit().await?;
602
603 Ok(())
604 }
605
606 async fn get_active_user_count(
607 &self,
608 time_period: Range<OffsetDateTime>,
609 min_duration: Duration,
610 only_collaborative: bool,
611 ) -> Result<usize> {
612 let mut with_clause = String::new();
613 with_clause.push_str("WITH\n");
614 with_clause.push_str(
615 "
616 project_durations AS (
617 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
618 FROM project_activity_periods
619 WHERE $1 < ended_at AND ended_at <= $2
620 GROUP BY user_id, project_id
621 ),
622 ",
623 );
624 with_clause.push_str(
625 "
626 project_collaborators as (
627 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
628 FROM project_durations
629 GROUP BY project_id
630 ),
631 ",
632 );
633
634 if only_collaborative {
635 with_clause.push_str(
636 "
637 user_durations AS (
638 SELECT user_id, SUM(project_duration) as total_duration
639 FROM project_durations, project_collaborators
640 WHERE
641 project_durations.project_id = project_collaborators.project_id AND
642 max_collaborators > 1
643 GROUP BY user_id
644 ORDER BY total_duration DESC
645 LIMIT $3
646 )
647 ",
648 );
649 } else {
650 with_clause.push_str(
651 "
652 user_durations AS (
653 SELECT user_id, SUM(project_duration) as total_duration
654 FROM project_durations
655 GROUP BY user_id
656 ORDER BY total_duration DESC
657 LIMIT $3
658 )
659 ",
660 );
661 }
662
663 let query = format!(
664 "
665 {with_clause}
666 SELECT count(user_durations.user_id)
667 FROM user_durations
668 WHERE user_durations.total_duration >= $3
669 "
670 );
671
672 let count: i64 = sqlx::query_scalar(&query)
673 .bind(time_period.start)
674 .bind(time_period.end)
675 .bind(min_duration.as_millis() as i64)
676 .fetch_one(&self.pool)
677 .await?;
678 Ok(count as usize)
679 }
680
681 async fn get_top_users_activity_summary(
682 &self,
683 time_period: Range<OffsetDateTime>,
684 max_user_count: usize,
685 ) -> Result<Vec<UserActivitySummary>> {
686 let query = "
687 WITH
688 project_durations AS (
689 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
690 FROM project_activity_periods
691 WHERE $1 < ended_at AND ended_at <= $2
692 GROUP BY user_id, project_id
693 ),
694 user_durations AS (
695 SELECT user_id, SUM(project_duration) as total_duration
696 FROM project_durations
697 GROUP BY user_id
698 ORDER BY total_duration DESC
699 LIMIT $3
700 ),
701 project_collaborators as (
702 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
703 FROM project_durations
704 GROUP BY project_id
705 )
706 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
707 FROM user_durations, project_durations, project_collaborators, users
708 WHERE
709 user_durations.user_id = project_durations.user_id AND
710 user_durations.user_id = users.id AND
711 project_durations.project_id = project_collaborators.project_id
712 ORDER BY total_duration DESC, user_id ASC
713 ";
714
715 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
716 .bind(time_period.start)
717 .bind(time_period.end)
718 .bind(max_user_count as i32)
719 .fetch(&self.pool);
720
721 let mut result = Vec::<UserActivitySummary>::new();
722 while let Some(row) = rows.next().await {
723 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
724 let project_id = project_id;
725 let duration = Duration::from_millis(duration_millis as u64);
726 let project_activity = ProjectActivitySummary {
727 id: project_id,
728 duration,
729 max_collaborators: project_collaborators as usize,
730 };
731 if let Some(last_summary) = result.last_mut() {
732 if last_summary.id == user_id {
733 last_summary.project_activity.push(project_activity);
734 continue;
735 }
736 }
737 result.push(UserActivitySummary {
738 id: user_id,
739 project_activity: vec![project_activity],
740 github_login,
741 });
742 }
743
744 Ok(result)
745 }
746
747 async fn get_user_activity_timeline(
748 &self,
749 time_period: Range<OffsetDateTime>,
750 user_id: UserId,
751 ) -> Result<Vec<UserActivityPeriod>> {
752 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
753
754 let query = "
755 SELECT
756 project_activity_periods.ended_at,
757 project_activity_periods.duration_millis,
758 project_activity_periods.project_id,
759 worktree_extensions.extension,
760 worktree_extensions.count
761 FROM project_activity_periods
762 LEFT OUTER JOIN
763 worktree_extensions
764 ON
765 project_activity_periods.project_id = worktree_extensions.project_id
766 WHERE
767 project_activity_periods.user_id = $1 AND
768 $2 < project_activity_periods.ended_at AND
769 project_activity_periods.ended_at <= $3
770 ORDER BY project_activity_periods.id ASC
771 ";
772
773 let mut rows = sqlx::query_as::<
774 _,
775 (
776 PrimitiveDateTime,
777 i32,
778 ProjectId,
779 Option<String>,
780 Option<i32>,
781 ),
782 >(query)
783 .bind(user_id)
784 .bind(time_period.start)
785 .bind(time_period.end)
786 .fetch(&self.pool);
787
788 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
789 while let Some(row) = rows.next().await {
790 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
791 let ended_at = ended_at.assume_utc();
792 let duration = Duration::from_millis(duration_millis as u64);
793 let started_at = ended_at - duration;
794 let project_time_periods = time_periods.entry(project_id).or_default();
795
796 if let Some(prev_duration) = project_time_periods.last_mut() {
797 if started_at <= prev_duration.end + COALESCE_THRESHOLD
798 && ended_at >= prev_duration.start
799 {
800 prev_duration.end = cmp::max(prev_duration.end, ended_at);
801 } else {
802 project_time_periods.push(UserActivityPeriod {
803 project_id,
804 start: started_at,
805 end: ended_at,
806 extensions: Default::default(),
807 });
808 }
809 } else {
810 project_time_periods.push(UserActivityPeriod {
811 project_id,
812 start: started_at,
813 end: ended_at,
814 extensions: Default::default(),
815 });
816 }
817
818 if let Some((extension, extension_count)) = extension.zip(extension_count) {
819 project_time_periods
820 .last_mut()
821 .unwrap()
822 .extensions
823 .insert(extension, extension_count as usize);
824 }
825 }
826
827 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
828 durations.sort_unstable_by_key(|duration| duration.start);
829 Ok(durations)
830 }
831
832 // contacts
833
834 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
835 let query = "
836 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
837 FROM contacts
838 WHERE user_id_a = $1 OR user_id_b = $1;
839 ";
840
841 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
842 .bind(user_id)
843 .fetch(&self.pool);
844
845 let mut contacts = vec![Contact::Accepted {
846 user_id,
847 should_notify: false,
848 }];
849 while let Some(row) = rows.next().await {
850 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
851
852 if user_id_a == user_id {
853 if accepted {
854 contacts.push(Contact::Accepted {
855 user_id: user_id_b,
856 should_notify: should_notify && a_to_b,
857 });
858 } else if a_to_b {
859 contacts.push(Contact::Outgoing { user_id: user_id_b })
860 } else {
861 contacts.push(Contact::Incoming {
862 user_id: user_id_b,
863 should_notify,
864 });
865 }
866 } else {
867 if accepted {
868 contacts.push(Contact::Accepted {
869 user_id: user_id_a,
870 should_notify: should_notify && !a_to_b,
871 });
872 } else if a_to_b {
873 contacts.push(Contact::Incoming {
874 user_id: user_id_a,
875 should_notify,
876 });
877 } else {
878 contacts.push(Contact::Outgoing { user_id: user_id_a });
879 }
880 }
881 }
882
883 contacts.sort_unstable_by_key(|contact| contact.user_id());
884
885 Ok(contacts)
886 }
887
888 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
889 let (id_a, id_b) = if user_id_1 < user_id_2 {
890 (user_id_1, user_id_2)
891 } else {
892 (user_id_2, user_id_1)
893 };
894
895 let query = "
896 SELECT 1 FROM contacts
897 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
898 LIMIT 1
899 ";
900 Ok(sqlx::query_scalar::<_, i32>(query)
901 .bind(id_a.0)
902 .bind(id_b.0)
903 .fetch_optional(&self.pool)
904 .await?
905 .is_some())
906 }
907
908 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
909 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
910 (sender_id, receiver_id, true)
911 } else {
912 (receiver_id, sender_id, false)
913 };
914 let query = "
915 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
916 VALUES ($1, $2, $3, 'f', 't')
917 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
918 SET
919 accepted = 't',
920 should_notify = 'f'
921 WHERE
922 NOT contacts.accepted AND
923 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
924 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
925 ";
926 let result = sqlx::query(query)
927 .bind(id_a.0)
928 .bind(id_b.0)
929 .bind(a_to_b)
930 .execute(&self.pool)
931 .await?;
932
933 if result.rows_affected() == 1 {
934 Ok(())
935 } else {
936 Err(anyhow!("contact already requested"))?
937 }
938 }
939
940 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
941 let (id_a, id_b) = if responder_id < requester_id {
942 (responder_id, requester_id)
943 } else {
944 (requester_id, responder_id)
945 };
946 let query = "
947 DELETE FROM contacts
948 WHERE user_id_a = $1 AND user_id_b = $2;
949 ";
950 let result = sqlx::query(query)
951 .bind(id_a.0)
952 .bind(id_b.0)
953 .execute(&self.pool)
954 .await?;
955
956 if result.rows_affected() == 1 {
957 Ok(())
958 } else {
959 Err(anyhow!("no such contact"))?
960 }
961 }
962
963 async fn dismiss_contact_notification(
964 &self,
965 user_id: UserId,
966 contact_user_id: UserId,
967 ) -> Result<()> {
968 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
969 (user_id, contact_user_id, true)
970 } else {
971 (contact_user_id, user_id, false)
972 };
973
974 let query = "
975 UPDATE contacts
976 SET should_notify = 'f'
977 WHERE
978 user_id_a = $1 AND user_id_b = $2 AND
979 (
980 (a_to_b = $3 AND accepted) OR
981 (a_to_b != $3 AND NOT accepted)
982 );
983 ";
984
985 let result = sqlx::query(query)
986 .bind(id_a.0)
987 .bind(id_b.0)
988 .bind(a_to_b)
989 .execute(&self.pool)
990 .await?;
991
992 if result.rows_affected() == 0 {
993 Err(anyhow!("no such contact request"))?;
994 }
995
996 Ok(())
997 }
998
999 async fn respond_to_contact_request(
1000 &self,
1001 responder_id: UserId,
1002 requester_id: UserId,
1003 accept: bool,
1004 ) -> Result<()> {
1005 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1006 (responder_id, requester_id, false)
1007 } else {
1008 (requester_id, responder_id, true)
1009 };
1010 let result = if accept {
1011 let query = "
1012 UPDATE contacts
1013 SET accepted = 't', should_notify = 't'
1014 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1015 ";
1016 sqlx::query(query)
1017 .bind(id_a.0)
1018 .bind(id_b.0)
1019 .bind(a_to_b)
1020 .execute(&self.pool)
1021 .await?
1022 } else {
1023 let query = "
1024 DELETE FROM contacts
1025 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1026 ";
1027 sqlx::query(query)
1028 .bind(id_a.0)
1029 .bind(id_b.0)
1030 .bind(a_to_b)
1031 .execute(&self.pool)
1032 .await?
1033 };
1034 if result.rows_affected() == 1 {
1035 Ok(())
1036 } else {
1037 Err(anyhow!("no such contact request"))?
1038 }
1039 }
1040
1041 // access tokens
1042
1043 async fn create_access_token_hash(
1044 &self,
1045 user_id: UserId,
1046 access_token_hash: &str,
1047 max_access_token_count: usize,
1048 ) -> Result<()> {
1049 let insert_query = "
1050 INSERT INTO access_tokens (user_id, hash)
1051 VALUES ($1, $2);
1052 ";
1053 let cleanup_query = "
1054 DELETE FROM access_tokens
1055 WHERE id IN (
1056 SELECT id from access_tokens
1057 WHERE user_id = $1
1058 ORDER BY id DESC
1059 OFFSET $3
1060 )
1061 ";
1062
1063 let mut tx = self.pool.begin().await?;
1064 sqlx::query(insert_query)
1065 .bind(user_id.0)
1066 .bind(access_token_hash)
1067 .execute(&mut tx)
1068 .await?;
1069 sqlx::query(cleanup_query)
1070 .bind(user_id.0)
1071 .bind(access_token_hash)
1072 .bind(max_access_token_count as i32)
1073 .execute(&mut tx)
1074 .await?;
1075 Ok(tx.commit().await?)
1076 }
1077
1078 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1079 let query = "
1080 SELECT hash
1081 FROM access_tokens
1082 WHERE user_id = $1
1083 ORDER BY id DESC
1084 ";
1085 Ok(sqlx::query_scalar(query)
1086 .bind(user_id.0)
1087 .fetch_all(&self.pool)
1088 .await?)
1089 }
1090
1091 // orgs
1092
1093 #[allow(unused)] // Help rust-analyzer
1094 #[cfg(any(test, feature = "seed-support"))]
1095 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1096 let query = "
1097 SELECT *
1098 FROM orgs
1099 WHERE slug = $1
1100 ";
1101 Ok(sqlx::query_as(query)
1102 .bind(slug)
1103 .fetch_optional(&self.pool)
1104 .await?)
1105 }
1106
1107 #[cfg(any(test, feature = "seed-support"))]
1108 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1109 let query = "
1110 INSERT INTO orgs (name, slug)
1111 VALUES ($1, $2)
1112 RETURNING id
1113 ";
1114 Ok(sqlx::query_scalar(query)
1115 .bind(name)
1116 .bind(slug)
1117 .fetch_one(&self.pool)
1118 .await
1119 .map(OrgId)?)
1120 }
1121
1122 #[cfg(any(test, feature = "seed-support"))]
1123 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1124 let query = "
1125 INSERT INTO org_memberships (org_id, user_id, admin)
1126 VALUES ($1, $2, $3)
1127 ON CONFLICT DO NOTHING
1128 ";
1129 Ok(sqlx::query(query)
1130 .bind(org_id.0)
1131 .bind(user_id.0)
1132 .bind(is_admin)
1133 .execute(&self.pool)
1134 .await
1135 .map(drop)?)
1136 }
1137
1138 // channels
1139
1140 #[cfg(any(test, feature = "seed-support"))]
1141 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1142 let query = "
1143 INSERT INTO channels (owner_id, owner_is_user, name)
1144 VALUES ($1, false, $2)
1145 RETURNING id
1146 ";
1147 Ok(sqlx::query_scalar(query)
1148 .bind(org_id.0)
1149 .bind(name)
1150 .fetch_one(&self.pool)
1151 .await
1152 .map(ChannelId)?)
1153 }
1154
1155 #[allow(unused)] // Help rust-analyzer
1156 #[cfg(any(test, feature = "seed-support"))]
1157 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1158 let query = "
1159 SELECT *
1160 FROM channels
1161 WHERE
1162 channels.owner_is_user = false AND
1163 channels.owner_id = $1
1164 ";
1165 Ok(sqlx::query_as(query)
1166 .bind(org_id.0)
1167 .fetch_all(&self.pool)
1168 .await?)
1169 }
1170
1171 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1172 let query = "
1173 SELECT
1174 channels.*
1175 FROM
1176 channel_memberships, channels
1177 WHERE
1178 channel_memberships.user_id = $1 AND
1179 channel_memberships.channel_id = channels.id
1180 ";
1181 Ok(sqlx::query_as(query)
1182 .bind(user_id.0)
1183 .fetch_all(&self.pool)
1184 .await?)
1185 }
1186
1187 async fn can_user_access_channel(
1188 &self,
1189 user_id: UserId,
1190 channel_id: ChannelId,
1191 ) -> Result<bool> {
1192 let query = "
1193 SELECT id
1194 FROM channel_memberships
1195 WHERE user_id = $1 AND channel_id = $2
1196 LIMIT 1
1197 ";
1198 Ok(sqlx::query_scalar::<_, i32>(query)
1199 .bind(user_id.0)
1200 .bind(channel_id.0)
1201 .fetch_optional(&self.pool)
1202 .await
1203 .map(|e| e.is_some())?)
1204 }
1205
1206 #[cfg(any(test, feature = "seed-support"))]
1207 async fn add_channel_member(
1208 &self,
1209 channel_id: ChannelId,
1210 user_id: UserId,
1211 is_admin: bool,
1212 ) -> Result<()> {
1213 let query = "
1214 INSERT INTO channel_memberships (channel_id, user_id, admin)
1215 VALUES ($1, $2, $3)
1216 ON CONFLICT DO NOTHING
1217 ";
1218 Ok(sqlx::query(query)
1219 .bind(channel_id.0)
1220 .bind(user_id.0)
1221 .bind(is_admin)
1222 .execute(&self.pool)
1223 .await
1224 .map(drop)?)
1225 }
1226
1227 // messages
1228
1229 async fn create_channel_message(
1230 &self,
1231 channel_id: ChannelId,
1232 sender_id: UserId,
1233 body: &str,
1234 timestamp: OffsetDateTime,
1235 nonce: u128,
1236 ) -> Result<MessageId> {
1237 let query = "
1238 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1239 VALUES ($1, $2, $3, $4, $5)
1240 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1241 RETURNING id
1242 ";
1243 Ok(sqlx::query_scalar(query)
1244 .bind(channel_id.0)
1245 .bind(sender_id.0)
1246 .bind(body)
1247 .bind(timestamp)
1248 .bind(Uuid::from_u128(nonce))
1249 .fetch_one(&self.pool)
1250 .await
1251 .map(MessageId)?)
1252 }
1253
1254 async fn get_channel_messages(
1255 &self,
1256 channel_id: ChannelId,
1257 count: usize,
1258 before_id: Option<MessageId>,
1259 ) -> Result<Vec<ChannelMessage>> {
1260 let query = r#"
1261 SELECT * FROM (
1262 SELECT
1263 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1264 FROM
1265 channel_messages
1266 WHERE
1267 channel_id = $1 AND
1268 id < $2
1269 ORDER BY id DESC
1270 LIMIT $3
1271 ) as recent_messages
1272 ORDER BY id ASC
1273 "#;
1274 Ok(sqlx::query_as(query)
1275 .bind(channel_id.0)
1276 .bind(before_id.unwrap_or(MessageId::MAX))
1277 .bind(count as i64)
1278 .fetch_all(&self.pool)
1279 .await?)
1280 }
1281
1282 #[cfg(test)]
1283 async fn teardown(&self, url: &str) {
1284 use util::ResultExt;
1285
1286 let query = "
1287 SELECT pg_terminate_backend(pg_stat_activity.pid)
1288 FROM pg_stat_activity
1289 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1290 ";
1291 sqlx::query(query).execute(&self.pool).await.log_err();
1292 self.pool.close().await;
1293 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1294 .await
1295 .log_err();
1296 }
1297
1298 #[cfg(test)]
1299 fn as_fake(&self) -> Option<&tests::FakeDb> {
1300 None
1301 }
1302}
1303
1304macro_rules! id_type {
1305 ($name:ident) => {
1306 #[derive(
1307 Clone,
1308 Copy,
1309 Debug,
1310 Default,
1311 PartialEq,
1312 Eq,
1313 PartialOrd,
1314 Ord,
1315 Hash,
1316 sqlx::Type,
1317 Serialize,
1318 Deserialize,
1319 )]
1320 #[sqlx(transparent)]
1321 #[serde(transparent)]
1322 pub struct $name(pub i32);
1323
1324 impl $name {
1325 #[allow(unused)]
1326 pub const MAX: Self = Self(i32::MAX);
1327
1328 #[allow(unused)]
1329 pub fn from_proto(value: u64) -> Self {
1330 Self(value as i32)
1331 }
1332
1333 #[allow(unused)]
1334 pub fn to_proto(&self) -> u64 {
1335 self.0 as u64
1336 }
1337 }
1338
1339 impl std::fmt::Display for $name {
1340 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1341 self.0.fmt(f)
1342 }
1343 }
1344 };
1345}
1346
1347id_type!(UserId);
1348#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1349pub struct User {
1350 pub id: UserId,
1351 pub github_login: String,
1352 pub email_address: Option<String>,
1353 pub admin: bool,
1354 pub invite_code: Option<String>,
1355 pub invite_count: i32,
1356 pub connected_once: bool,
1357}
1358
1359id_type!(ProjectId);
1360#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1361pub struct Project {
1362 pub id: ProjectId,
1363 pub host_user_id: UserId,
1364 pub unregistered: bool,
1365}
1366
1367#[derive(Clone, Debug, PartialEq, Serialize)]
1368pub struct UserActivitySummary {
1369 pub id: UserId,
1370 pub github_login: String,
1371 pub project_activity: Vec<ProjectActivitySummary>,
1372}
1373
1374#[derive(Clone, Debug, PartialEq, Serialize)]
1375pub struct ProjectActivitySummary {
1376 id: ProjectId,
1377 duration: Duration,
1378 max_collaborators: usize,
1379}
1380
1381#[derive(Clone, Debug, PartialEq, Serialize)]
1382pub struct UserActivityPeriod {
1383 project_id: ProjectId,
1384 #[serde(with = "time::serde::iso8601")]
1385 start: OffsetDateTime,
1386 #[serde(with = "time::serde::iso8601")]
1387 end: OffsetDateTime,
1388 extensions: HashMap<String, usize>,
1389}
1390
1391id_type!(OrgId);
1392#[derive(FromRow)]
1393pub struct Org {
1394 pub id: OrgId,
1395 pub name: String,
1396 pub slug: String,
1397}
1398
1399id_type!(ChannelId);
1400#[derive(Clone, Debug, FromRow, Serialize)]
1401pub struct Channel {
1402 pub id: ChannelId,
1403 pub name: String,
1404 pub owner_id: i32,
1405 pub owner_is_user: bool,
1406}
1407
1408id_type!(MessageId);
1409#[derive(Clone, Debug, FromRow)]
1410pub struct ChannelMessage {
1411 pub id: MessageId,
1412 pub channel_id: ChannelId,
1413 pub sender_id: UserId,
1414 pub body: String,
1415 pub sent_at: OffsetDateTime,
1416 pub nonce: Uuid,
1417}
1418
1419#[derive(Clone, Debug, PartialEq, Eq)]
1420pub enum Contact {
1421 Accepted {
1422 user_id: UserId,
1423 should_notify: bool,
1424 },
1425 Outgoing {
1426 user_id: UserId,
1427 },
1428 Incoming {
1429 user_id: UserId,
1430 should_notify: bool,
1431 },
1432}
1433
1434impl Contact {
1435 pub fn user_id(&self) -> UserId {
1436 match self {
1437 Contact::Accepted { user_id, .. } => *user_id,
1438 Contact::Outgoing { user_id } => *user_id,
1439 Contact::Incoming { user_id, .. } => *user_id,
1440 }
1441 }
1442}
1443
1444#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1445pub struct IncomingContactRequest {
1446 pub requester_id: UserId,
1447 pub should_notify: bool,
1448}
1449
1450fn fuzzy_like_string(string: &str) -> String {
1451 let mut result = String::with_capacity(string.len() * 2 + 1);
1452 for c in string.chars() {
1453 if c.is_alphanumeric() {
1454 result.push('%');
1455 result.push(c);
1456 }
1457 }
1458 result.push('%');
1459 result
1460}
1461
1462fn random_invite_code() -> String {
1463 nanoid::nanoid!(16)
1464}
1465
1466#[cfg(test)]
1467pub mod tests {
1468 use super::*;
1469 use anyhow::anyhow;
1470 use collections::BTreeMap;
1471 use gpui::executor::{Background, Deterministic};
1472 use lazy_static::lazy_static;
1473 use parking_lot::Mutex;
1474 use rand::prelude::*;
1475 use sqlx::{
1476 migrate::{MigrateDatabase, Migrator},
1477 Postgres,
1478 };
1479 use std::{path::Path, sync::Arc};
1480 use util::post_inc;
1481
1482 #[tokio::test(flavor = "multi_thread")]
1483 async fn test_get_users_by_ids() {
1484 for test_db in [
1485 TestDb::postgres().await,
1486 TestDb::fake(build_background_executor()),
1487 ] {
1488 let db = test_db.db();
1489
1490 let user = db.create_user("user", None, false).await.unwrap();
1491 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1492 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1493 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1494
1495 assert_eq!(
1496 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1497 .await
1498 .unwrap(),
1499 vec![
1500 User {
1501 id: user,
1502 github_login: "user".to_string(),
1503 admin: false,
1504 ..Default::default()
1505 },
1506 User {
1507 id: friend1,
1508 github_login: "friend-1".to_string(),
1509 admin: false,
1510 ..Default::default()
1511 },
1512 User {
1513 id: friend2,
1514 github_login: "friend-2".to_string(),
1515 admin: false,
1516 ..Default::default()
1517 },
1518 User {
1519 id: friend3,
1520 github_login: "friend-3".to_string(),
1521 admin: false,
1522 ..Default::default()
1523 }
1524 ]
1525 );
1526 }
1527 }
1528
1529 #[tokio::test(flavor = "multi_thread")]
1530 async fn test_create_users() {
1531 let db = TestDb::postgres().await;
1532 let db = db.db();
1533
1534 // Create the first batch of users, ensuring invite counts are assigned
1535 // correctly and the respective invite codes are unique.
1536 let user_ids_batch_1 = db
1537 .create_users(vec![
1538 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1539 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1540 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1541 ])
1542 .await
1543 .unwrap();
1544 assert_eq!(user_ids_batch_1.len(), 3);
1545
1546 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1547 assert_eq!(users.len(), 3);
1548 assert_eq!(users[0].github_login, "user1");
1549 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1550 assert_eq!(users[0].invite_count, 5);
1551 assert_eq!(users[1].github_login, "user2");
1552 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1553 assert_eq!(users[1].invite_count, 4);
1554 assert_eq!(users[2].github_login, "user3");
1555 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1556 assert_eq!(users[2].invite_count, 3);
1557
1558 let invite_code_1 = users[0].invite_code.clone().unwrap();
1559 let invite_code_2 = users[1].invite_code.clone().unwrap();
1560 let invite_code_3 = users[2].invite_code.clone().unwrap();
1561 assert_ne!(invite_code_1, invite_code_2);
1562 assert_ne!(invite_code_1, invite_code_3);
1563 assert_ne!(invite_code_2, invite_code_3);
1564
1565 // Create the second batch of users and include a user that is already in the database, ensuring
1566 // the invite count for the existing user is updated without changing their invite code.
1567 let user_ids_batch_2 = db
1568 .create_users(vec![
1569 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1570 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1571 ])
1572 .await
1573 .unwrap();
1574 assert_eq!(user_ids_batch_2.len(), 2);
1575 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1576
1577 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1578 assert_eq!(users.len(), 2);
1579 assert_eq!(users[0].github_login, "user2");
1580 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1581 assert_eq!(users[0].invite_count, 10);
1582 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1583 assert_eq!(users[1].github_login, "user4");
1584 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1585 assert_eq!(users[1].invite_count, 2);
1586
1587 let invite_code_4 = users[1].invite_code.clone().unwrap();
1588 assert_ne!(invite_code_4, invite_code_1);
1589 assert_ne!(invite_code_4, invite_code_2);
1590 assert_ne!(invite_code_4, invite_code_3);
1591 }
1592
1593 #[tokio::test(flavor = "multi_thread")]
1594 async fn test_worktree_extensions() {
1595 let test_db = TestDb::postgres().await;
1596 let db = test_db.db();
1597
1598 let user = db.create_user("user_1", None, false).await.unwrap();
1599 let project = db.register_project(user).await.unwrap();
1600
1601 db.update_worktree_extensions(project, 100, Default::default())
1602 .await
1603 .unwrap();
1604 db.update_worktree_extensions(
1605 project,
1606 100,
1607 [("rs".to_string(), 5), ("md".to_string(), 3)]
1608 .into_iter()
1609 .collect(),
1610 )
1611 .await
1612 .unwrap();
1613 db.update_worktree_extensions(
1614 project,
1615 100,
1616 [("rs".to_string(), 6), ("md".to_string(), 5)]
1617 .into_iter()
1618 .collect(),
1619 )
1620 .await
1621 .unwrap();
1622 db.update_worktree_extensions(
1623 project,
1624 101,
1625 [("ts".to_string(), 2), ("md".to_string(), 1)]
1626 .into_iter()
1627 .collect(),
1628 )
1629 .await
1630 .unwrap();
1631
1632 assert_eq!(
1633 db.get_project_extensions(project).await.unwrap(),
1634 [
1635 (
1636 100,
1637 [("rs".into(), 6), ("md".into(), 5),]
1638 .into_iter()
1639 .collect::<HashMap<_, _>>()
1640 ),
1641 (
1642 101,
1643 [("ts".into(), 2), ("md".into(), 1),]
1644 .into_iter()
1645 .collect::<HashMap<_, _>>()
1646 )
1647 ]
1648 .into_iter()
1649 .collect()
1650 );
1651 }
1652
1653 #[tokio::test(flavor = "multi_thread")]
1654 async fn test_user_activity() {
1655 let test_db = TestDb::postgres().await;
1656 let db = test_db.db();
1657
1658 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1659 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1660 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1661 let project_1 = db.register_project(user_1).await.unwrap();
1662 db.update_worktree_extensions(
1663 project_1,
1664 1,
1665 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1666 )
1667 .await
1668 .unwrap();
1669 let project_2 = db.register_project(user_2).await.unwrap();
1670 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1671
1672 // User 2 opens a project
1673 let t1 = t0 + Duration::from_secs(10);
1674 db.record_user_activity(t0..t1, &[(user_2, project_2)])
1675 .await
1676 .unwrap();
1677
1678 let t2 = t1 + Duration::from_secs(10);
1679 db.record_user_activity(t1..t2, &[(user_2, project_2)])
1680 .await
1681 .unwrap();
1682
1683 // User 1 joins the project
1684 let t3 = t2 + Duration::from_secs(10);
1685 db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1686 .await
1687 .unwrap();
1688
1689 // User 1 opens another project
1690 let t4 = t3 + Duration::from_secs(10);
1691 db.record_user_activity(
1692 t3..t4,
1693 &[
1694 (user_2, project_2),
1695 (user_1, project_2),
1696 (user_1, project_1),
1697 ],
1698 )
1699 .await
1700 .unwrap();
1701
1702 // User 3 joins that project
1703 let t5 = t4 + Duration::from_secs(10);
1704 db.record_user_activity(
1705 t4..t5,
1706 &[
1707 (user_2, project_2),
1708 (user_1, project_2),
1709 (user_1, project_1),
1710 (user_3, project_1),
1711 ],
1712 )
1713 .await
1714 .unwrap();
1715
1716 // User 2 leaves
1717 let t6 = t5 + Duration::from_secs(5);
1718 db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1719 .await
1720 .unwrap();
1721
1722 let t7 = t6 + Duration::from_secs(60);
1723 let t8 = t7 + Duration::from_secs(10);
1724 db.record_user_activity(t7..t8, &[(user_1, project_1)])
1725 .await
1726 .unwrap();
1727
1728 assert_eq!(
1729 db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1730 &[
1731 UserActivitySummary {
1732 id: user_1,
1733 github_login: "user_1".to_string(),
1734 project_activity: vec![
1735 ProjectActivitySummary {
1736 id: project_1,
1737 duration: Duration::from_secs(25),
1738 max_collaborators: 2
1739 },
1740 ProjectActivitySummary {
1741 id: project_2,
1742 duration: Duration::from_secs(30),
1743 max_collaborators: 2
1744 }
1745 ]
1746 },
1747 UserActivitySummary {
1748 id: user_2,
1749 github_login: "user_2".to_string(),
1750 project_activity: vec![ProjectActivitySummary {
1751 id: project_2,
1752 duration: Duration::from_secs(50),
1753 max_collaborators: 2
1754 }]
1755 },
1756 UserActivitySummary {
1757 id: user_3,
1758 github_login: "user_3".to_string(),
1759 project_activity: vec![ProjectActivitySummary {
1760 id: project_1,
1761 duration: Duration::from_secs(15),
1762 max_collaborators: 2
1763 }]
1764 },
1765 ]
1766 );
1767
1768 assert_eq!(
1769 db.get_active_user_count(t0..t6, Duration::from_secs(56), false)
1770 .await
1771 .unwrap(),
1772 0
1773 );
1774 assert_eq!(
1775 db.get_active_user_count(t0..t6, Duration::from_secs(56), true)
1776 .await
1777 .unwrap(),
1778 0
1779 );
1780 assert_eq!(
1781 db.get_active_user_count(t0..t6, Duration::from_secs(54), false)
1782 .await
1783 .unwrap(),
1784 1
1785 );
1786 assert_eq!(
1787 db.get_active_user_count(t0..t6, Duration::from_secs(54), true)
1788 .await
1789 .unwrap(),
1790 1
1791 );
1792 assert_eq!(
1793 db.get_active_user_count(t0..t6, Duration::from_secs(30), false)
1794 .await
1795 .unwrap(),
1796 2
1797 );
1798 assert_eq!(
1799 db.get_active_user_count(t0..t6, Duration::from_secs(30), true)
1800 .await
1801 .unwrap(),
1802 2
1803 );
1804 assert_eq!(
1805 db.get_active_user_count(t0..t6, Duration::from_secs(10), false)
1806 .await
1807 .unwrap(),
1808 3
1809 );
1810 assert_eq!(
1811 db.get_active_user_count(t0..t6, Duration::from_secs(10), true)
1812 .await
1813 .unwrap(),
1814 3
1815 );
1816 assert_eq!(
1817 db.get_active_user_count(t0..t1, Duration::from_secs(5), false)
1818 .await
1819 .unwrap(),
1820 1
1821 );
1822 assert_eq!(
1823 db.get_active_user_count(t0..t1, Duration::from_secs(5), true)
1824 .await
1825 .unwrap(),
1826 0
1827 );
1828
1829 assert_eq!(
1830 db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1831 &[
1832 UserActivityPeriod {
1833 project_id: project_1,
1834 start: t3,
1835 end: t6,
1836 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1837 },
1838 UserActivityPeriod {
1839 project_id: project_2,
1840 start: t3,
1841 end: t5,
1842 extensions: Default::default(),
1843 },
1844 ]
1845 );
1846 assert_eq!(
1847 db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1848 &[
1849 UserActivityPeriod {
1850 project_id: project_2,
1851 start: t2,
1852 end: t5,
1853 extensions: Default::default(),
1854 },
1855 UserActivityPeriod {
1856 project_id: project_1,
1857 start: t3,
1858 end: t6,
1859 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1860 },
1861 UserActivityPeriod {
1862 project_id: project_1,
1863 start: t7,
1864 end: t8,
1865 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1866 },
1867 ]
1868 );
1869 }
1870
1871 #[tokio::test(flavor = "multi_thread")]
1872 async fn test_recent_channel_messages() {
1873 for test_db in [
1874 TestDb::postgres().await,
1875 TestDb::fake(build_background_executor()),
1876 ] {
1877 let db = test_db.db();
1878 let user = db.create_user("user", None, false).await.unwrap();
1879 let org = db.create_org("org", "org").await.unwrap();
1880 let channel = db.create_org_channel(org, "channel").await.unwrap();
1881 for i in 0..10 {
1882 db.create_channel_message(
1883 channel,
1884 user,
1885 &i.to_string(),
1886 OffsetDateTime::now_utc(),
1887 i,
1888 )
1889 .await
1890 .unwrap();
1891 }
1892
1893 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1894 assert_eq!(
1895 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1896 ["5", "6", "7", "8", "9"]
1897 );
1898
1899 let prev_messages = db
1900 .get_channel_messages(channel, 4, Some(messages[0].id))
1901 .await
1902 .unwrap();
1903 assert_eq!(
1904 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1905 ["1", "2", "3", "4"]
1906 );
1907 }
1908 }
1909
1910 #[tokio::test(flavor = "multi_thread")]
1911 async fn test_channel_message_nonces() {
1912 for test_db in [
1913 TestDb::postgres().await,
1914 TestDb::fake(build_background_executor()),
1915 ] {
1916 let db = test_db.db();
1917 let user = db.create_user("user", None, false).await.unwrap();
1918 let org = db.create_org("org", "org").await.unwrap();
1919 let channel = db.create_org_channel(org, "channel").await.unwrap();
1920
1921 let msg1_id = db
1922 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1923 .await
1924 .unwrap();
1925 let msg2_id = db
1926 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1927 .await
1928 .unwrap();
1929 let msg3_id = db
1930 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1931 .await
1932 .unwrap();
1933 let msg4_id = db
1934 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1935 .await
1936 .unwrap();
1937
1938 assert_ne!(msg1_id, msg2_id);
1939 assert_eq!(msg1_id, msg3_id);
1940 assert_eq!(msg2_id, msg4_id);
1941 }
1942 }
1943
1944 #[tokio::test(flavor = "multi_thread")]
1945 async fn test_create_access_tokens() {
1946 let test_db = TestDb::postgres().await;
1947 let db = test_db.db();
1948 let user = db.create_user("the-user", None, false).await.unwrap();
1949
1950 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1951 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1952 assert_eq!(
1953 db.get_access_token_hashes(user).await.unwrap(),
1954 &["h2".to_string(), "h1".to_string()]
1955 );
1956
1957 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1958 assert_eq!(
1959 db.get_access_token_hashes(user).await.unwrap(),
1960 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1961 );
1962
1963 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1964 assert_eq!(
1965 db.get_access_token_hashes(user).await.unwrap(),
1966 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1967 );
1968
1969 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1970 assert_eq!(
1971 db.get_access_token_hashes(user).await.unwrap(),
1972 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1973 );
1974 }
1975
1976 #[test]
1977 fn test_fuzzy_like_string() {
1978 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1979 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1980 assert_eq!(fuzzy_like_string(" z "), "%z%");
1981 }
1982
1983 #[tokio::test(flavor = "multi_thread")]
1984 async fn test_fuzzy_search_users() {
1985 let test_db = TestDb::postgres().await;
1986 let db = test_db.db();
1987 for github_login in [
1988 "California",
1989 "colorado",
1990 "oregon",
1991 "washington",
1992 "florida",
1993 "delaware",
1994 "rhode-island",
1995 ] {
1996 db.create_user(github_login, None, false).await.unwrap();
1997 }
1998
1999 assert_eq!(
2000 fuzzy_search_user_names(db, "clr").await,
2001 &["colorado", "California"]
2002 );
2003 assert_eq!(
2004 fuzzy_search_user_names(db, "ro").await,
2005 &["rhode-island", "colorado", "oregon"],
2006 );
2007
2008 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
2009 db.fuzzy_search_users(query, 10)
2010 .await
2011 .unwrap()
2012 .into_iter()
2013 .map(|user| user.github_login)
2014 .collect::<Vec<_>>()
2015 }
2016 }
2017
2018 #[tokio::test(flavor = "multi_thread")]
2019 async fn test_add_contacts() {
2020 for test_db in [
2021 TestDb::postgres().await,
2022 TestDb::fake(build_background_executor()),
2023 ] {
2024 let db = test_db.db();
2025
2026 let user_1 = db.create_user("user1", None, false).await.unwrap();
2027 let user_2 = db.create_user("user2", None, false).await.unwrap();
2028 let user_3 = db.create_user("user3", None, false).await.unwrap();
2029
2030 // User starts with no contacts
2031 assert_eq!(
2032 db.get_contacts(user_1).await.unwrap(),
2033 vec![Contact::Accepted {
2034 user_id: user_1,
2035 should_notify: false
2036 }],
2037 );
2038
2039 // User requests a contact. Both users see the pending request.
2040 db.send_contact_request(user_1, user_2).await.unwrap();
2041 assert!(!db.has_contact(user_1, user_2).await.unwrap());
2042 assert!(!db.has_contact(user_2, user_1).await.unwrap());
2043 assert_eq!(
2044 db.get_contacts(user_1).await.unwrap(),
2045 &[
2046 Contact::Accepted {
2047 user_id: user_1,
2048 should_notify: false
2049 },
2050 Contact::Outgoing { user_id: user_2 }
2051 ],
2052 );
2053 assert_eq!(
2054 db.get_contacts(user_2).await.unwrap(),
2055 &[
2056 Contact::Incoming {
2057 user_id: user_1,
2058 should_notify: true
2059 },
2060 Contact::Accepted {
2061 user_id: user_2,
2062 should_notify: false
2063 },
2064 ]
2065 );
2066
2067 // User 2 dismisses the contact request notification without accepting or rejecting.
2068 // We shouldn't notify them again.
2069 db.dismiss_contact_notification(user_1, user_2)
2070 .await
2071 .unwrap_err();
2072 db.dismiss_contact_notification(user_2, user_1)
2073 .await
2074 .unwrap();
2075 assert_eq!(
2076 db.get_contacts(user_2).await.unwrap(),
2077 &[
2078 Contact::Incoming {
2079 user_id: user_1,
2080 should_notify: false
2081 },
2082 Contact::Accepted {
2083 user_id: user_2,
2084 should_notify: false
2085 },
2086 ]
2087 );
2088
2089 // User can't accept their own contact request
2090 db.respond_to_contact_request(user_1, user_2, true)
2091 .await
2092 .unwrap_err();
2093
2094 // User accepts a contact request. Both users see the contact.
2095 db.respond_to_contact_request(user_2, user_1, true)
2096 .await
2097 .unwrap();
2098 assert_eq!(
2099 db.get_contacts(user_1).await.unwrap(),
2100 &[
2101 Contact::Accepted {
2102 user_id: user_1,
2103 should_notify: false
2104 },
2105 Contact::Accepted {
2106 user_id: user_2,
2107 should_notify: true
2108 }
2109 ],
2110 );
2111 assert!(db.has_contact(user_1, user_2).await.unwrap());
2112 assert!(db.has_contact(user_2, user_1).await.unwrap());
2113 assert_eq!(
2114 db.get_contacts(user_2).await.unwrap(),
2115 &[
2116 Contact::Accepted {
2117 user_id: user_1,
2118 should_notify: false,
2119 },
2120 Contact::Accepted {
2121 user_id: user_2,
2122 should_notify: false,
2123 },
2124 ]
2125 );
2126
2127 // Users cannot re-request existing contacts.
2128 db.send_contact_request(user_1, user_2).await.unwrap_err();
2129 db.send_contact_request(user_2, user_1).await.unwrap_err();
2130
2131 // Users can't dismiss notifications of them accepting other users' requests.
2132 db.dismiss_contact_notification(user_2, user_1)
2133 .await
2134 .unwrap_err();
2135 assert_eq!(
2136 db.get_contacts(user_1).await.unwrap(),
2137 &[
2138 Contact::Accepted {
2139 user_id: user_1,
2140 should_notify: false
2141 },
2142 Contact::Accepted {
2143 user_id: user_2,
2144 should_notify: true,
2145 },
2146 ]
2147 );
2148
2149 // Users can dismiss notifications of other users accepting their requests.
2150 db.dismiss_contact_notification(user_1, user_2)
2151 .await
2152 .unwrap();
2153 assert_eq!(
2154 db.get_contacts(user_1).await.unwrap(),
2155 &[
2156 Contact::Accepted {
2157 user_id: user_1,
2158 should_notify: false
2159 },
2160 Contact::Accepted {
2161 user_id: user_2,
2162 should_notify: false,
2163 },
2164 ]
2165 );
2166
2167 // Users send each other concurrent contact requests and
2168 // see that they are immediately accepted.
2169 db.send_contact_request(user_1, user_3).await.unwrap();
2170 db.send_contact_request(user_3, user_1).await.unwrap();
2171 assert_eq!(
2172 db.get_contacts(user_1).await.unwrap(),
2173 &[
2174 Contact::Accepted {
2175 user_id: user_1,
2176 should_notify: false
2177 },
2178 Contact::Accepted {
2179 user_id: user_2,
2180 should_notify: false,
2181 },
2182 Contact::Accepted {
2183 user_id: user_3,
2184 should_notify: false
2185 },
2186 ]
2187 );
2188 assert_eq!(
2189 db.get_contacts(user_3).await.unwrap(),
2190 &[
2191 Contact::Accepted {
2192 user_id: user_1,
2193 should_notify: false
2194 },
2195 Contact::Accepted {
2196 user_id: user_3,
2197 should_notify: false
2198 }
2199 ],
2200 );
2201
2202 // User declines a contact request. Both users see that it is gone.
2203 db.send_contact_request(user_2, user_3).await.unwrap();
2204 db.respond_to_contact_request(user_3, user_2, false)
2205 .await
2206 .unwrap();
2207 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2208 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2209 assert_eq!(
2210 db.get_contacts(user_2).await.unwrap(),
2211 &[
2212 Contact::Accepted {
2213 user_id: user_1,
2214 should_notify: false
2215 },
2216 Contact::Accepted {
2217 user_id: user_2,
2218 should_notify: false
2219 }
2220 ]
2221 );
2222 assert_eq!(
2223 db.get_contacts(user_3).await.unwrap(),
2224 &[
2225 Contact::Accepted {
2226 user_id: user_1,
2227 should_notify: false
2228 },
2229 Contact::Accepted {
2230 user_id: user_3,
2231 should_notify: false
2232 }
2233 ],
2234 );
2235 }
2236 }
2237
2238 #[tokio::test(flavor = "multi_thread")]
2239 async fn test_invite_codes() {
2240 let postgres = TestDb::postgres().await;
2241 let db = postgres.db();
2242 let user1 = db.create_user("user-1", None, false).await.unwrap();
2243
2244 // Initially, user 1 has no invite code
2245 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2246
2247 // Setting invite count to 0 when no code is assigned does not assign a new code
2248 db.set_invite_count(user1, 0).await.unwrap();
2249 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2250
2251 // User 1 creates an invite code that can be used twice.
2252 db.set_invite_count(user1, 2).await.unwrap();
2253 let (invite_code, invite_count) =
2254 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2255 assert_eq!(invite_count, 2);
2256
2257 // User 2 redeems the invite code and becomes a contact of user 1.
2258 let user2 = db
2259 .redeem_invite_code(&invite_code, "user-2", None)
2260 .await
2261 .unwrap();
2262 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2263 assert_eq!(invite_count, 1);
2264 assert_eq!(
2265 db.get_contacts(user1).await.unwrap(),
2266 [
2267 Contact::Accepted {
2268 user_id: user1,
2269 should_notify: false
2270 },
2271 Contact::Accepted {
2272 user_id: user2,
2273 should_notify: true
2274 }
2275 ]
2276 );
2277 assert_eq!(
2278 db.get_contacts(user2).await.unwrap(),
2279 [
2280 Contact::Accepted {
2281 user_id: user1,
2282 should_notify: false
2283 },
2284 Contact::Accepted {
2285 user_id: user2,
2286 should_notify: false
2287 }
2288 ]
2289 );
2290
2291 // User 3 redeems the invite code and becomes a contact of user 1.
2292 let user3 = db
2293 .redeem_invite_code(&invite_code, "user-3", None)
2294 .await
2295 .unwrap();
2296 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2297 assert_eq!(invite_count, 0);
2298 assert_eq!(
2299 db.get_contacts(user1).await.unwrap(),
2300 [
2301 Contact::Accepted {
2302 user_id: user1,
2303 should_notify: false
2304 },
2305 Contact::Accepted {
2306 user_id: user2,
2307 should_notify: true
2308 },
2309 Contact::Accepted {
2310 user_id: user3,
2311 should_notify: true
2312 }
2313 ]
2314 );
2315 assert_eq!(
2316 db.get_contacts(user3).await.unwrap(),
2317 [
2318 Contact::Accepted {
2319 user_id: user1,
2320 should_notify: false
2321 },
2322 Contact::Accepted {
2323 user_id: user3,
2324 should_notify: false
2325 },
2326 ]
2327 );
2328
2329 // Trying to reedem the code for the third time results in an error.
2330 db.redeem_invite_code(&invite_code, "user-4", None)
2331 .await
2332 .unwrap_err();
2333
2334 // Invite count can be updated after the code has been created.
2335 db.set_invite_count(user1, 2).await.unwrap();
2336 let (latest_code, invite_count) =
2337 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2338 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2339 assert_eq!(invite_count, 2);
2340
2341 // User 4 can now redeem the invite code and becomes a contact of user 1.
2342 let user4 = db
2343 .redeem_invite_code(&invite_code, "user-4", None)
2344 .await
2345 .unwrap();
2346 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2347 assert_eq!(invite_count, 1);
2348 assert_eq!(
2349 db.get_contacts(user1).await.unwrap(),
2350 [
2351 Contact::Accepted {
2352 user_id: user1,
2353 should_notify: false
2354 },
2355 Contact::Accepted {
2356 user_id: user2,
2357 should_notify: true
2358 },
2359 Contact::Accepted {
2360 user_id: user3,
2361 should_notify: true
2362 },
2363 Contact::Accepted {
2364 user_id: user4,
2365 should_notify: true
2366 }
2367 ]
2368 );
2369 assert_eq!(
2370 db.get_contacts(user4).await.unwrap(),
2371 [
2372 Contact::Accepted {
2373 user_id: user1,
2374 should_notify: false
2375 },
2376 Contact::Accepted {
2377 user_id: user4,
2378 should_notify: false
2379 },
2380 ]
2381 );
2382
2383 // An existing user cannot redeem invite codes.
2384 db.redeem_invite_code(&invite_code, "user-2", None)
2385 .await
2386 .unwrap_err();
2387 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2388 assert_eq!(invite_count, 1);
2389
2390 // Ensure invited users get invite codes too.
2391 assert_eq!(
2392 db.get_invite_code_for_user(user2).await.unwrap().unwrap().1,
2393 5
2394 );
2395 assert_eq!(
2396 db.get_invite_code_for_user(user3).await.unwrap().unwrap().1,
2397 5
2398 );
2399 assert_eq!(
2400 db.get_invite_code_for_user(user4).await.unwrap().unwrap().1,
2401 5
2402 );
2403 }
2404
2405 pub struct TestDb {
2406 pub db: Option<Arc<dyn Db>>,
2407 pub url: String,
2408 }
2409
2410 impl TestDb {
2411 pub async fn postgres() -> Self {
2412 lazy_static! {
2413 static ref LOCK: Mutex<()> = Mutex::new(());
2414 }
2415
2416 let _guard = LOCK.lock();
2417 let mut rng = StdRng::from_entropy();
2418 let name = format!("zed-test-{}", rng.gen::<u128>());
2419 let url = format!("postgres://postgres@localhost/{}", name);
2420 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2421 Postgres::create_database(&url)
2422 .await
2423 .expect("failed to create test db");
2424 let db = PostgresDb::new(&url, 5).await.unwrap();
2425 let migrator = Migrator::new(migrations_path).await.unwrap();
2426 migrator.run(&db.pool).await.unwrap();
2427 Self {
2428 db: Some(Arc::new(db)),
2429 url,
2430 }
2431 }
2432
2433 pub fn fake(background: Arc<Background>) -> Self {
2434 Self {
2435 db: Some(Arc::new(FakeDb::new(background))),
2436 url: Default::default(),
2437 }
2438 }
2439
2440 pub fn db(&self) -> &Arc<dyn Db> {
2441 self.db.as_ref().unwrap()
2442 }
2443 }
2444
2445 impl Drop for TestDb {
2446 fn drop(&mut self) {
2447 if let Some(db) = self.db.take() {
2448 futures::executor::block_on(db.teardown(&self.url));
2449 }
2450 }
2451 }
2452
2453 pub struct FakeDb {
2454 background: Arc<Background>,
2455 pub users: Mutex<BTreeMap<UserId, User>>,
2456 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2457 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2458 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2459 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2460 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2461 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2462 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2463 pub contacts: Mutex<Vec<FakeContact>>,
2464 next_channel_message_id: Mutex<i32>,
2465 next_user_id: Mutex<i32>,
2466 next_org_id: Mutex<i32>,
2467 next_channel_id: Mutex<i32>,
2468 next_project_id: Mutex<i32>,
2469 }
2470
2471 #[derive(Debug)]
2472 pub struct FakeContact {
2473 pub requester_id: UserId,
2474 pub responder_id: UserId,
2475 pub accepted: bool,
2476 pub should_notify: bool,
2477 }
2478
2479 impl FakeDb {
2480 pub fn new(background: Arc<Background>) -> Self {
2481 Self {
2482 background,
2483 users: Default::default(),
2484 next_user_id: Mutex::new(0),
2485 projects: Default::default(),
2486 worktree_extensions: Default::default(),
2487 next_project_id: Mutex::new(1),
2488 orgs: Default::default(),
2489 next_org_id: Mutex::new(1),
2490 org_memberships: Default::default(),
2491 channels: Default::default(),
2492 next_channel_id: Mutex::new(1),
2493 channel_memberships: Default::default(),
2494 channel_messages: Default::default(),
2495 next_channel_message_id: Mutex::new(1),
2496 contacts: Default::default(),
2497 }
2498 }
2499 }
2500
2501 #[async_trait]
2502 impl Db for FakeDb {
2503 async fn create_user(
2504 &self,
2505 github_login: &str,
2506 email_address: Option<&str>,
2507 admin: bool,
2508 ) -> Result<UserId> {
2509 self.background.simulate_random_delay().await;
2510
2511 let mut users = self.users.lock();
2512 if let Some(user) = users
2513 .values()
2514 .find(|user| user.github_login == github_login)
2515 {
2516 Ok(user.id)
2517 } else {
2518 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2519 users.insert(
2520 user_id,
2521 User {
2522 id: user_id,
2523 github_login: github_login.to_string(),
2524 email_address: email_address.map(str::to_string),
2525 admin,
2526 invite_code: None,
2527 invite_count: 0,
2528 connected_once: false,
2529 },
2530 );
2531 Ok(user_id)
2532 }
2533 }
2534
2535 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2536 unimplemented!()
2537 }
2538
2539 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2540 unimplemented!()
2541 }
2542
2543 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2544 unimplemented!()
2545 }
2546
2547 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2548 self.background.simulate_random_delay().await;
2549 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2550 }
2551
2552 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2553 self.background.simulate_random_delay().await;
2554 let users = self.users.lock();
2555 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2556 }
2557
2558 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2559 unimplemented!()
2560 }
2561
2562 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2563 self.background.simulate_random_delay().await;
2564 Ok(self
2565 .users
2566 .lock()
2567 .values()
2568 .find(|user| user.github_login == github_login)
2569 .cloned())
2570 }
2571
2572 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2573 unimplemented!()
2574 }
2575
2576 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2577 self.background.simulate_random_delay().await;
2578 let mut users = self.users.lock();
2579 let mut user = users
2580 .get_mut(&id)
2581 .ok_or_else(|| anyhow!("user not found"))?;
2582 user.connected_once = connected_once;
2583 Ok(())
2584 }
2585
2586 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2587 unimplemented!()
2588 }
2589
2590 // invite codes
2591
2592 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2593 unimplemented!()
2594 }
2595
2596 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2597 self.background.simulate_random_delay().await;
2598 Ok(None)
2599 }
2600
2601 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2602 unimplemented!()
2603 }
2604
2605 async fn redeem_invite_code(
2606 &self,
2607 _code: &str,
2608 _login: &str,
2609 _email_address: Option<&str>,
2610 ) -> Result<UserId> {
2611 unimplemented!()
2612 }
2613
2614 // projects
2615
2616 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2617 self.background.simulate_random_delay().await;
2618 if !self.users.lock().contains_key(&host_user_id) {
2619 Err(anyhow!("no such user"))?;
2620 }
2621
2622 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2623 self.projects.lock().insert(
2624 project_id,
2625 Project {
2626 id: project_id,
2627 host_user_id,
2628 unregistered: false,
2629 },
2630 );
2631 Ok(project_id)
2632 }
2633
2634 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2635 self.background.simulate_random_delay().await;
2636 self.projects
2637 .lock()
2638 .get_mut(&project_id)
2639 .ok_or_else(|| anyhow!("no such project"))?
2640 .unregistered = true;
2641 Ok(())
2642 }
2643
2644 async fn update_worktree_extensions(
2645 &self,
2646 project_id: ProjectId,
2647 worktree_id: u64,
2648 extensions: HashMap<String, u32>,
2649 ) -> Result<()> {
2650 self.background.simulate_random_delay().await;
2651 if !self.projects.lock().contains_key(&project_id) {
2652 Err(anyhow!("no such project"))?;
2653 }
2654
2655 for (extension, count) in extensions {
2656 self.worktree_extensions
2657 .lock()
2658 .insert((project_id, worktree_id, extension), count);
2659 }
2660
2661 Ok(())
2662 }
2663
2664 async fn get_project_extensions(
2665 &self,
2666 _project_id: ProjectId,
2667 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2668 unimplemented!()
2669 }
2670
2671 async fn record_user_activity(
2672 &self,
2673 _time_period: Range<OffsetDateTime>,
2674 _active_projects: &[(UserId, ProjectId)],
2675 ) -> Result<()> {
2676 unimplemented!()
2677 }
2678
2679 async fn get_active_user_count(
2680 &self,
2681 _time_period: Range<OffsetDateTime>,
2682 _min_duration: Duration,
2683 _only_collaborative: bool,
2684 ) -> Result<usize> {
2685 unimplemented!()
2686 }
2687
2688 async fn get_top_users_activity_summary(
2689 &self,
2690 _time_period: Range<OffsetDateTime>,
2691 _limit: usize,
2692 ) -> Result<Vec<UserActivitySummary>> {
2693 unimplemented!()
2694 }
2695
2696 async fn get_user_activity_timeline(
2697 &self,
2698 _time_period: Range<OffsetDateTime>,
2699 _user_id: UserId,
2700 ) -> Result<Vec<UserActivityPeriod>> {
2701 unimplemented!()
2702 }
2703
2704 // contacts
2705
2706 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2707 self.background.simulate_random_delay().await;
2708 let mut contacts = vec![Contact::Accepted {
2709 user_id: id,
2710 should_notify: false,
2711 }];
2712
2713 for contact in self.contacts.lock().iter() {
2714 if contact.requester_id == id {
2715 if contact.accepted {
2716 contacts.push(Contact::Accepted {
2717 user_id: contact.responder_id,
2718 should_notify: contact.should_notify,
2719 });
2720 } else {
2721 contacts.push(Contact::Outgoing {
2722 user_id: contact.responder_id,
2723 });
2724 }
2725 } else if contact.responder_id == id {
2726 if contact.accepted {
2727 contacts.push(Contact::Accepted {
2728 user_id: contact.requester_id,
2729 should_notify: false,
2730 });
2731 } else {
2732 contacts.push(Contact::Incoming {
2733 user_id: contact.requester_id,
2734 should_notify: contact.should_notify,
2735 });
2736 }
2737 }
2738 }
2739
2740 contacts.sort_unstable_by_key(|contact| contact.user_id());
2741 Ok(contacts)
2742 }
2743
2744 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2745 self.background.simulate_random_delay().await;
2746 Ok(self.contacts.lock().iter().any(|contact| {
2747 contact.accepted
2748 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2749 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2750 }))
2751 }
2752
2753 async fn send_contact_request(
2754 &self,
2755 requester_id: UserId,
2756 responder_id: UserId,
2757 ) -> Result<()> {
2758 self.background.simulate_random_delay().await;
2759 let mut contacts = self.contacts.lock();
2760 for contact in contacts.iter_mut() {
2761 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2762 if contact.accepted {
2763 Err(anyhow!("contact already exists"))?;
2764 } else {
2765 Err(anyhow!("contact already requested"))?;
2766 }
2767 }
2768 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2769 if contact.accepted {
2770 Err(anyhow!("contact already exists"))?;
2771 } else {
2772 contact.accepted = true;
2773 contact.should_notify = false;
2774 return Ok(());
2775 }
2776 }
2777 }
2778 contacts.push(FakeContact {
2779 requester_id,
2780 responder_id,
2781 accepted: false,
2782 should_notify: true,
2783 });
2784 Ok(())
2785 }
2786
2787 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2788 self.background.simulate_random_delay().await;
2789 self.contacts.lock().retain(|contact| {
2790 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2791 });
2792 Ok(())
2793 }
2794
2795 async fn dismiss_contact_notification(
2796 &self,
2797 user_id: UserId,
2798 contact_user_id: UserId,
2799 ) -> Result<()> {
2800 self.background.simulate_random_delay().await;
2801 let mut contacts = self.contacts.lock();
2802 for contact in contacts.iter_mut() {
2803 if contact.requester_id == contact_user_id
2804 && contact.responder_id == user_id
2805 && !contact.accepted
2806 {
2807 contact.should_notify = false;
2808 return Ok(());
2809 }
2810 if contact.requester_id == user_id
2811 && contact.responder_id == contact_user_id
2812 && contact.accepted
2813 {
2814 contact.should_notify = false;
2815 return Ok(());
2816 }
2817 }
2818 Err(anyhow!("no such notification"))?
2819 }
2820
2821 async fn respond_to_contact_request(
2822 &self,
2823 responder_id: UserId,
2824 requester_id: UserId,
2825 accept: bool,
2826 ) -> Result<()> {
2827 self.background.simulate_random_delay().await;
2828 let mut contacts = self.contacts.lock();
2829 for (ix, contact) in contacts.iter_mut().enumerate() {
2830 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2831 if contact.accepted {
2832 Err(anyhow!("contact already confirmed"))?;
2833 }
2834 if accept {
2835 contact.accepted = true;
2836 contact.should_notify = true;
2837 } else {
2838 contacts.remove(ix);
2839 }
2840 return Ok(());
2841 }
2842 }
2843 Err(anyhow!("no such contact request"))?
2844 }
2845
2846 async fn create_access_token_hash(
2847 &self,
2848 _user_id: UserId,
2849 _access_token_hash: &str,
2850 _max_access_token_count: usize,
2851 ) -> Result<()> {
2852 unimplemented!()
2853 }
2854
2855 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2856 unimplemented!()
2857 }
2858
2859 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2860 unimplemented!()
2861 }
2862
2863 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2864 self.background.simulate_random_delay().await;
2865 let mut orgs = self.orgs.lock();
2866 if orgs.values().any(|org| org.slug == slug) {
2867 Err(anyhow!("org already exists"))?
2868 } else {
2869 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2870 orgs.insert(
2871 org_id,
2872 Org {
2873 id: org_id,
2874 name: name.to_string(),
2875 slug: slug.to_string(),
2876 },
2877 );
2878 Ok(org_id)
2879 }
2880 }
2881
2882 async fn add_org_member(
2883 &self,
2884 org_id: OrgId,
2885 user_id: UserId,
2886 is_admin: bool,
2887 ) -> Result<()> {
2888 self.background.simulate_random_delay().await;
2889 if !self.orgs.lock().contains_key(&org_id) {
2890 Err(anyhow!("org does not exist"))?;
2891 }
2892 if !self.users.lock().contains_key(&user_id) {
2893 Err(anyhow!("user does not exist"))?;
2894 }
2895
2896 self.org_memberships
2897 .lock()
2898 .entry((org_id, user_id))
2899 .or_insert(is_admin);
2900 Ok(())
2901 }
2902
2903 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2904 self.background.simulate_random_delay().await;
2905 if !self.orgs.lock().contains_key(&org_id) {
2906 Err(anyhow!("org does not exist"))?;
2907 }
2908
2909 let mut channels = self.channels.lock();
2910 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2911 channels.insert(
2912 channel_id,
2913 Channel {
2914 id: channel_id,
2915 name: name.to_string(),
2916 owner_id: org_id.0,
2917 owner_is_user: false,
2918 },
2919 );
2920 Ok(channel_id)
2921 }
2922
2923 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2924 self.background.simulate_random_delay().await;
2925 Ok(self
2926 .channels
2927 .lock()
2928 .values()
2929 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2930 .cloned()
2931 .collect())
2932 }
2933
2934 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2935 self.background.simulate_random_delay().await;
2936 let channels = self.channels.lock();
2937 let memberships = self.channel_memberships.lock();
2938 Ok(channels
2939 .values()
2940 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2941 .cloned()
2942 .collect())
2943 }
2944
2945 async fn can_user_access_channel(
2946 &self,
2947 user_id: UserId,
2948 channel_id: ChannelId,
2949 ) -> Result<bool> {
2950 self.background.simulate_random_delay().await;
2951 Ok(self
2952 .channel_memberships
2953 .lock()
2954 .contains_key(&(channel_id, user_id)))
2955 }
2956
2957 async fn add_channel_member(
2958 &self,
2959 channel_id: ChannelId,
2960 user_id: UserId,
2961 is_admin: bool,
2962 ) -> Result<()> {
2963 self.background.simulate_random_delay().await;
2964 if !self.channels.lock().contains_key(&channel_id) {
2965 Err(anyhow!("channel does not exist"))?;
2966 }
2967 if !self.users.lock().contains_key(&user_id) {
2968 Err(anyhow!("user does not exist"))?;
2969 }
2970
2971 self.channel_memberships
2972 .lock()
2973 .entry((channel_id, user_id))
2974 .or_insert(is_admin);
2975 Ok(())
2976 }
2977
2978 async fn create_channel_message(
2979 &self,
2980 channel_id: ChannelId,
2981 sender_id: UserId,
2982 body: &str,
2983 timestamp: OffsetDateTime,
2984 nonce: u128,
2985 ) -> Result<MessageId> {
2986 self.background.simulate_random_delay().await;
2987 if !self.channels.lock().contains_key(&channel_id) {
2988 Err(anyhow!("channel does not exist"))?;
2989 }
2990 if !self.users.lock().contains_key(&sender_id) {
2991 Err(anyhow!("user does not exist"))?;
2992 }
2993
2994 let mut messages = self.channel_messages.lock();
2995 if let Some(message) = messages
2996 .values()
2997 .find(|message| message.nonce.as_u128() == nonce)
2998 {
2999 Ok(message.id)
3000 } else {
3001 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
3002 messages.insert(
3003 message_id,
3004 ChannelMessage {
3005 id: message_id,
3006 channel_id,
3007 sender_id,
3008 body: body.to_string(),
3009 sent_at: timestamp,
3010 nonce: Uuid::from_u128(nonce),
3011 },
3012 );
3013 Ok(message_id)
3014 }
3015 }
3016
3017 async fn get_channel_messages(
3018 &self,
3019 channel_id: ChannelId,
3020 count: usize,
3021 before_id: Option<MessageId>,
3022 ) -> Result<Vec<ChannelMessage>> {
3023 self.background.simulate_random_delay().await;
3024 let mut messages = self
3025 .channel_messages
3026 .lock()
3027 .values()
3028 .rev()
3029 .filter(|message| {
3030 message.channel_id == channel_id
3031 && message.id < before_id.unwrap_or(MessageId::MAX)
3032 })
3033 .take(count)
3034 .cloned()
3035 .collect::<Vec<_>>();
3036 messages.sort_unstable_by_key(|message| message.id);
3037 Ok(messages)
3038 }
3039
3040 async fn teardown(&self, _: &str) {}
3041
3042 #[cfg(test)]
3043 fn as_fake(&self) -> Option<&FakeDb> {
3044 Some(self)
3045 }
3046 }
3047
3048 fn build_background_executor() -> Arc<Background> {
3049 Deterministic::new(0).build_background()
3050 }
3051}