1use std::{cmp, ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use serde::{Deserialize, Serialize};
10pub use sqlx::postgres::PgPoolOptions as DbOptions;
11use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
12use time::{OffsetDateTime, PrimitiveDateTime};
13
14#[async_trait]
15pub trait Db: Send + Sync {
16 async fn create_user(
17 &self,
18 github_login: &str,
19 email_address: Option<&str>,
20 admin: bool,
21 ) -> Result<UserId>;
22 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
23 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
24 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
25 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
26 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
27 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
28 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
29 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
30 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
31 async fn destroy_user(&self, id: UserId) -> Result<()>;
32
33 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
34 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
35 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
36 async fn redeem_invite_code(
37 &self,
38 code: &str,
39 login: &str,
40 email_address: Option<&str>,
41 ) -> Result<UserId>;
42
43 /// Registers a new project for the given user.
44 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
45
46 /// Unregisters a project for the given project id.
47 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
48
49 /// Update file counts by extension for the given project and worktree.
50 async fn update_worktree_extensions(
51 &self,
52 project_id: ProjectId,
53 worktree_id: u64,
54 extensions: HashMap<String, u32>,
55 ) -> Result<()>;
56
57 /// Get the file counts on the given project keyed by their worktree and extension.
58 async fn get_project_extensions(
59 &self,
60 project_id: ProjectId,
61 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
62
63 /// Record which users have been active in which projects during
64 /// a given period of time.
65 async fn record_user_activity(
66 &self,
67 time_period: Range<OffsetDateTime>,
68 active_projects: &[(UserId, ProjectId)],
69 ) -> Result<()>;
70
71 /// Get the number of users who have been active in the given
72 /// time period for at least the given time duration.
73 async fn get_active_user_count(
74 &self,
75 time_period: Range<OffsetDateTime>,
76 min_duration: Duration,
77 only_collaborative: bool,
78 ) -> Result<usize>;
79
80 /// Get the users that have been most active during the given time period,
81 /// along with the amount of time they have been active in each project.
82 async fn get_top_users_activity_summary(
83 &self,
84 time_period: Range<OffsetDateTime>,
85 max_user_count: usize,
86 ) -> Result<Vec<UserActivitySummary>>;
87
88 /// Get the project activity for the given user and time period.
89 async fn get_user_activity_timeline(
90 &self,
91 time_period: Range<OffsetDateTime>,
92 user_id: UserId,
93 ) -> Result<Vec<UserActivityPeriod>>;
94
95 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
96 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
97 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
98 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
99 async fn dismiss_contact_notification(
100 &self,
101 responder_id: UserId,
102 requester_id: UserId,
103 ) -> Result<()>;
104 async fn respond_to_contact_request(
105 &self,
106 responder_id: UserId,
107 requester_id: UserId,
108 accept: bool,
109 ) -> Result<()>;
110
111 async fn create_access_token_hash(
112 &self,
113 user_id: UserId,
114 access_token_hash: &str,
115 max_access_token_count: usize,
116 ) -> Result<()>;
117 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
118 #[cfg(any(test, feature = "seed-support"))]
119
120 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
121 #[cfg(any(test, feature = "seed-support"))]
122 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
123 #[cfg(any(test, feature = "seed-support"))]
124 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
125 #[cfg(any(test, feature = "seed-support"))]
126 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
127 #[cfg(any(test, feature = "seed-support"))]
128
129 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
130 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
131 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
132 -> Result<bool>;
133 #[cfg(any(test, feature = "seed-support"))]
134 async fn add_channel_member(
135 &self,
136 channel_id: ChannelId,
137 user_id: UserId,
138 is_admin: bool,
139 ) -> Result<()>;
140 async fn create_channel_message(
141 &self,
142 channel_id: ChannelId,
143 sender_id: UserId,
144 body: &str,
145 timestamp: OffsetDateTime,
146 nonce: u128,
147 ) -> Result<MessageId>;
148 async fn get_channel_messages(
149 &self,
150 channel_id: ChannelId,
151 count: usize,
152 before_id: Option<MessageId>,
153 ) -> Result<Vec<ChannelMessage>>;
154 #[cfg(test)]
155 async fn teardown(&self, url: &str);
156 #[cfg(test)]
157 fn as_fake(&self) -> Option<&tests::FakeDb>;
158}
159
160pub struct PostgresDb {
161 pool: sqlx::PgPool,
162}
163
164impl PostgresDb {
165 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
166 let pool = DbOptions::new()
167 .max_connections(max_connections)
168 .connect(url)
169 .await
170 .context("failed to connect to postgres database")?;
171 Ok(Self { pool })
172 }
173}
174
175#[async_trait]
176impl Db for PostgresDb {
177 // users
178
179 async fn create_user(
180 &self,
181 github_login: &str,
182 email_address: Option<&str>,
183 admin: bool,
184 ) -> Result<UserId> {
185 let query = "
186 INSERT INTO users (github_login, email_address, admin)
187 VALUES ($1, $2, $3)
188 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
189 RETURNING id
190 ";
191 Ok(sqlx::query_scalar(query)
192 .bind(github_login)
193 .bind(email_address)
194 .bind(admin)
195 .fetch_one(&self.pool)
196 .await
197 .map(UserId)?)
198 }
199
200 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
201 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
202 Ok(sqlx::query_as(query)
203 .bind(limit as i32)
204 .bind((page * limit) as i32)
205 .fetch_all(&self.pool)
206 .await?)
207 }
208
209 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
210 let mut query = QueryBuilder::new(
211 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
212 );
213 query.push_values(
214 users,
215 |mut query, (github_login, email_address, invite_count)| {
216 query
217 .push_bind(github_login)
218 .push_bind(email_address)
219 .push_bind(false)
220 .push_bind(random_invite_code())
221 .push_bind(invite_count as i32);
222 },
223 );
224 query.push(
225 "
226 ON CONFLICT (github_login) DO UPDATE SET
227 github_login = excluded.github_login,
228 invite_count = excluded.invite_count,
229 invite_code = CASE WHEN users.invite_code IS NULL
230 THEN excluded.invite_code
231 ELSE users.invite_code
232 END
233 RETURNING id
234 ",
235 );
236
237 let rows = query.build().fetch_all(&self.pool).await?;
238 Ok(rows
239 .into_iter()
240 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
241 .collect())
242 }
243
244 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
245 let like_string = fuzzy_like_string(name_query);
246 let query = "
247 SELECT users.*
248 FROM users
249 WHERE github_login ILIKE $1
250 ORDER BY github_login <-> $2
251 LIMIT $3
252 ";
253 Ok(sqlx::query_as(query)
254 .bind(like_string)
255 .bind(name_query)
256 .bind(limit as i32)
257 .fetch_all(&self.pool)
258 .await?)
259 }
260
261 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
262 let users = self.get_users_by_ids(vec![id]).await?;
263 Ok(users.into_iter().next())
264 }
265
266 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
267 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
268 let query = "
269 SELECT users.*
270 FROM users
271 WHERE users.id = ANY ($1)
272 ";
273 Ok(sqlx::query_as(query)
274 .bind(&ids)
275 .fetch_all(&self.pool)
276 .await?)
277 }
278
279 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
280 let query = format!(
281 "
282 SELECT users.*
283 FROM users
284 WHERE invite_count = 0
285 AND inviter_id IS{} NULL
286 ",
287 if invited_by_another_user { " NOT" } else { "" }
288 );
289
290 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
291 }
292
293 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
294 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
295 Ok(sqlx::query_as(query)
296 .bind(github_login)
297 .fetch_optional(&self.pool)
298 .await?)
299 }
300
301 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
302 let query = "UPDATE users SET admin = $1 WHERE id = $2";
303 Ok(sqlx::query(query)
304 .bind(is_admin)
305 .bind(id.0)
306 .execute(&self.pool)
307 .await
308 .map(drop)?)
309 }
310
311 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
312 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
313 Ok(sqlx::query(query)
314 .bind(connected_once)
315 .bind(id.0)
316 .execute(&self.pool)
317 .await
318 .map(drop)?)
319 }
320
321 async fn destroy_user(&self, id: UserId) -> Result<()> {
322 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
323 sqlx::query(query)
324 .bind(id.0)
325 .execute(&self.pool)
326 .await
327 .map(drop)?;
328 let query = "DELETE FROM users WHERE id = $1;";
329 Ok(sqlx::query(query)
330 .bind(id.0)
331 .execute(&self.pool)
332 .await
333 .map(drop)?)
334 }
335
336 // invite codes
337
338 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
339 let mut tx = self.pool.begin().await?;
340 if count > 0 {
341 sqlx::query(
342 "
343 UPDATE users
344 SET invite_code = $1
345 WHERE id = $2 AND invite_code IS NULL
346 ",
347 )
348 .bind(random_invite_code())
349 .bind(id)
350 .execute(&mut tx)
351 .await?;
352 }
353
354 sqlx::query(
355 "
356 UPDATE users
357 SET invite_count = $1
358 WHERE id = $2
359 ",
360 )
361 .bind(count as i32)
362 .bind(id)
363 .execute(&mut tx)
364 .await?;
365 tx.commit().await?;
366 Ok(())
367 }
368
369 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
370 let result: Option<(String, i32)> = sqlx::query_as(
371 "
372 SELECT invite_code, invite_count
373 FROM users
374 WHERE id = $1 AND invite_code IS NOT NULL
375 ",
376 )
377 .bind(id)
378 .fetch_optional(&self.pool)
379 .await?;
380 if let Some((code, count)) = result {
381 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
382 } else {
383 Ok(None)
384 }
385 }
386
387 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
388 sqlx::query_as(
389 "
390 SELECT *
391 FROM users
392 WHERE invite_code = $1
393 ",
394 )
395 .bind(code)
396 .fetch_optional(&self.pool)
397 .await?
398 .ok_or_else(|| {
399 Error::Http(
400 StatusCode::NOT_FOUND,
401 "that invite code does not exist".to_string(),
402 )
403 })
404 }
405
406 async fn redeem_invite_code(
407 &self,
408 code: &str,
409 login: &str,
410 email_address: Option<&str>,
411 ) -> Result<UserId> {
412 let mut tx = self.pool.begin().await?;
413
414 let inviter_id: Option<UserId> = sqlx::query_scalar(
415 "
416 UPDATE users
417 SET invite_count = invite_count - 1
418 WHERE
419 invite_code = $1 AND
420 invite_count > 0
421 RETURNING id
422 ",
423 )
424 .bind(code)
425 .fetch_optional(&mut tx)
426 .await?;
427
428 let inviter_id = match inviter_id {
429 Some(inviter_id) => inviter_id,
430 None => {
431 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
432 .bind(code)
433 .fetch_optional(&mut tx)
434 .await?
435 .is_some()
436 {
437 Err(Error::Http(
438 StatusCode::UNAUTHORIZED,
439 "no invites remaining".to_string(),
440 ))?
441 } else {
442 Err(Error::Http(
443 StatusCode::NOT_FOUND,
444 "invite code not found".to_string(),
445 ))?
446 }
447 }
448 };
449
450 let invitee_id = sqlx::query_scalar(
451 "
452 INSERT INTO users
453 (github_login, email_address, admin, inviter_id, invite_code, invite_count)
454 VALUES
455 ($1, $2, 'f', $3, $4, $5)
456 RETURNING id
457 ",
458 )
459 .bind(login)
460 .bind(email_address)
461 .bind(inviter_id)
462 .bind(random_invite_code())
463 .bind(5)
464 .fetch_one(&mut tx)
465 .await
466 .map(UserId)?;
467
468 sqlx::query(
469 "
470 INSERT INTO contacts
471 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
472 VALUES
473 ($1, $2, 't', 't', 't')
474 ",
475 )
476 .bind(inviter_id)
477 .bind(invitee_id)
478 .execute(&mut tx)
479 .await?;
480
481 tx.commit().await?;
482 Ok(invitee_id)
483 }
484
485 // projects
486
487 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
488 Ok(sqlx::query_scalar(
489 "
490 INSERT INTO projects(host_user_id)
491 VALUES ($1)
492 RETURNING id
493 ",
494 )
495 .bind(host_user_id)
496 .fetch_one(&self.pool)
497 .await
498 .map(ProjectId)?)
499 }
500
501 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
502 sqlx::query(
503 "
504 UPDATE projects
505 SET unregistered = 't'
506 WHERE id = $1
507 ",
508 )
509 .bind(project_id)
510 .execute(&self.pool)
511 .await?;
512 Ok(())
513 }
514
515 async fn update_worktree_extensions(
516 &self,
517 project_id: ProjectId,
518 worktree_id: u64,
519 extensions: HashMap<String, u32>,
520 ) -> Result<()> {
521 if extensions.is_empty() {
522 return Ok(());
523 }
524
525 let mut query = QueryBuilder::new(
526 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
527 );
528 query.push_values(extensions, |mut query, (extension, count)| {
529 query
530 .push_bind(project_id)
531 .push_bind(worktree_id as i32)
532 .push_bind(extension)
533 .push_bind(count as i32);
534 });
535 query.push(
536 "
537 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
538 count = excluded.count
539 ",
540 );
541 query.build().execute(&self.pool).await?;
542
543 Ok(())
544 }
545
546 async fn get_project_extensions(
547 &self,
548 project_id: ProjectId,
549 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
550 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
551 struct WorktreeExtension {
552 worktree_id: i32,
553 extension: String,
554 count: i32,
555 }
556
557 let query = "
558 SELECT worktree_id, extension, count
559 FROM worktree_extensions
560 WHERE project_id = $1
561 ";
562 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
563 .bind(&project_id)
564 .fetch_all(&self.pool)
565 .await?;
566
567 let mut extension_counts = HashMap::default();
568 for count in counts {
569 extension_counts
570 .entry(count.worktree_id as u64)
571 .or_insert_with(HashMap::default)
572 .insert(count.extension, count.count as usize);
573 }
574 Ok(extension_counts)
575 }
576
577 async fn record_user_activity(
578 &self,
579 time_period: Range<OffsetDateTime>,
580 projects: &[(UserId, ProjectId)],
581 ) -> Result<()> {
582 let query = "
583 INSERT INTO project_activity_periods
584 (ended_at, duration_millis, user_id, project_id)
585 VALUES
586 ($1, $2, $3, $4);
587 ";
588
589 let mut tx = self.pool.begin().await?;
590 let duration_millis =
591 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
592 for (user_id, project_id) in projects {
593 sqlx::query(query)
594 .bind(time_period.end)
595 .bind(duration_millis)
596 .bind(user_id)
597 .bind(project_id)
598 .execute(&mut tx)
599 .await?;
600 }
601 tx.commit().await?;
602
603 Ok(())
604 }
605
606 async fn get_active_user_count(
607 &self,
608 time_period: Range<OffsetDateTime>,
609 min_duration: Duration,
610 only_collaborative: bool,
611 ) -> Result<usize> {
612 let mut with_clause = String::new();
613 with_clause.push_str("WITH\n");
614 with_clause.push_str(
615 "
616 project_durations AS (
617 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
618 FROM project_activity_periods
619 WHERE $1 < ended_at AND ended_at <= $2
620 GROUP BY user_id, project_id
621 ),
622 ",
623 );
624 with_clause.push_str(
625 "
626 project_collaborators as (
627 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
628 FROM project_durations
629 GROUP BY project_id
630 ),
631 ",
632 );
633
634 if only_collaborative {
635 with_clause.push_str(
636 "
637 user_durations AS (
638 SELECT user_id, SUM(project_duration) as total_duration
639 FROM project_durations, project_collaborators
640 WHERE
641 project_durations.project_id = project_collaborators.project_id AND
642 max_collaborators > 1
643 GROUP BY user_id
644 ORDER BY total_duration DESC
645 LIMIT $3
646 )
647 ",
648 );
649 } else {
650 with_clause.push_str(
651 "
652 user_durations AS (
653 SELECT user_id, SUM(project_duration) as total_duration
654 FROM project_durations
655 GROUP BY user_id
656 ORDER BY total_duration DESC
657 LIMIT $3
658 )
659 ",
660 );
661 }
662
663 let query = format!(
664 "
665 {with_clause}
666 SELECT count(user_durations.user_id)
667 FROM user_durations
668 WHERE user_durations.total_duration >= $3
669 "
670 );
671
672 let count: i64 = sqlx::query_scalar(&query)
673 .bind(time_period.start)
674 .bind(time_period.end)
675 .bind(min_duration.as_millis() as i64)
676 .fetch_one(&self.pool)
677 .await?;
678 Ok(count as usize)
679 }
680
681 async fn get_top_users_activity_summary(
682 &self,
683 time_period: Range<OffsetDateTime>,
684 max_user_count: usize,
685 ) -> Result<Vec<UserActivitySummary>> {
686 let query = "
687 WITH
688 project_durations AS (
689 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
690 FROM project_activity_periods
691 WHERE $1 < ended_at AND ended_at <= $2
692 GROUP BY user_id, project_id
693 ),
694 user_durations AS (
695 SELECT user_id, SUM(project_duration) as total_duration
696 FROM project_durations
697 GROUP BY user_id
698 ORDER BY total_duration DESC
699 LIMIT $3
700 ),
701 project_collaborators as (
702 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
703 FROM project_durations
704 GROUP BY project_id
705 )
706 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
707 FROM user_durations, project_durations, project_collaborators, users
708 WHERE
709 user_durations.user_id = project_durations.user_id AND
710 user_durations.user_id = users.id AND
711 project_durations.project_id = project_collaborators.project_id
712 ORDER BY total_duration DESC, user_id ASC, project_id ASC
713 ";
714
715 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
716 .bind(time_period.start)
717 .bind(time_period.end)
718 .bind(max_user_count as i32)
719 .fetch(&self.pool);
720
721 let mut result = Vec::<UserActivitySummary>::new();
722 while let Some(row) = rows.next().await {
723 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
724 let project_id = project_id;
725 let duration = Duration::from_millis(duration_millis as u64);
726 let project_activity = ProjectActivitySummary {
727 id: project_id,
728 duration,
729 max_collaborators: project_collaborators as usize,
730 };
731 if let Some(last_summary) = result.last_mut() {
732 if last_summary.id == user_id {
733 last_summary.project_activity.push(project_activity);
734 continue;
735 }
736 }
737 result.push(UserActivitySummary {
738 id: user_id,
739 project_activity: vec![project_activity],
740 github_login,
741 });
742 }
743
744 Ok(result)
745 }
746
747 async fn get_user_activity_timeline(
748 &self,
749 time_period: Range<OffsetDateTime>,
750 user_id: UserId,
751 ) -> Result<Vec<UserActivityPeriod>> {
752 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
753
754 let query = "
755 SELECT
756 project_activity_periods.ended_at,
757 project_activity_periods.duration_millis,
758 project_activity_periods.project_id,
759 worktree_extensions.extension,
760 worktree_extensions.count
761 FROM project_activity_periods
762 LEFT OUTER JOIN
763 worktree_extensions
764 ON
765 project_activity_periods.project_id = worktree_extensions.project_id
766 WHERE
767 project_activity_periods.user_id = $1 AND
768 $2 < project_activity_periods.ended_at AND
769 project_activity_periods.ended_at <= $3
770 ORDER BY project_activity_periods.id ASC
771 ";
772
773 let mut rows = sqlx::query_as::<
774 _,
775 (
776 PrimitiveDateTime,
777 i32,
778 ProjectId,
779 Option<String>,
780 Option<i32>,
781 ),
782 >(query)
783 .bind(user_id)
784 .bind(time_period.start)
785 .bind(time_period.end)
786 .fetch(&self.pool);
787
788 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
789 while let Some(row) = rows.next().await {
790 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
791 let ended_at = ended_at.assume_utc();
792 let duration = Duration::from_millis(duration_millis as u64);
793 let started_at = ended_at - duration;
794 let project_time_periods = time_periods.entry(project_id).or_default();
795
796 if let Some(prev_duration) = project_time_periods.last_mut() {
797 if started_at <= prev_duration.end + COALESCE_THRESHOLD
798 && ended_at >= prev_duration.start
799 {
800 prev_duration.end = cmp::max(prev_duration.end, ended_at);
801 } else {
802 project_time_periods.push(UserActivityPeriod {
803 project_id,
804 start: started_at,
805 end: ended_at,
806 extensions: Default::default(),
807 });
808 }
809 } else {
810 project_time_periods.push(UserActivityPeriod {
811 project_id,
812 start: started_at,
813 end: ended_at,
814 extensions: Default::default(),
815 });
816 }
817
818 if let Some((extension, extension_count)) = extension.zip(extension_count) {
819 project_time_periods
820 .last_mut()
821 .unwrap()
822 .extensions
823 .insert(extension, extension_count as usize);
824 }
825 }
826
827 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
828 durations.sort_unstable_by_key(|duration| duration.start);
829 Ok(durations)
830 }
831
832 // contacts
833
834 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
835 let query = "
836 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
837 FROM contacts
838 WHERE user_id_a = $1 OR user_id_b = $1;
839 ";
840
841 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
842 .bind(user_id)
843 .fetch(&self.pool);
844
845 let mut contacts = Vec::new();
846 while let Some(row) = rows.next().await {
847 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
848
849 if user_id_a == user_id {
850 if accepted {
851 contacts.push(Contact::Accepted {
852 user_id: user_id_b,
853 should_notify: should_notify && a_to_b,
854 });
855 } else if a_to_b {
856 contacts.push(Contact::Outgoing { user_id: user_id_b })
857 } else {
858 contacts.push(Contact::Incoming {
859 user_id: user_id_b,
860 should_notify,
861 });
862 }
863 } else if accepted {
864 contacts.push(Contact::Accepted {
865 user_id: user_id_a,
866 should_notify: should_notify && !a_to_b,
867 });
868 } else if a_to_b {
869 contacts.push(Contact::Incoming {
870 user_id: user_id_a,
871 should_notify,
872 });
873 } else {
874 contacts.push(Contact::Outgoing { user_id: user_id_a });
875 }
876 }
877
878 contacts.sort_unstable_by_key(|contact| contact.user_id());
879
880 Ok(contacts)
881 }
882
883 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
884 let (id_a, id_b) = if user_id_1 < user_id_2 {
885 (user_id_1, user_id_2)
886 } else {
887 (user_id_2, user_id_1)
888 };
889
890 let query = "
891 SELECT 1 FROM contacts
892 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
893 LIMIT 1
894 ";
895 Ok(sqlx::query_scalar::<_, i32>(query)
896 .bind(id_a.0)
897 .bind(id_b.0)
898 .fetch_optional(&self.pool)
899 .await?
900 .is_some())
901 }
902
903 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
904 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
905 (sender_id, receiver_id, true)
906 } else {
907 (receiver_id, sender_id, false)
908 };
909 let query = "
910 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
911 VALUES ($1, $2, $3, 'f', 't')
912 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
913 SET
914 accepted = 't',
915 should_notify = 'f'
916 WHERE
917 NOT contacts.accepted AND
918 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
919 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
920 ";
921 let result = sqlx::query(query)
922 .bind(id_a.0)
923 .bind(id_b.0)
924 .bind(a_to_b)
925 .execute(&self.pool)
926 .await?;
927
928 if result.rows_affected() == 1 {
929 Ok(())
930 } else {
931 Err(anyhow!("contact already requested"))?
932 }
933 }
934
935 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
936 let (id_a, id_b) = if responder_id < requester_id {
937 (responder_id, requester_id)
938 } else {
939 (requester_id, responder_id)
940 };
941 let query = "
942 DELETE FROM contacts
943 WHERE user_id_a = $1 AND user_id_b = $2;
944 ";
945 let result = sqlx::query(query)
946 .bind(id_a.0)
947 .bind(id_b.0)
948 .execute(&self.pool)
949 .await?;
950
951 if result.rows_affected() == 1 {
952 Ok(())
953 } else {
954 Err(anyhow!("no such contact"))?
955 }
956 }
957
958 async fn dismiss_contact_notification(
959 &self,
960 user_id: UserId,
961 contact_user_id: UserId,
962 ) -> Result<()> {
963 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
964 (user_id, contact_user_id, true)
965 } else {
966 (contact_user_id, user_id, false)
967 };
968
969 let query = "
970 UPDATE contacts
971 SET should_notify = 'f'
972 WHERE
973 user_id_a = $1 AND user_id_b = $2 AND
974 (
975 (a_to_b = $3 AND accepted) OR
976 (a_to_b != $3 AND NOT accepted)
977 );
978 ";
979
980 let result = sqlx::query(query)
981 .bind(id_a.0)
982 .bind(id_b.0)
983 .bind(a_to_b)
984 .execute(&self.pool)
985 .await?;
986
987 if result.rows_affected() == 0 {
988 Err(anyhow!("no such contact request"))?;
989 }
990
991 Ok(())
992 }
993
994 async fn respond_to_contact_request(
995 &self,
996 responder_id: UserId,
997 requester_id: UserId,
998 accept: bool,
999 ) -> Result<()> {
1000 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1001 (responder_id, requester_id, false)
1002 } else {
1003 (requester_id, responder_id, true)
1004 };
1005 let result = if accept {
1006 let query = "
1007 UPDATE contacts
1008 SET accepted = 't', should_notify = 't'
1009 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1010 ";
1011 sqlx::query(query)
1012 .bind(id_a.0)
1013 .bind(id_b.0)
1014 .bind(a_to_b)
1015 .execute(&self.pool)
1016 .await?
1017 } else {
1018 let query = "
1019 DELETE FROM contacts
1020 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1021 ";
1022 sqlx::query(query)
1023 .bind(id_a.0)
1024 .bind(id_b.0)
1025 .bind(a_to_b)
1026 .execute(&self.pool)
1027 .await?
1028 };
1029 if result.rows_affected() == 1 {
1030 Ok(())
1031 } else {
1032 Err(anyhow!("no such contact request"))?
1033 }
1034 }
1035
1036 // access tokens
1037
1038 async fn create_access_token_hash(
1039 &self,
1040 user_id: UserId,
1041 access_token_hash: &str,
1042 max_access_token_count: usize,
1043 ) -> Result<()> {
1044 let insert_query = "
1045 INSERT INTO access_tokens (user_id, hash)
1046 VALUES ($1, $2);
1047 ";
1048 let cleanup_query = "
1049 DELETE FROM access_tokens
1050 WHERE id IN (
1051 SELECT id from access_tokens
1052 WHERE user_id = $1
1053 ORDER BY id DESC
1054 OFFSET $3
1055 )
1056 ";
1057
1058 let mut tx = self.pool.begin().await?;
1059 sqlx::query(insert_query)
1060 .bind(user_id.0)
1061 .bind(access_token_hash)
1062 .execute(&mut tx)
1063 .await?;
1064 sqlx::query(cleanup_query)
1065 .bind(user_id.0)
1066 .bind(access_token_hash)
1067 .bind(max_access_token_count as i32)
1068 .execute(&mut tx)
1069 .await?;
1070 Ok(tx.commit().await?)
1071 }
1072
1073 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1074 let query = "
1075 SELECT hash
1076 FROM access_tokens
1077 WHERE user_id = $1
1078 ORDER BY id DESC
1079 ";
1080 Ok(sqlx::query_scalar(query)
1081 .bind(user_id.0)
1082 .fetch_all(&self.pool)
1083 .await?)
1084 }
1085
1086 // orgs
1087
1088 #[allow(unused)] // Help rust-analyzer
1089 #[cfg(any(test, feature = "seed-support"))]
1090 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1091 let query = "
1092 SELECT *
1093 FROM orgs
1094 WHERE slug = $1
1095 ";
1096 Ok(sqlx::query_as(query)
1097 .bind(slug)
1098 .fetch_optional(&self.pool)
1099 .await?)
1100 }
1101
1102 #[cfg(any(test, feature = "seed-support"))]
1103 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1104 let query = "
1105 INSERT INTO orgs (name, slug)
1106 VALUES ($1, $2)
1107 RETURNING id
1108 ";
1109 Ok(sqlx::query_scalar(query)
1110 .bind(name)
1111 .bind(slug)
1112 .fetch_one(&self.pool)
1113 .await
1114 .map(OrgId)?)
1115 }
1116
1117 #[cfg(any(test, feature = "seed-support"))]
1118 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1119 let query = "
1120 INSERT INTO org_memberships (org_id, user_id, admin)
1121 VALUES ($1, $2, $3)
1122 ON CONFLICT DO NOTHING
1123 ";
1124 Ok(sqlx::query(query)
1125 .bind(org_id.0)
1126 .bind(user_id.0)
1127 .bind(is_admin)
1128 .execute(&self.pool)
1129 .await
1130 .map(drop)?)
1131 }
1132
1133 // channels
1134
1135 #[cfg(any(test, feature = "seed-support"))]
1136 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1137 let query = "
1138 INSERT INTO channels (owner_id, owner_is_user, name)
1139 VALUES ($1, false, $2)
1140 RETURNING id
1141 ";
1142 Ok(sqlx::query_scalar(query)
1143 .bind(org_id.0)
1144 .bind(name)
1145 .fetch_one(&self.pool)
1146 .await
1147 .map(ChannelId)?)
1148 }
1149
1150 #[allow(unused)] // Help rust-analyzer
1151 #[cfg(any(test, feature = "seed-support"))]
1152 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1153 let query = "
1154 SELECT *
1155 FROM channels
1156 WHERE
1157 channels.owner_is_user = false AND
1158 channels.owner_id = $1
1159 ";
1160 Ok(sqlx::query_as(query)
1161 .bind(org_id.0)
1162 .fetch_all(&self.pool)
1163 .await?)
1164 }
1165
1166 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1167 let query = "
1168 SELECT
1169 channels.*
1170 FROM
1171 channel_memberships, channels
1172 WHERE
1173 channel_memberships.user_id = $1 AND
1174 channel_memberships.channel_id = channels.id
1175 ";
1176 Ok(sqlx::query_as(query)
1177 .bind(user_id.0)
1178 .fetch_all(&self.pool)
1179 .await?)
1180 }
1181
1182 async fn can_user_access_channel(
1183 &self,
1184 user_id: UserId,
1185 channel_id: ChannelId,
1186 ) -> Result<bool> {
1187 let query = "
1188 SELECT id
1189 FROM channel_memberships
1190 WHERE user_id = $1 AND channel_id = $2
1191 LIMIT 1
1192 ";
1193 Ok(sqlx::query_scalar::<_, i32>(query)
1194 .bind(user_id.0)
1195 .bind(channel_id.0)
1196 .fetch_optional(&self.pool)
1197 .await
1198 .map(|e| e.is_some())?)
1199 }
1200
1201 #[cfg(any(test, feature = "seed-support"))]
1202 async fn add_channel_member(
1203 &self,
1204 channel_id: ChannelId,
1205 user_id: UserId,
1206 is_admin: bool,
1207 ) -> Result<()> {
1208 let query = "
1209 INSERT INTO channel_memberships (channel_id, user_id, admin)
1210 VALUES ($1, $2, $3)
1211 ON CONFLICT DO NOTHING
1212 ";
1213 Ok(sqlx::query(query)
1214 .bind(channel_id.0)
1215 .bind(user_id.0)
1216 .bind(is_admin)
1217 .execute(&self.pool)
1218 .await
1219 .map(drop)?)
1220 }
1221
1222 // messages
1223
1224 async fn create_channel_message(
1225 &self,
1226 channel_id: ChannelId,
1227 sender_id: UserId,
1228 body: &str,
1229 timestamp: OffsetDateTime,
1230 nonce: u128,
1231 ) -> Result<MessageId> {
1232 let query = "
1233 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1234 VALUES ($1, $2, $3, $4, $5)
1235 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1236 RETURNING id
1237 ";
1238 Ok(sqlx::query_scalar(query)
1239 .bind(channel_id.0)
1240 .bind(sender_id.0)
1241 .bind(body)
1242 .bind(timestamp)
1243 .bind(Uuid::from_u128(nonce))
1244 .fetch_one(&self.pool)
1245 .await
1246 .map(MessageId)?)
1247 }
1248
1249 async fn get_channel_messages(
1250 &self,
1251 channel_id: ChannelId,
1252 count: usize,
1253 before_id: Option<MessageId>,
1254 ) -> Result<Vec<ChannelMessage>> {
1255 let query = r#"
1256 SELECT * FROM (
1257 SELECT
1258 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1259 FROM
1260 channel_messages
1261 WHERE
1262 channel_id = $1 AND
1263 id < $2
1264 ORDER BY id DESC
1265 LIMIT $3
1266 ) as recent_messages
1267 ORDER BY id ASC
1268 "#;
1269 Ok(sqlx::query_as(query)
1270 .bind(channel_id.0)
1271 .bind(before_id.unwrap_or(MessageId::MAX))
1272 .bind(count as i64)
1273 .fetch_all(&self.pool)
1274 .await?)
1275 }
1276
1277 #[cfg(test)]
1278 async fn teardown(&self, url: &str) {
1279 use util::ResultExt;
1280
1281 let query = "
1282 SELECT pg_terminate_backend(pg_stat_activity.pid)
1283 FROM pg_stat_activity
1284 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1285 ";
1286 sqlx::query(query).execute(&self.pool).await.log_err();
1287 self.pool.close().await;
1288 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1289 .await
1290 .log_err();
1291 }
1292
1293 #[cfg(test)]
1294 fn as_fake(&self) -> Option<&tests::FakeDb> {
1295 None
1296 }
1297}
1298
1299macro_rules! id_type {
1300 ($name:ident) => {
1301 #[derive(
1302 Clone,
1303 Copy,
1304 Debug,
1305 Default,
1306 PartialEq,
1307 Eq,
1308 PartialOrd,
1309 Ord,
1310 Hash,
1311 sqlx::Type,
1312 Serialize,
1313 Deserialize,
1314 )]
1315 #[sqlx(transparent)]
1316 #[serde(transparent)]
1317 pub struct $name(pub i32);
1318
1319 impl $name {
1320 #[allow(unused)]
1321 pub const MAX: Self = Self(i32::MAX);
1322
1323 #[allow(unused)]
1324 pub fn from_proto(value: u64) -> Self {
1325 Self(value as i32)
1326 }
1327
1328 #[allow(unused)]
1329 pub fn to_proto(self) -> u64 {
1330 self.0 as u64
1331 }
1332 }
1333
1334 impl std::fmt::Display for $name {
1335 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1336 self.0.fmt(f)
1337 }
1338 }
1339 };
1340}
1341
1342id_type!(UserId);
1343#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1344pub struct User {
1345 pub id: UserId,
1346 pub github_login: String,
1347 pub email_address: Option<String>,
1348 pub admin: bool,
1349 pub invite_code: Option<String>,
1350 pub invite_count: i32,
1351 pub connected_once: bool,
1352}
1353
1354id_type!(ProjectId);
1355#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1356pub struct Project {
1357 pub id: ProjectId,
1358 pub host_user_id: UserId,
1359 pub unregistered: bool,
1360}
1361
1362#[derive(Clone, Debug, PartialEq, Serialize)]
1363pub struct UserActivitySummary {
1364 pub id: UserId,
1365 pub github_login: String,
1366 pub project_activity: Vec<ProjectActivitySummary>,
1367}
1368
1369#[derive(Clone, Debug, PartialEq, Serialize)]
1370pub struct ProjectActivitySummary {
1371 id: ProjectId,
1372 duration: Duration,
1373 max_collaborators: usize,
1374}
1375
1376#[derive(Clone, Debug, PartialEq, Serialize)]
1377pub struct UserActivityPeriod {
1378 project_id: ProjectId,
1379 #[serde(with = "time::serde::iso8601")]
1380 start: OffsetDateTime,
1381 #[serde(with = "time::serde::iso8601")]
1382 end: OffsetDateTime,
1383 extensions: HashMap<String, usize>,
1384}
1385
1386id_type!(OrgId);
1387#[derive(FromRow)]
1388pub struct Org {
1389 pub id: OrgId,
1390 pub name: String,
1391 pub slug: String,
1392}
1393
1394id_type!(ChannelId);
1395#[derive(Clone, Debug, FromRow, Serialize)]
1396pub struct Channel {
1397 pub id: ChannelId,
1398 pub name: String,
1399 pub owner_id: i32,
1400 pub owner_is_user: bool,
1401}
1402
1403id_type!(MessageId);
1404#[derive(Clone, Debug, FromRow)]
1405pub struct ChannelMessage {
1406 pub id: MessageId,
1407 pub channel_id: ChannelId,
1408 pub sender_id: UserId,
1409 pub body: String,
1410 pub sent_at: OffsetDateTime,
1411 pub nonce: Uuid,
1412}
1413
1414#[derive(Clone, Debug, PartialEq, Eq)]
1415pub enum Contact {
1416 Accepted {
1417 user_id: UserId,
1418 should_notify: bool,
1419 },
1420 Outgoing {
1421 user_id: UserId,
1422 },
1423 Incoming {
1424 user_id: UserId,
1425 should_notify: bool,
1426 },
1427}
1428
1429impl Contact {
1430 pub fn user_id(&self) -> UserId {
1431 match self {
1432 Contact::Accepted { user_id, .. } => *user_id,
1433 Contact::Outgoing { user_id } => *user_id,
1434 Contact::Incoming { user_id, .. } => *user_id,
1435 }
1436 }
1437}
1438
1439#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1440pub struct IncomingContactRequest {
1441 pub requester_id: UserId,
1442 pub should_notify: bool,
1443}
1444
1445fn fuzzy_like_string(string: &str) -> String {
1446 let mut result = String::with_capacity(string.len() * 2 + 1);
1447 for c in string.chars() {
1448 if c.is_alphanumeric() {
1449 result.push('%');
1450 result.push(c);
1451 }
1452 }
1453 result.push('%');
1454 result
1455}
1456
1457fn random_invite_code() -> String {
1458 nanoid::nanoid!(16)
1459}
1460
1461#[cfg(test)]
1462pub mod tests {
1463 use super::*;
1464 use anyhow::anyhow;
1465 use collections::BTreeMap;
1466 use gpui::executor::{Background, Deterministic};
1467 use lazy_static::lazy_static;
1468 use parking_lot::Mutex;
1469 use rand::prelude::*;
1470 use sqlx::{
1471 migrate::{MigrateDatabase, Migrator},
1472 Postgres,
1473 };
1474 use std::{path::Path, sync::Arc};
1475 use util::post_inc;
1476
1477 #[tokio::test(flavor = "multi_thread")]
1478 async fn test_get_users_by_ids() {
1479 for test_db in [
1480 TestDb::postgres().await,
1481 TestDb::fake(build_background_executor()),
1482 ] {
1483 let db = test_db.db();
1484
1485 let user = db.create_user("user", None, false).await.unwrap();
1486 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1487 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1488 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1489
1490 assert_eq!(
1491 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1492 .await
1493 .unwrap(),
1494 vec![
1495 User {
1496 id: user,
1497 github_login: "user".to_string(),
1498 admin: false,
1499 ..Default::default()
1500 },
1501 User {
1502 id: friend1,
1503 github_login: "friend-1".to_string(),
1504 admin: false,
1505 ..Default::default()
1506 },
1507 User {
1508 id: friend2,
1509 github_login: "friend-2".to_string(),
1510 admin: false,
1511 ..Default::default()
1512 },
1513 User {
1514 id: friend3,
1515 github_login: "friend-3".to_string(),
1516 admin: false,
1517 ..Default::default()
1518 }
1519 ]
1520 );
1521 }
1522 }
1523
1524 #[tokio::test(flavor = "multi_thread")]
1525 async fn test_create_users() {
1526 let db = TestDb::postgres().await;
1527 let db = db.db();
1528
1529 // Create the first batch of users, ensuring invite counts are assigned
1530 // correctly and the respective invite codes are unique.
1531 let user_ids_batch_1 = db
1532 .create_users(vec![
1533 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1534 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1535 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1536 ])
1537 .await
1538 .unwrap();
1539 assert_eq!(user_ids_batch_1.len(), 3);
1540
1541 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1542 assert_eq!(users.len(), 3);
1543 assert_eq!(users[0].github_login, "user1");
1544 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1545 assert_eq!(users[0].invite_count, 5);
1546 assert_eq!(users[1].github_login, "user2");
1547 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1548 assert_eq!(users[1].invite_count, 4);
1549 assert_eq!(users[2].github_login, "user3");
1550 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1551 assert_eq!(users[2].invite_count, 3);
1552
1553 let invite_code_1 = users[0].invite_code.clone().unwrap();
1554 let invite_code_2 = users[1].invite_code.clone().unwrap();
1555 let invite_code_3 = users[2].invite_code.clone().unwrap();
1556 assert_ne!(invite_code_1, invite_code_2);
1557 assert_ne!(invite_code_1, invite_code_3);
1558 assert_ne!(invite_code_2, invite_code_3);
1559
1560 // Create the second batch of users and include a user that is already in the database, ensuring
1561 // the invite count for the existing user is updated without changing their invite code.
1562 let user_ids_batch_2 = db
1563 .create_users(vec![
1564 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1565 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1566 ])
1567 .await
1568 .unwrap();
1569 assert_eq!(user_ids_batch_2.len(), 2);
1570 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1571
1572 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1573 assert_eq!(users.len(), 2);
1574 assert_eq!(users[0].github_login, "user2");
1575 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1576 assert_eq!(users[0].invite_count, 10);
1577 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1578 assert_eq!(users[1].github_login, "user4");
1579 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1580 assert_eq!(users[1].invite_count, 2);
1581
1582 let invite_code_4 = users[1].invite_code.clone().unwrap();
1583 assert_ne!(invite_code_4, invite_code_1);
1584 assert_ne!(invite_code_4, invite_code_2);
1585 assert_ne!(invite_code_4, invite_code_3);
1586 }
1587
1588 #[tokio::test(flavor = "multi_thread")]
1589 async fn test_worktree_extensions() {
1590 let test_db = TestDb::postgres().await;
1591 let db = test_db.db();
1592
1593 let user = db.create_user("user_1", None, false).await.unwrap();
1594 let project = db.register_project(user).await.unwrap();
1595
1596 db.update_worktree_extensions(project, 100, Default::default())
1597 .await
1598 .unwrap();
1599 db.update_worktree_extensions(
1600 project,
1601 100,
1602 [("rs".to_string(), 5), ("md".to_string(), 3)]
1603 .into_iter()
1604 .collect(),
1605 )
1606 .await
1607 .unwrap();
1608 db.update_worktree_extensions(
1609 project,
1610 100,
1611 [("rs".to_string(), 6), ("md".to_string(), 5)]
1612 .into_iter()
1613 .collect(),
1614 )
1615 .await
1616 .unwrap();
1617 db.update_worktree_extensions(
1618 project,
1619 101,
1620 [("ts".to_string(), 2), ("md".to_string(), 1)]
1621 .into_iter()
1622 .collect(),
1623 )
1624 .await
1625 .unwrap();
1626
1627 assert_eq!(
1628 db.get_project_extensions(project).await.unwrap(),
1629 [
1630 (
1631 100,
1632 [("rs".into(), 6), ("md".into(), 5),]
1633 .into_iter()
1634 .collect::<HashMap<_, _>>()
1635 ),
1636 (
1637 101,
1638 [("ts".into(), 2), ("md".into(), 1),]
1639 .into_iter()
1640 .collect::<HashMap<_, _>>()
1641 )
1642 ]
1643 .into_iter()
1644 .collect()
1645 );
1646 }
1647
1648 #[tokio::test(flavor = "multi_thread")]
1649 async fn test_user_activity() {
1650 let test_db = TestDb::postgres().await;
1651 let db = test_db.db();
1652
1653 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1654 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1655 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1656 let project_1 = db.register_project(user_1).await.unwrap();
1657 db.update_worktree_extensions(
1658 project_1,
1659 1,
1660 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1661 )
1662 .await
1663 .unwrap();
1664 let project_2 = db.register_project(user_2).await.unwrap();
1665 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1666
1667 // User 2 opens a project
1668 let t1 = t0 + Duration::from_secs(10);
1669 db.record_user_activity(t0..t1, &[(user_2, project_2)])
1670 .await
1671 .unwrap();
1672
1673 let t2 = t1 + Duration::from_secs(10);
1674 db.record_user_activity(t1..t2, &[(user_2, project_2)])
1675 .await
1676 .unwrap();
1677
1678 // User 1 joins the project
1679 let t3 = t2 + Duration::from_secs(10);
1680 db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1681 .await
1682 .unwrap();
1683
1684 // User 1 opens another project
1685 let t4 = t3 + Duration::from_secs(10);
1686 db.record_user_activity(
1687 t3..t4,
1688 &[
1689 (user_2, project_2),
1690 (user_1, project_2),
1691 (user_1, project_1),
1692 ],
1693 )
1694 .await
1695 .unwrap();
1696
1697 // User 3 joins that project
1698 let t5 = t4 + Duration::from_secs(10);
1699 db.record_user_activity(
1700 t4..t5,
1701 &[
1702 (user_2, project_2),
1703 (user_1, project_2),
1704 (user_1, project_1),
1705 (user_3, project_1),
1706 ],
1707 )
1708 .await
1709 .unwrap();
1710
1711 // User 2 leaves
1712 let t6 = t5 + Duration::from_secs(5);
1713 db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1714 .await
1715 .unwrap();
1716
1717 let t7 = t6 + Duration::from_secs(60);
1718 let t8 = t7 + Duration::from_secs(10);
1719 db.record_user_activity(t7..t8, &[(user_1, project_1)])
1720 .await
1721 .unwrap();
1722
1723 assert_eq!(
1724 db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1725 &[
1726 UserActivitySummary {
1727 id: user_1,
1728 github_login: "user_1".to_string(),
1729 project_activity: vec![
1730 ProjectActivitySummary {
1731 id: project_1,
1732 duration: Duration::from_secs(25),
1733 max_collaborators: 2
1734 },
1735 ProjectActivitySummary {
1736 id: project_2,
1737 duration: Duration::from_secs(30),
1738 max_collaborators: 2
1739 }
1740 ]
1741 },
1742 UserActivitySummary {
1743 id: user_2,
1744 github_login: "user_2".to_string(),
1745 project_activity: vec![ProjectActivitySummary {
1746 id: project_2,
1747 duration: Duration::from_secs(50),
1748 max_collaborators: 2
1749 }]
1750 },
1751 UserActivitySummary {
1752 id: user_3,
1753 github_login: "user_3".to_string(),
1754 project_activity: vec![ProjectActivitySummary {
1755 id: project_1,
1756 duration: Duration::from_secs(15),
1757 max_collaborators: 2
1758 }]
1759 },
1760 ]
1761 );
1762
1763 assert_eq!(
1764 db.get_active_user_count(t0..t6, Duration::from_secs(56), false)
1765 .await
1766 .unwrap(),
1767 0
1768 );
1769 assert_eq!(
1770 db.get_active_user_count(t0..t6, Duration::from_secs(56), true)
1771 .await
1772 .unwrap(),
1773 0
1774 );
1775 assert_eq!(
1776 db.get_active_user_count(t0..t6, Duration::from_secs(54), false)
1777 .await
1778 .unwrap(),
1779 1
1780 );
1781 assert_eq!(
1782 db.get_active_user_count(t0..t6, Duration::from_secs(54), true)
1783 .await
1784 .unwrap(),
1785 1
1786 );
1787 assert_eq!(
1788 db.get_active_user_count(t0..t6, Duration::from_secs(30), false)
1789 .await
1790 .unwrap(),
1791 2
1792 );
1793 assert_eq!(
1794 db.get_active_user_count(t0..t6, Duration::from_secs(30), true)
1795 .await
1796 .unwrap(),
1797 2
1798 );
1799 assert_eq!(
1800 db.get_active_user_count(t0..t6, Duration::from_secs(10), false)
1801 .await
1802 .unwrap(),
1803 3
1804 );
1805 assert_eq!(
1806 db.get_active_user_count(t0..t6, Duration::from_secs(10), true)
1807 .await
1808 .unwrap(),
1809 3
1810 );
1811 assert_eq!(
1812 db.get_active_user_count(t0..t1, Duration::from_secs(5), false)
1813 .await
1814 .unwrap(),
1815 1
1816 );
1817 assert_eq!(
1818 db.get_active_user_count(t0..t1, Duration::from_secs(5), true)
1819 .await
1820 .unwrap(),
1821 0
1822 );
1823
1824 assert_eq!(
1825 db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1826 &[
1827 UserActivityPeriod {
1828 project_id: project_1,
1829 start: t3,
1830 end: t6,
1831 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1832 },
1833 UserActivityPeriod {
1834 project_id: project_2,
1835 start: t3,
1836 end: t5,
1837 extensions: Default::default(),
1838 },
1839 ]
1840 );
1841 assert_eq!(
1842 db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1843 &[
1844 UserActivityPeriod {
1845 project_id: project_2,
1846 start: t2,
1847 end: t5,
1848 extensions: Default::default(),
1849 },
1850 UserActivityPeriod {
1851 project_id: project_1,
1852 start: t3,
1853 end: t6,
1854 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1855 },
1856 UserActivityPeriod {
1857 project_id: project_1,
1858 start: t7,
1859 end: t8,
1860 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1861 },
1862 ]
1863 );
1864 }
1865
1866 #[tokio::test(flavor = "multi_thread")]
1867 async fn test_recent_channel_messages() {
1868 for test_db in [
1869 TestDb::postgres().await,
1870 TestDb::fake(build_background_executor()),
1871 ] {
1872 let db = test_db.db();
1873 let user = db.create_user("user", None, false).await.unwrap();
1874 let org = db.create_org("org", "org").await.unwrap();
1875 let channel = db.create_org_channel(org, "channel").await.unwrap();
1876 for i in 0..10 {
1877 db.create_channel_message(
1878 channel,
1879 user,
1880 &i.to_string(),
1881 OffsetDateTime::now_utc(),
1882 i,
1883 )
1884 .await
1885 .unwrap();
1886 }
1887
1888 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1889 assert_eq!(
1890 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1891 ["5", "6", "7", "8", "9"]
1892 );
1893
1894 let prev_messages = db
1895 .get_channel_messages(channel, 4, Some(messages[0].id))
1896 .await
1897 .unwrap();
1898 assert_eq!(
1899 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1900 ["1", "2", "3", "4"]
1901 );
1902 }
1903 }
1904
1905 #[tokio::test(flavor = "multi_thread")]
1906 async fn test_channel_message_nonces() {
1907 for test_db in [
1908 TestDb::postgres().await,
1909 TestDb::fake(build_background_executor()),
1910 ] {
1911 let db = test_db.db();
1912 let user = db.create_user("user", None, false).await.unwrap();
1913 let org = db.create_org("org", "org").await.unwrap();
1914 let channel = db.create_org_channel(org, "channel").await.unwrap();
1915
1916 let msg1_id = db
1917 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1918 .await
1919 .unwrap();
1920 let msg2_id = db
1921 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1922 .await
1923 .unwrap();
1924 let msg3_id = db
1925 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1926 .await
1927 .unwrap();
1928 let msg4_id = db
1929 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1930 .await
1931 .unwrap();
1932
1933 assert_ne!(msg1_id, msg2_id);
1934 assert_eq!(msg1_id, msg3_id);
1935 assert_eq!(msg2_id, msg4_id);
1936 }
1937 }
1938
1939 #[tokio::test(flavor = "multi_thread")]
1940 async fn test_create_access_tokens() {
1941 let test_db = TestDb::postgres().await;
1942 let db = test_db.db();
1943 let user = db.create_user("the-user", None, false).await.unwrap();
1944
1945 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1946 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1947 assert_eq!(
1948 db.get_access_token_hashes(user).await.unwrap(),
1949 &["h2".to_string(), "h1".to_string()]
1950 );
1951
1952 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1953 assert_eq!(
1954 db.get_access_token_hashes(user).await.unwrap(),
1955 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1956 );
1957
1958 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1959 assert_eq!(
1960 db.get_access_token_hashes(user).await.unwrap(),
1961 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1962 );
1963
1964 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1965 assert_eq!(
1966 db.get_access_token_hashes(user).await.unwrap(),
1967 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1968 );
1969 }
1970
1971 #[test]
1972 fn test_fuzzy_like_string() {
1973 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1974 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1975 assert_eq!(fuzzy_like_string(" z "), "%z%");
1976 }
1977
1978 #[tokio::test(flavor = "multi_thread")]
1979 async fn test_fuzzy_search_users() {
1980 let test_db = TestDb::postgres().await;
1981 let db = test_db.db();
1982 for github_login in [
1983 "California",
1984 "colorado",
1985 "oregon",
1986 "washington",
1987 "florida",
1988 "delaware",
1989 "rhode-island",
1990 ] {
1991 db.create_user(github_login, None, false).await.unwrap();
1992 }
1993
1994 assert_eq!(
1995 fuzzy_search_user_names(db, "clr").await,
1996 &["colorado", "California"]
1997 );
1998 assert_eq!(
1999 fuzzy_search_user_names(db, "ro").await,
2000 &["rhode-island", "colorado", "oregon"],
2001 );
2002
2003 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
2004 db.fuzzy_search_users(query, 10)
2005 .await
2006 .unwrap()
2007 .into_iter()
2008 .map(|user| user.github_login)
2009 .collect::<Vec<_>>()
2010 }
2011 }
2012
2013 #[tokio::test(flavor = "multi_thread")]
2014 async fn test_add_contacts() {
2015 for test_db in [
2016 TestDb::postgres().await,
2017 TestDb::fake(build_background_executor()),
2018 ] {
2019 let db = test_db.db();
2020
2021 let user_1 = db.create_user("user1", None, false).await.unwrap();
2022 let user_2 = db.create_user("user2", None, false).await.unwrap();
2023 let user_3 = db.create_user("user3", None, false).await.unwrap();
2024
2025 // User starts with no contacts
2026 assert_eq!(db.get_contacts(user_1).await.unwrap(), vec![]);
2027
2028 // User requests a contact. Both users see the pending request.
2029 db.send_contact_request(user_1, user_2).await.unwrap();
2030 assert!(!db.has_contact(user_1, user_2).await.unwrap());
2031 assert!(!db.has_contact(user_2, user_1).await.unwrap());
2032 assert_eq!(
2033 db.get_contacts(user_1).await.unwrap(),
2034 &[Contact::Outgoing { user_id: user_2 }],
2035 );
2036 assert_eq!(
2037 db.get_contacts(user_2).await.unwrap(),
2038 &[Contact::Incoming {
2039 user_id: user_1,
2040 should_notify: true
2041 }]
2042 );
2043
2044 // User 2 dismisses the contact request notification without accepting or rejecting.
2045 // We shouldn't notify them again.
2046 db.dismiss_contact_notification(user_1, user_2)
2047 .await
2048 .unwrap_err();
2049 db.dismiss_contact_notification(user_2, user_1)
2050 .await
2051 .unwrap();
2052 assert_eq!(
2053 db.get_contacts(user_2).await.unwrap(),
2054 &[Contact::Incoming {
2055 user_id: user_1,
2056 should_notify: false
2057 }]
2058 );
2059
2060 // User can't accept their own contact request
2061 db.respond_to_contact_request(user_1, user_2, true)
2062 .await
2063 .unwrap_err();
2064
2065 // User accepts a contact request. Both users see the contact.
2066 db.respond_to_contact_request(user_2, user_1, true)
2067 .await
2068 .unwrap();
2069 assert_eq!(
2070 db.get_contacts(user_1).await.unwrap(),
2071 &[Contact::Accepted {
2072 user_id: user_2,
2073 should_notify: true
2074 }],
2075 );
2076 assert!(db.has_contact(user_1, user_2).await.unwrap());
2077 assert!(db.has_contact(user_2, user_1).await.unwrap());
2078 assert_eq!(
2079 db.get_contacts(user_2).await.unwrap(),
2080 &[Contact::Accepted {
2081 user_id: user_1,
2082 should_notify: false,
2083 }]
2084 );
2085
2086 // Users cannot re-request existing contacts.
2087 db.send_contact_request(user_1, user_2).await.unwrap_err();
2088 db.send_contact_request(user_2, user_1).await.unwrap_err();
2089
2090 // Users can't dismiss notifications of them accepting other users' requests.
2091 db.dismiss_contact_notification(user_2, user_1)
2092 .await
2093 .unwrap_err();
2094 assert_eq!(
2095 db.get_contacts(user_1).await.unwrap(),
2096 &[Contact::Accepted {
2097 user_id: user_2,
2098 should_notify: true,
2099 }]
2100 );
2101
2102 // Users can dismiss notifications of other users accepting their requests.
2103 db.dismiss_contact_notification(user_1, user_2)
2104 .await
2105 .unwrap();
2106 assert_eq!(
2107 db.get_contacts(user_1).await.unwrap(),
2108 &[Contact::Accepted {
2109 user_id: user_2,
2110 should_notify: false,
2111 },]
2112 );
2113
2114 // Users send each other concurrent contact requests and
2115 // see that they are immediately accepted.
2116 db.send_contact_request(user_1, user_3).await.unwrap();
2117 db.send_contact_request(user_3, user_1).await.unwrap();
2118 assert_eq!(
2119 db.get_contacts(user_1).await.unwrap(),
2120 &[
2121 Contact::Accepted {
2122 user_id: user_2,
2123 should_notify: false,
2124 },
2125 Contact::Accepted {
2126 user_id: user_3,
2127 should_notify: false
2128 },
2129 ]
2130 );
2131 assert_eq!(
2132 db.get_contacts(user_3).await.unwrap(),
2133 &[Contact::Accepted {
2134 user_id: user_1,
2135 should_notify: false
2136 }],
2137 );
2138
2139 // User declines a contact request. Both users see that it is gone.
2140 db.send_contact_request(user_2, user_3).await.unwrap();
2141 db.respond_to_contact_request(user_3, user_2, false)
2142 .await
2143 .unwrap();
2144 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2145 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2146 assert_eq!(
2147 db.get_contacts(user_2).await.unwrap(),
2148 &[Contact::Accepted {
2149 user_id: user_1,
2150 should_notify: false
2151 }]
2152 );
2153 assert_eq!(
2154 db.get_contacts(user_3).await.unwrap(),
2155 &[Contact::Accepted {
2156 user_id: user_1,
2157 should_notify: false
2158 }],
2159 );
2160 }
2161 }
2162
2163 #[tokio::test(flavor = "multi_thread")]
2164 async fn test_invite_codes() {
2165 let postgres = TestDb::postgres().await;
2166 let db = postgres.db();
2167 let user1 = db.create_user("user-1", None, false).await.unwrap();
2168
2169 // Initially, user 1 has no invite code
2170 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2171
2172 // Setting invite count to 0 when no code is assigned does not assign a new code
2173 db.set_invite_count(user1, 0).await.unwrap();
2174 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2175
2176 // User 1 creates an invite code that can be used twice.
2177 db.set_invite_count(user1, 2).await.unwrap();
2178 let (invite_code, invite_count) =
2179 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2180 assert_eq!(invite_count, 2);
2181
2182 // User 2 redeems the invite code and becomes a contact of user 1.
2183 let user2 = db
2184 .redeem_invite_code(&invite_code, "user-2", None)
2185 .await
2186 .unwrap();
2187 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2188 assert_eq!(invite_count, 1);
2189 assert_eq!(
2190 db.get_contacts(user1).await.unwrap(),
2191 [Contact::Accepted {
2192 user_id: user2,
2193 should_notify: true
2194 }]
2195 );
2196 assert_eq!(
2197 db.get_contacts(user2).await.unwrap(),
2198 [Contact::Accepted {
2199 user_id: user1,
2200 should_notify: false
2201 }]
2202 );
2203
2204 // User 3 redeems the invite code and becomes a contact of user 1.
2205 let user3 = db
2206 .redeem_invite_code(&invite_code, "user-3", None)
2207 .await
2208 .unwrap();
2209 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2210 assert_eq!(invite_count, 0);
2211 assert_eq!(
2212 db.get_contacts(user1).await.unwrap(),
2213 [
2214 Contact::Accepted {
2215 user_id: user2,
2216 should_notify: true
2217 },
2218 Contact::Accepted {
2219 user_id: user3,
2220 should_notify: true
2221 }
2222 ]
2223 );
2224 assert_eq!(
2225 db.get_contacts(user3).await.unwrap(),
2226 [Contact::Accepted {
2227 user_id: user1,
2228 should_notify: false
2229 }]
2230 );
2231
2232 // Trying to reedem the code for the third time results in an error.
2233 db.redeem_invite_code(&invite_code, "user-4", None)
2234 .await
2235 .unwrap_err();
2236
2237 // Invite count can be updated after the code has been created.
2238 db.set_invite_count(user1, 2).await.unwrap();
2239 let (latest_code, invite_count) =
2240 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2241 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2242 assert_eq!(invite_count, 2);
2243
2244 // User 4 can now redeem the invite code and becomes a contact of user 1.
2245 let user4 = db
2246 .redeem_invite_code(&invite_code, "user-4", None)
2247 .await
2248 .unwrap();
2249 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2250 assert_eq!(invite_count, 1);
2251 assert_eq!(
2252 db.get_contacts(user1).await.unwrap(),
2253 [
2254 Contact::Accepted {
2255 user_id: user2,
2256 should_notify: true
2257 },
2258 Contact::Accepted {
2259 user_id: user3,
2260 should_notify: true
2261 },
2262 Contact::Accepted {
2263 user_id: user4,
2264 should_notify: true
2265 }
2266 ]
2267 );
2268 assert_eq!(
2269 db.get_contacts(user4).await.unwrap(),
2270 [Contact::Accepted {
2271 user_id: user1,
2272 should_notify: false
2273 }]
2274 );
2275
2276 // An existing user cannot redeem invite codes.
2277 db.redeem_invite_code(&invite_code, "user-2", None)
2278 .await
2279 .unwrap_err();
2280 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2281 assert_eq!(invite_count, 1);
2282
2283 // Ensure invited users get invite codes too.
2284 assert_eq!(
2285 db.get_invite_code_for_user(user2).await.unwrap().unwrap().1,
2286 5
2287 );
2288 assert_eq!(
2289 db.get_invite_code_for_user(user3).await.unwrap().unwrap().1,
2290 5
2291 );
2292 assert_eq!(
2293 db.get_invite_code_for_user(user4).await.unwrap().unwrap().1,
2294 5
2295 );
2296 }
2297
2298 pub struct TestDb {
2299 pub db: Option<Arc<dyn Db>>,
2300 pub url: String,
2301 }
2302
2303 impl TestDb {
2304 #[allow(clippy::await_holding_lock)]
2305 pub async fn postgres() -> Self {
2306 lazy_static! {
2307 static ref LOCK: Mutex<()> = Mutex::new(());
2308 }
2309
2310 let _guard = LOCK.lock();
2311 let mut rng = StdRng::from_entropy();
2312 let name = format!("zed-test-{}", rng.gen::<u128>());
2313 let url = format!("postgres://postgres@localhost/{}", name);
2314 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2315 Postgres::create_database(&url)
2316 .await
2317 .expect("failed to create test db");
2318 let db = PostgresDb::new(&url, 5).await.unwrap();
2319 let migrator = Migrator::new(migrations_path).await.unwrap();
2320 migrator.run(&db.pool).await.unwrap();
2321 Self {
2322 db: Some(Arc::new(db)),
2323 url,
2324 }
2325 }
2326
2327 pub fn fake(background: Arc<Background>) -> Self {
2328 Self {
2329 db: Some(Arc::new(FakeDb::new(background))),
2330 url: Default::default(),
2331 }
2332 }
2333
2334 pub fn db(&self) -> &Arc<dyn Db> {
2335 self.db.as_ref().unwrap()
2336 }
2337 }
2338
2339 impl Drop for TestDb {
2340 fn drop(&mut self) {
2341 if let Some(db) = self.db.take() {
2342 futures::executor::block_on(db.teardown(&self.url));
2343 }
2344 }
2345 }
2346
2347 pub struct FakeDb {
2348 background: Arc<Background>,
2349 pub users: Mutex<BTreeMap<UserId, User>>,
2350 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2351 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2352 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2353 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2354 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2355 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2356 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2357 pub contacts: Mutex<Vec<FakeContact>>,
2358 next_channel_message_id: Mutex<i32>,
2359 next_user_id: Mutex<i32>,
2360 next_org_id: Mutex<i32>,
2361 next_channel_id: Mutex<i32>,
2362 next_project_id: Mutex<i32>,
2363 }
2364
2365 #[derive(Debug)]
2366 pub struct FakeContact {
2367 pub requester_id: UserId,
2368 pub responder_id: UserId,
2369 pub accepted: bool,
2370 pub should_notify: bool,
2371 }
2372
2373 impl FakeDb {
2374 pub fn new(background: Arc<Background>) -> Self {
2375 Self {
2376 background,
2377 users: Default::default(),
2378 next_user_id: Mutex::new(0),
2379 projects: Default::default(),
2380 worktree_extensions: Default::default(),
2381 next_project_id: Mutex::new(1),
2382 orgs: Default::default(),
2383 next_org_id: Mutex::new(1),
2384 org_memberships: Default::default(),
2385 channels: Default::default(),
2386 next_channel_id: Mutex::new(1),
2387 channel_memberships: Default::default(),
2388 channel_messages: Default::default(),
2389 next_channel_message_id: Mutex::new(1),
2390 contacts: Default::default(),
2391 }
2392 }
2393 }
2394
2395 #[async_trait]
2396 impl Db for FakeDb {
2397 async fn create_user(
2398 &self,
2399 github_login: &str,
2400 email_address: Option<&str>,
2401 admin: bool,
2402 ) -> Result<UserId> {
2403 self.background.simulate_random_delay().await;
2404
2405 let mut users = self.users.lock();
2406 if let Some(user) = users
2407 .values()
2408 .find(|user| user.github_login == github_login)
2409 {
2410 Ok(user.id)
2411 } else {
2412 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2413 users.insert(
2414 user_id,
2415 User {
2416 id: user_id,
2417 github_login: github_login.to_string(),
2418 email_address: email_address.map(str::to_string),
2419 admin,
2420 invite_code: None,
2421 invite_count: 0,
2422 connected_once: false,
2423 },
2424 );
2425 Ok(user_id)
2426 }
2427 }
2428
2429 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2430 unimplemented!()
2431 }
2432
2433 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2434 unimplemented!()
2435 }
2436
2437 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2438 unimplemented!()
2439 }
2440
2441 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2442 self.background.simulate_random_delay().await;
2443 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2444 }
2445
2446 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2447 self.background.simulate_random_delay().await;
2448 let users = self.users.lock();
2449 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2450 }
2451
2452 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2453 unimplemented!()
2454 }
2455
2456 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2457 self.background.simulate_random_delay().await;
2458 Ok(self
2459 .users
2460 .lock()
2461 .values()
2462 .find(|user| user.github_login == github_login)
2463 .cloned())
2464 }
2465
2466 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2467 unimplemented!()
2468 }
2469
2470 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2471 self.background.simulate_random_delay().await;
2472 let mut users = self.users.lock();
2473 let mut user = users
2474 .get_mut(&id)
2475 .ok_or_else(|| anyhow!("user not found"))?;
2476 user.connected_once = connected_once;
2477 Ok(())
2478 }
2479
2480 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2481 unimplemented!()
2482 }
2483
2484 // invite codes
2485
2486 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2487 unimplemented!()
2488 }
2489
2490 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2491 self.background.simulate_random_delay().await;
2492 Ok(None)
2493 }
2494
2495 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2496 unimplemented!()
2497 }
2498
2499 async fn redeem_invite_code(
2500 &self,
2501 _code: &str,
2502 _login: &str,
2503 _email_address: Option<&str>,
2504 ) -> Result<UserId> {
2505 unimplemented!()
2506 }
2507
2508 // projects
2509
2510 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2511 self.background.simulate_random_delay().await;
2512 if !self.users.lock().contains_key(&host_user_id) {
2513 Err(anyhow!("no such user"))?;
2514 }
2515
2516 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2517 self.projects.lock().insert(
2518 project_id,
2519 Project {
2520 id: project_id,
2521 host_user_id,
2522 unregistered: false,
2523 },
2524 );
2525 Ok(project_id)
2526 }
2527
2528 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2529 self.background.simulate_random_delay().await;
2530 self.projects
2531 .lock()
2532 .get_mut(&project_id)
2533 .ok_or_else(|| anyhow!("no such project"))?
2534 .unregistered = true;
2535 Ok(())
2536 }
2537
2538 async fn update_worktree_extensions(
2539 &self,
2540 project_id: ProjectId,
2541 worktree_id: u64,
2542 extensions: HashMap<String, u32>,
2543 ) -> Result<()> {
2544 self.background.simulate_random_delay().await;
2545 if !self.projects.lock().contains_key(&project_id) {
2546 Err(anyhow!("no such project"))?;
2547 }
2548
2549 for (extension, count) in extensions {
2550 self.worktree_extensions
2551 .lock()
2552 .insert((project_id, worktree_id, extension), count);
2553 }
2554
2555 Ok(())
2556 }
2557
2558 async fn get_project_extensions(
2559 &self,
2560 _project_id: ProjectId,
2561 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2562 unimplemented!()
2563 }
2564
2565 async fn record_user_activity(
2566 &self,
2567 _time_period: Range<OffsetDateTime>,
2568 _active_projects: &[(UserId, ProjectId)],
2569 ) -> Result<()> {
2570 unimplemented!()
2571 }
2572
2573 async fn get_active_user_count(
2574 &self,
2575 _time_period: Range<OffsetDateTime>,
2576 _min_duration: Duration,
2577 _only_collaborative: bool,
2578 ) -> Result<usize> {
2579 unimplemented!()
2580 }
2581
2582 async fn get_top_users_activity_summary(
2583 &self,
2584 _time_period: Range<OffsetDateTime>,
2585 _limit: usize,
2586 ) -> Result<Vec<UserActivitySummary>> {
2587 unimplemented!()
2588 }
2589
2590 async fn get_user_activity_timeline(
2591 &self,
2592 _time_period: Range<OffsetDateTime>,
2593 _user_id: UserId,
2594 ) -> Result<Vec<UserActivityPeriod>> {
2595 unimplemented!()
2596 }
2597
2598 // contacts
2599
2600 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2601 self.background.simulate_random_delay().await;
2602 let mut contacts = Vec::new();
2603
2604 for contact in self.contacts.lock().iter() {
2605 if contact.requester_id == id {
2606 if contact.accepted {
2607 contacts.push(Contact::Accepted {
2608 user_id: contact.responder_id,
2609 should_notify: contact.should_notify,
2610 });
2611 } else {
2612 contacts.push(Contact::Outgoing {
2613 user_id: contact.responder_id,
2614 });
2615 }
2616 } else if contact.responder_id == id {
2617 if contact.accepted {
2618 contacts.push(Contact::Accepted {
2619 user_id: contact.requester_id,
2620 should_notify: false,
2621 });
2622 } else {
2623 contacts.push(Contact::Incoming {
2624 user_id: contact.requester_id,
2625 should_notify: contact.should_notify,
2626 });
2627 }
2628 }
2629 }
2630
2631 contacts.sort_unstable_by_key(|contact| contact.user_id());
2632 Ok(contacts)
2633 }
2634
2635 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2636 self.background.simulate_random_delay().await;
2637 Ok(self.contacts.lock().iter().any(|contact| {
2638 contact.accepted
2639 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2640 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2641 }))
2642 }
2643
2644 async fn send_contact_request(
2645 &self,
2646 requester_id: UserId,
2647 responder_id: UserId,
2648 ) -> Result<()> {
2649 self.background.simulate_random_delay().await;
2650 let mut contacts = self.contacts.lock();
2651 for contact in contacts.iter_mut() {
2652 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2653 if contact.accepted {
2654 Err(anyhow!("contact already exists"))?;
2655 } else {
2656 Err(anyhow!("contact already requested"))?;
2657 }
2658 }
2659 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2660 if contact.accepted {
2661 Err(anyhow!("contact already exists"))?;
2662 } else {
2663 contact.accepted = true;
2664 contact.should_notify = false;
2665 return Ok(());
2666 }
2667 }
2668 }
2669 contacts.push(FakeContact {
2670 requester_id,
2671 responder_id,
2672 accepted: false,
2673 should_notify: true,
2674 });
2675 Ok(())
2676 }
2677
2678 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2679 self.background.simulate_random_delay().await;
2680 self.contacts.lock().retain(|contact| {
2681 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2682 });
2683 Ok(())
2684 }
2685
2686 async fn dismiss_contact_notification(
2687 &self,
2688 user_id: UserId,
2689 contact_user_id: UserId,
2690 ) -> Result<()> {
2691 self.background.simulate_random_delay().await;
2692 let mut contacts = self.contacts.lock();
2693 for contact in contacts.iter_mut() {
2694 if contact.requester_id == contact_user_id
2695 && contact.responder_id == user_id
2696 && !contact.accepted
2697 {
2698 contact.should_notify = false;
2699 return Ok(());
2700 }
2701 if contact.requester_id == user_id
2702 && contact.responder_id == contact_user_id
2703 && contact.accepted
2704 {
2705 contact.should_notify = false;
2706 return Ok(());
2707 }
2708 }
2709 Err(anyhow!("no such notification"))?
2710 }
2711
2712 async fn respond_to_contact_request(
2713 &self,
2714 responder_id: UserId,
2715 requester_id: UserId,
2716 accept: bool,
2717 ) -> Result<()> {
2718 self.background.simulate_random_delay().await;
2719 let mut contacts = self.contacts.lock();
2720 for (ix, contact) in contacts.iter_mut().enumerate() {
2721 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2722 if contact.accepted {
2723 Err(anyhow!("contact already confirmed"))?;
2724 }
2725 if accept {
2726 contact.accepted = true;
2727 contact.should_notify = true;
2728 } else {
2729 contacts.remove(ix);
2730 }
2731 return Ok(());
2732 }
2733 }
2734 Err(anyhow!("no such contact request"))?
2735 }
2736
2737 async fn create_access_token_hash(
2738 &self,
2739 _user_id: UserId,
2740 _access_token_hash: &str,
2741 _max_access_token_count: usize,
2742 ) -> Result<()> {
2743 unimplemented!()
2744 }
2745
2746 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2747 unimplemented!()
2748 }
2749
2750 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2751 unimplemented!()
2752 }
2753
2754 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2755 self.background.simulate_random_delay().await;
2756 let mut orgs = self.orgs.lock();
2757 if orgs.values().any(|org| org.slug == slug) {
2758 Err(anyhow!("org already exists"))?
2759 } else {
2760 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2761 orgs.insert(
2762 org_id,
2763 Org {
2764 id: org_id,
2765 name: name.to_string(),
2766 slug: slug.to_string(),
2767 },
2768 );
2769 Ok(org_id)
2770 }
2771 }
2772
2773 async fn add_org_member(
2774 &self,
2775 org_id: OrgId,
2776 user_id: UserId,
2777 is_admin: bool,
2778 ) -> Result<()> {
2779 self.background.simulate_random_delay().await;
2780 if !self.orgs.lock().contains_key(&org_id) {
2781 Err(anyhow!("org does not exist"))?;
2782 }
2783 if !self.users.lock().contains_key(&user_id) {
2784 Err(anyhow!("user does not exist"))?;
2785 }
2786
2787 self.org_memberships
2788 .lock()
2789 .entry((org_id, user_id))
2790 .or_insert(is_admin);
2791 Ok(())
2792 }
2793
2794 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2795 self.background.simulate_random_delay().await;
2796 if !self.orgs.lock().contains_key(&org_id) {
2797 Err(anyhow!("org does not exist"))?;
2798 }
2799
2800 let mut channels = self.channels.lock();
2801 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2802 channels.insert(
2803 channel_id,
2804 Channel {
2805 id: channel_id,
2806 name: name.to_string(),
2807 owner_id: org_id.0,
2808 owner_is_user: false,
2809 },
2810 );
2811 Ok(channel_id)
2812 }
2813
2814 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2815 self.background.simulate_random_delay().await;
2816 Ok(self
2817 .channels
2818 .lock()
2819 .values()
2820 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2821 .cloned()
2822 .collect())
2823 }
2824
2825 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2826 self.background.simulate_random_delay().await;
2827 let channels = self.channels.lock();
2828 let memberships = self.channel_memberships.lock();
2829 Ok(channels
2830 .values()
2831 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2832 .cloned()
2833 .collect())
2834 }
2835
2836 async fn can_user_access_channel(
2837 &self,
2838 user_id: UserId,
2839 channel_id: ChannelId,
2840 ) -> Result<bool> {
2841 self.background.simulate_random_delay().await;
2842 Ok(self
2843 .channel_memberships
2844 .lock()
2845 .contains_key(&(channel_id, user_id)))
2846 }
2847
2848 async fn add_channel_member(
2849 &self,
2850 channel_id: ChannelId,
2851 user_id: UserId,
2852 is_admin: bool,
2853 ) -> Result<()> {
2854 self.background.simulate_random_delay().await;
2855 if !self.channels.lock().contains_key(&channel_id) {
2856 Err(anyhow!("channel does not exist"))?;
2857 }
2858 if !self.users.lock().contains_key(&user_id) {
2859 Err(anyhow!("user does not exist"))?;
2860 }
2861
2862 self.channel_memberships
2863 .lock()
2864 .entry((channel_id, user_id))
2865 .or_insert(is_admin);
2866 Ok(())
2867 }
2868
2869 async fn create_channel_message(
2870 &self,
2871 channel_id: ChannelId,
2872 sender_id: UserId,
2873 body: &str,
2874 timestamp: OffsetDateTime,
2875 nonce: u128,
2876 ) -> Result<MessageId> {
2877 self.background.simulate_random_delay().await;
2878 if !self.channels.lock().contains_key(&channel_id) {
2879 Err(anyhow!("channel does not exist"))?;
2880 }
2881 if !self.users.lock().contains_key(&sender_id) {
2882 Err(anyhow!("user does not exist"))?;
2883 }
2884
2885 let mut messages = self.channel_messages.lock();
2886 if let Some(message) = messages
2887 .values()
2888 .find(|message| message.nonce.as_u128() == nonce)
2889 {
2890 Ok(message.id)
2891 } else {
2892 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2893 messages.insert(
2894 message_id,
2895 ChannelMessage {
2896 id: message_id,
2897 channel_id,
2898 sender_id,
2899 body: body.to_string(),
2900 sent_at: timestamp,
2901 nonce: Uuid::from_u128(nonce),
2902 },
2903 );
2904 Ok(message_id)
2905 }
2906 }
2907
2908 async fn get_channel_messages(
2909 &self,
2910 channel_id: ChannelId,
2911 count: usize,
2912 before_id: Option<MessageId>,
2913 ) -> Result<Vec<ChannelMessage>> {
2914 self.background.simulate_random_delay().await;
2915 let mut messages = self
2916 .channel_messages
2917 .lock()
2918 .values()
2919 .rev()
2920 .filter(|message| {
2921 message.channel_id == channel_id
2922 && message.id < before_id.unwrap_or(MessageId::MAX)
2923 })
2924 .take(count)
2925 .cloned()
2926 .collect::<Vec<_>>();
2927 messages.sort_unstable_by_key(|message| message.id);
2928 Ok(messages)
2929 }
2930
2931 async fn teardown(&self, _: &str) {}
2932
2933 #[cfg(test)]
2934 fn as_fake(&self) -> Option<&FakeDb> {
2935 Some(self)
2936 }
2937 }
2938
2939 fn build_background_executor() -> Arc<Background> {
2940 Deterministic::new(0).build_background()
2941 }
2942}