1use std::{cmp, ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::{Deserialize, Serialize};
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
29 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
30 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
31 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
32 async fn destroy_user(&self, id: UserId) -> Result<()>;
33
34 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
35 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
36 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
37 async fn redeem_invite_code(
38 &self,
39 code: &str,
40 login: &str,
41 email_address: Option<&str>,
42 ) -> Result<UserId>;
43
44 /// Registers a new project for the given user.
45 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
46
47 /// Unregisters a project for the given project id.
48 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
49
50 /// Update file counts by extension for the given project and worktree.
51 async fn update_worktree_extensions(
52 &self,
53 project_id: ProjectId,
54 worktree_id: u64,
55 extensions: HashMap<String, u32>,
56 ) -> Result<()>;
57
58 /// Get the file counts on the given project keyed by their worktree and extension.
59 async fn get_project_extensions(
60 &self,
61 project_id: ProjectId,
62 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
63
64 /// Record which users have been active in which projects during
65 /// a given period of time.
66 async fn record_user_activity(
67 &self,
68 time_period: Range<OffsetDateTime>,
69 active_projects: &[(UserId, ProjectId)],
70 ) -> Result<()>;
71
72 /// Get the number of users who have been active in the given
73 /// time period for at least the given time duration.
74 async fn get_active_user_count(
75 &self,
76 time_period: Range<OffsetDateTime>,
77 min_duration: Duration,
78 ) -> Result<usize>;
79
80 /// Get the users that have been most active during the given time period,
81 /// along with the amount of time they have been active in each project.
82 async fn get_top_users_activity_summary(
83 &self,
84 time_period: Range<OffsetDateTime>,
85 max_user_count: usize,
86 ) -> Result<Vec<UserActivitySummary>>;
87
88 /// Get the project activity for the given user and time period.
89 async fn get_user_activity_timeline(
90 &self,
91 time_period: Range<OffsetDateTime>,
92 user_id: UserId,
93 ) -> Result<Vec<UserActivityPeriod>>;
94
95 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
96 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
97 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
98 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
99 async fn dismiss_contact_notification(
100 &self,
101 responder_id: UserId,
102 requester_id: UserId,
103 ) -> Result<()>;
104 async fn respond_to_contact_request(
105 &self,
106 responder_id: UserId,
107 requester_id: UserId,
108 accept: bool,
109 ) -> Result<()>;
110
111 async fn create_access_token_hash(
112 &self,
113 user_id: UserId,
114 access_token_hash: &str,
115 max_access_token_count: usize,
116 ) -> Result<()>;
117 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
118 #[cfg(any(test, feature = "seed-support"))]
119
120 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
121 #[cfg(any(test, feature = "seed-support"))]
122 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
123 #[cfg(any(test, feature = "seed-support"))]
124 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
125 #[cfg(any(test, feature = "seed-support"))]
126 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
127 #[cfg(any(test, feature = "seed-support"))]
128
129 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
130 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
131 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
132 -> Result<bool>;
133 #[cfg(any(test, feature = "seed-support"))]
134 async fn add_channel_member(
135 &self,
136 channel_id: ChannelId,
137 user_id: UserId,
138 is_admin: bool,
139 ) -> Result<()>;
140 async fn create_channel_message(
141 &self,
142 channel_id: ChannelId,
143 sender_id: UserId,
144 body: &str,
145 timestamp: OffsetDateTime,
146 nonce: u128,
147 ) -> Result<MessageId>;
148 async fn get_channel_messages(
149 &self,
150 channel_id: ChannelId,
151 count: usize,
152 before_id: Option<MessageId>,
153 ) -> Result<Vec<ChannelMessage>>;
154 #[cfg(test)]
155 async fn teardown(&self, url: &str);
156 #[cfg(test)]
157 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
158}
159
160pub struct PostgresDb {
161 pool: sqlx::PgPool,
162}
163
164impl PostgresDb {
165 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
166 let pool = DbOptions::new()
167 .max_connections(max_connections)
168 .connect(&url)
169 .await
170 .context("failed to connect to postgres database")?;
171 Ok(Self { pool })
172 }
173}
174
175#[async_trait]
176impl Db for PostgresDb {
177 // users
178
179 async fn create_user(
180 &self,
181 github_login: &str,
182 email_address: Option<&str>,
183 admin: bool,
184 ) -> Result<UserId> {
185 let query = "
186 INSERT INTO users (github_login, email_address, admin)
187 VALUES ($1, $2, $3)
188 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
189 RETURNING id
190 ";
191 Ok(sqlx::query_scalar(query)
192 .bind(github_login)
193 .bind(email_address)
194 .bind(admin)
195 .fetch_one(&self.pool)
196 .await
197 .map(UserId)?)
198 }
199
200 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
201 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
202 Ok(sqlx::query_as(query)
203 .bind(limit as i32)
204 .bind((page * limit) as i32)
205 .fetch_all(&self.pool)
206 .await?)
207 }
208
209 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
210 let mut query = QueryBuilder::new(
211 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
212 );
213 query.push_values(
214 users,
215 |mut query, (github_login, email_address, invite_count)| {
216 query
217 .push_bind(github_login)
218 .push_bind(email_address)
219 .push_bind(false)
220 .push_bind(nanoid!(16))
221 .push_bind(invite_count as i32);
222 },
223 );
224 query.push(
225 "
226 ON CONFLICT (github_login) DO UPDATE SET
227 github_login = excluded.github_login,
228 invite_count = excluded.invite_count,
229 invite_code = CASE WHEN users.invite_code IS NULL
230 THEN excluded.invite_code
231 ELSE users.invite_code
232 END
233 RETURNING id
234 ",
235 );
236
237 let rows = query.build().fetch_all(&self.pool).await?;
238 Ok(rows
239 .into_iter()
240 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
241 .collect())
242 }
243
244 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
245 let like_string = fuzzy_like_string(name_query);
246 let query = "
247 SELECT users.*
248 FROM users
249 WHERE github_login ILIKE $1
250 ORDER BY github_login <-> $2
251 LIMIT $3
252 ";
253 Ok(sqlx::query_as(query)
254 .bind(like_string)
255 .bind(name_query)
256 .bind(limit as i32)
257 .fetch_all(&self.pool)
258 .await?)
259 }
260
261 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
262 let users = self.get_users_by_ids(vec![id]).await?;
263 Ok(users.into_iter().next())
264 }
265
266 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
267 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
268 let query = "
269 SELECT users.*
270 FROM users
271 WHERE users.id = ANY ($1)
272 ";
273 Ok(sqlx::query_as(query)
274 .bind(&ids)
275 .fetch_all(&self.pool)
276 .await?)
277 }
278
279 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
280 let query = format!(
281 "
282 SELECT users.*
283 FROM users
284 WHERE invite_count = 0
285 AND inviter_id IS{} NULL
286 ",
287 if invited_by_another_user { " NOT" } else { "" }
288 );
289
290 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
291 }
292
293 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
294 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
295 Ok(sqlx::query_as(query)
296 .bind(github_login)
297 .fetch_optional(&self.pool)
298 .await?)
299 }
300
301 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
302 let query = "UPDATE users SET admin = $1 WHERE id = $2";
303 Ok(sqlx::query(query)
304 .bind(is_admin)
305 .bind(id.0)
306 .execute(&self.pool)
307 .await
308 .map(drop)?)
309 }
310
311 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
312 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
313 Ok(sqlx::query(query)
314 .bind(connected_once)
315 .bind(id.0)
316 .execute(&self.pool)
317 .await
318 .map(drop)?)
319 }
320
321 async fn destroy_user(&self, id: UserId) -> Result<()> {
322 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
323 sqlx::query(query)
324 .bind(id.0)
325 .execute(&self.pool)
326 .await
327 .map(drop)?;
328 let query = "DELETE FROM users WHERE id = $1;";
329 Ok(sqlx::query(query)
330 .bind(id.0)
331 .execute(&self.pool)
332 .await
333 .map(drop)?)
334 }
335
336 // invite codes
337
338 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
339 let mut tx = self.pool.begin().await?;
340 if count > 0 {
341 sqlx::query(
342 "
343 UPDATE users
344 SET invite_code = $1
345 WHERE id = $2 AND invite_code IS NULL
346 ",
347 )
348 .bind(nanoid!(16))
349 .bind(id)
350 .execute(&mut tx)
351 .await?;
352 }
353
354 sqlx::query(
355 "
356 UPDATE users
357 SET invite_count = $1
358 WHERE id = $2
359 ",
360 )
361 .bind(count as i32)
362 .bind(id)
363 .execute(&mut tx)
364 .await?;
365 tx.commit().await?;
366 Ok(())
367 }
368
369 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
370 let result: Option<(String, i32)> = sqlx::query_as(
371 "
372 SELECT invite_code, invite_count
373 FROM users
374 WHERE id = $1 AND invite_code IS NOT NULL
375 ",
376 )
377 .bind(id)
378 .fetch_optional(&self.pool)
379 .await?;
380 if let Some((code, count)) = result {
381 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
382 } else {
383 Ok(None)
384 }
385 }
386
387 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
388 sqlx::query_as(
389 "
390 SELECT *
391 FROM users
392 WHERE invite_code = $1
393 ",
394 )
395 .bind(code)
396 .fetch_optional(&self.pool)
397 .await?
398 .ok_or_else(|| {
399 Error::Http(
400 StatusCode::NOT_FOUND,
401 "that invite code does not exist".to_string(),
402 )
403 })
404 }
405
406 async fn redeem_invite_code(
407 &self,
408 code: &str,
409 login: &str,
410 email_address: Option<&str>,
411 ) -> Result<UserId> {
412 let mut tx = self.pool.begin().await?;
413
414 let inviter_id: Option<UserId> = sqlx::query_scalar(
415 "
416 UPDATE users
417 SET invite_count = invite_count - 1
418 WHERE
419 invite_code = $1 AND
420 invite_count > 0
421 RETURNING id
422 ",
423 )
424 .bind(code)
425 .fetch_optional(&mut tx)
426 .await?;
427
428 let inviter_id = match inviter_id {
429 Some(inviter_id) => inviter_id,
430 None => {
431 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
432 .bind(code)
433 .fetch_optional(&mut tx)
434 .await?
435 .is_some()
436 {
437 Err(Error::Http(
438 StatusCode::UNAUTHORIZED,
439 "no invites remaining".to_string(),
440 ))?
441 } else {
442 Err(Error::Http(
443 StatusCode::NOT_FOUND,
444 "invite code not found".to_string(),
445 ))?
446 }
447 }
448 };
449
450 let invitee_id = sqlx::query_scalar(
451 "
452 INSERT INTO users
453 (github_login, email_address, admin, inviter_id)
454 VALUES
455 ($1, $2, 'f', $3)
456 RETURNING id
457 ",
458 )
459 .bind(login)
460 .bind(email_address)
461 .bind(inviter_id)
462 .fetch_one(&mut tx)
463 .await
464 .map(UserId)?;
465
466 sqlx::query(
467 "
468 INSERT INTO contacts
469 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
470 VALUES
471 ($1, $2, 't', 't', 't')
472 ",
473 )
474 .bind(inviter_id)
475 .bind(invitee_id)
476 .execute(&mut tx)
477 .await?;
478
479 tx.commit().await?;
480 Ok(invitee_id)
481 }
482
483 // projects
484
485 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
486 Ok(sqlx::query_scalar(
487 "
488 INSERT INTO projects(host_user_id)
489 VALUES ($1)
490 RETURNING id
491 ",
492 )
493 .bind(host_user_id)
494 .fetch_one(&self.pool)
495 .await
496 .map(ProjectId)?)
497 }
498
499 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
500 sqlx::query(
501 "
502 UPDATE projects
503 SET unregistered = 't'
504 WHERE id = $1
505 ",
506 )
507 .bind(project_id)
508 .execute(&self.pool)
509 .await?;
510 Ok(())
511 }
512
513 async fn update_worktree_extensions(
514 &self,
515 project_id: ProjectId,
516 worktree_id: u64,
517 extensions: HashMap<String, u32>,
518 ) -> Result<()> {
519 if extensions.is_empty() {
520 return Ok(());
521 }
522
523 let mut query = QueryBuilder::new(
524 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
525 );
526 query.push_values(extensions, |mut query, (extension, count)| {
527 query
528 .push_bind(project_id)
529 .push_bind(worktree_id as i32)
530 .push_bind(extension)
531 .push_bind(count as i32);
532 });
533 query.push(
534 "
535 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
536 count = excluded.count
537 ",
538 );
539 query.build().execute(&self.pool).await?;
540
541 Ok(())
542 }
543
544 async fn get_project_extensions(
545 &self,
546 project_id: ProjectId,
547 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
548 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
549 struct WorktreeExtension {
550 worktree_id: i32,
551 extension: String,
552 count: i32,
553 }
554
555 let query = "
556 SELECT worktree_id, extension, count
557 FROM worktree_extensions
558 WHERE project_id = $1
559 ";
560 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
561 .bind(&project_id)
562 .fetch_all(&self.pool)
563 .await?;
564
565 let mut extension_counts = HashMap::default();
566 for count in counts {
567 extension_counts
568 .entry(count.worktree_id as u64)
569 .or_insert(HashMap::default())
570 .insert(count.extension, count.count as usize);
571 }
572 Ok(extension_counts)
573 }
574
575 async fn record_user_activity(
576 &self,
577 time_period: Range<OffsetDateTime>,
578 projects: &[(UserId, ProjectId)],
579 ) -> Result<()> {
580 let query = "
581 INSERT INTO project_activity_periods
582 (ended_at, duration_millis, user_id, project_id)
583 VALUES
584 ($1, $2, $3, $4);
585 ";
586
587 let mut tx = self.pool.begin().await?;
588 let duration_millis =
589 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
590 for (user_id, project_id) in projects {
591 sqlx::query(query)
592 .bind(time_period.end)
593 .bind(duration_millis)
594 .bind(user_id)
595 .bind(project_id)
596 .execute(&mut tx)
597 .await?;
598 }
599 tx.commit().await?;
600
601 Ok(())
602 }
603
604 async fn get_active_user_count(
605 &self,
606 time_period: Range<OffsetDateTime>,
607 min_duration: Duration,
608 ) -> Result<usize> {
609 let query = "
610 WITH
611 project_durations AS (
612 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
613 FROM project_activity_periods
614 WHERE $1 < ended_at AND ended_at <= $2
615 GROUP BY user_id, project_id
616 ),
617 user_durations AS (
618 SELECT user_id, SUM(project_duration) as total_duration
619 FROM project_durations
620 GROUP BY user_id
621 ORDER BY total_duration DESC
622 LIMIT $3
623 )
624 SELECT count(user_durations.user_id)
625 FROM user_durations
626 WHERE user_durations.total_duration >= $3
627 ";
628
629 let count: i64 = sqlx::query_scalar(query)
630 .bind(time_period.start)
631 .bind(time_period.end)
632 .bind(min_duration.as_millis() as i64)
633 .fetch_one(&self.pool)
634 .await?;
635 Ok(count as usize)
636 }
637
638 async fn get_top_users_activity_summary(
639 &self,
640 time_period: Range<OffsetDateTime>,
641 max_user_count: usize,
642 ) -> Result<Vec<UserActivitySummary>> {
643 let query = "
644 WITH
645 project_durations AS (
646 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
647 FROM project_activity_periods
648 WHERE $1 < ended_at AND ended_at <= $2
649 GROUP BY user_id, project_id
650 ),
651 user_durations AS (
652 SELECT user_id, SUM(project_duration) as total_duration
653 FROM project_durations
654 GROUP BY user_id
655 ORDER BY total_duration DESC
656 LIMIT $3
657 ),
658 project_collaborators as (
659 SELECT project_id, COUNT(DISTINCT user_id) as project_collaborators
660 FROM project_durations
661 GROUP BY project_id
662 )
663 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, project_collaborators
664 FROM user_durations, project_durations, project_collaborators, users
665 WHERE
666 user_durations.user_id = project_durations.user_id AND
667 user_durations.user_id = users.id AND
668 project_durations.project_id = project_collaborators.project_id
669 ORDER BY total_duration DESC, user_id ASC
670 ";
671
672 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
673 .bind(time_period.start)
674 .bind(time_period.end)
675 .bind(max_user_count as i32)
676 .fetch(&self.pool);
677
678 let mut result = Vec::<UserActivitySummary>::new();
679 while let Some(row) = rows.next().await {
680 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
681 let project_id = project_id;
682 let duration = Duration::from_millis(duration_millis as u64);
683 let project_activity = ProjectActivitySummary {
684 id: project_id,
685 duration,
686 max_collaborators: project_collaborators as usize,
687 };
688 if let Some(last_summary) = result.last_mut() {
689 if last_summary.id == user_id {
690 last_summary.project_activity.push(project_activity);
691 continue;
692 }
693 }
694 result.push(UserActivitySummary {
695 id: user_id,
696 project_activity: vec![project_activity],
697 github_login,
698 });
699 }
700
701 Ok(result)
702 }
703
704 async fn get_user_activity_timeline(
705 &self,
706 time_period: Range<OffsetDateTime>,
707 user_id: UserId,
708 ) -> Result<Vec<UserActivityPeriod>> {
709 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
710
711 let query = "
712 SELECT
713 project_activity_periods.ended_at,
714 project_activity_periods.duration_millis,
715 project_activity_periods.project_id,
716 worktree_extensions.extension,
717 worktree_extensions.count
718 FROM project_activity_periods
719 LEFT OUTER JOIN
720 worktree_extensions
721 ON
722 project_activity_periods.project_id = worktree_extensions.project_id
723 WHERE
724 project_activity_periods.user_id = $1 AND
725 $2 < project_activity_periods.ended_at AND
726 project_activity_periods.ended_at <= $3
727 ORDER BY project_activity_periods.id ASC
728 ";
729
730 let mut rows = sqlx::query_as::<
731 _,
732 (
733 PrimitiveDateTime,
734 i32,
735 ProjectId,
736 Option<String>,
737 Option<i32>,
738 ),
739 >(query)
740 .bind(user_id)
741 .bind(time_period.start)
742 .bind(time_period.end)
743 .fetch(&self.pool);
744
745 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
746 while let Some(row) = rows.next().await {
747 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
748 let ended_at = ended_at.assume_utc();
749 let duration = Duration::from_millis(duration_millis as u64);
750 let started_at = ended_at - duration;
751 let project_time_periods = time_periods.entry(project_id).or_default();
752
753 if let Some(prev_duration) = project_time_periods.last_mut() {
754 if started_at <= prev_duration.end + COALESCE_THRESHOLD
755 && ended_at >= prev_duration.start
756 {
757 prev_duration.end = cmp::max(prev_duration.end, ended_at);
758 } else {
759 project_time_periods.push(UserActivityPeriod {
760 project_id,
761 start: started_at,
762 end: ended_at,
763 extensions: Default::default(),
764 });
765 }
766 } else {
767 project_time_periods.push(UserActivityPeriod {
768 project_id,
769 start: started_at,
770 end: ended_at,
771 extensions: Default::default(),
772 });
773 }
774
775 if let Some((extension, extension_count)) = extension.zip(extension_count) {
776 project_time_periods
777 .last_mut()
778 .unwrap()
779 .extensions
780 .insert(extension, extension_count as usize);
781 }
782 }
783
784 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
785 durations.sort_unstable_by_key(|duration| duration.start);
786 Ok(durations)
787 }
788
789 // contacts
790
791 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
792 let query = "
793 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
794 FROM contacts
795 WHERE user_id_a = $1 OR user_id_b = $1;
796 ";
797
798 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
799 .bind(user_id)
800 .fetch(&self.pool);
801
802 let mut contacts = vec![Contact::Accepted {
803 user_id,
804 should_notify: false,
805 }];
806 while let Some(row) = rows.next().await {
807 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
808
809 if user_id_a == user_id {
810 if accepted {
811 contacts.push(Contact::Accepted {
812 user_id: user_id_b,
813 should_notify: should_notify && a_to_b,
814 });
815 } else if a_to_b {
816 contacts.push(Contact::Outgoing { user_id: user_id_b })
817 } else {
818 contacts.push(Contact::Incoming {
819 user_id: user_id_b,
820 should_notify,
821 });
822 }
823 } else {
824 if accepted {
825 contacts.push(Contact::Accepted {
826 user_id: user_id_a,
827 should_notify: should_notify && !a_to_b,
828 });
829 } else if a_to_b {
830 contacts.push(Contact::Incoming {
831 user_id: user_id_a,
832 should_notify,
833 });
834 } else {
835 contacts.push(Contact::Outgoing { user_id: user_id_a });
836 }
837 }
838 }
839
840 contacts.sort_unstable_by_key(|contact| contact.user_id());
841
842 Ok(contacts)
843 }
844
845 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
846 let (id_a, id_b) = if user_id_1 < user_id_2 {
847 (user_id_1, user_id_2)
848 } else {
849 (user_id_2, user_id_1)
850 };
851
852 let query = "
853 SELECT 1 FROM contacts
854 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
855 LIMIT 1
856 ";
857 Ok(sqlx::query_scalar::<_, i32>(query)
858 .bind(id_a.0)
859 .bind(id_b.0)
860 .fetch_optional(&self.pool)
861 .await?
862 .is_some())
863 }
864
865 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
866 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
867 (sender_id, receiver_id, true)
868 } else {
869 (receiver_id, sender_id, false)
870 };
871 let query = "
872 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
873 VALUES ($1, $2, $3, 'f', 't')
874 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
875 SET
876 accepted = 't',
877 should_notify = 'f'
878 WHERE
879 NOT contacts.accepted AND
880 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
881 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
882 ";
883 let result = sqlx::query(query)
884 .bind(id_a.0)
885 .bind(id_b.0)
886 .bind(a_to_b)
887 .execute(&self.pool)
888 .await?;
889
890 if result.rows_affected() == 1 {
891 Ok(())
892 } else {
893 Err(anyhow!("contact already requested"))?
894 }
895 }
896
897 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
898 let (id_a, id_b) = if responder_id < requester_id {
899 (responder_id, requester_id)
900 } else {
901 (requester_id, responder_id)
902 };
903 let query = "
904 DELETE FROM contacts
905 WHERE user_id_a = $1 AND user_id_b = $2;
906 ";
907 let result = sqlx::query(query)
908 .bind(id_a.0)
909 .bind(id_b.0)
910 .execute(&self.pool)
911 .await?;
912
913 if result.rows_affected() == 1 {
914 Ok(())
915 } else {
916 Err(anyhow!("no such contact"))?
917 }
918 }
919
920 async fn dismiss_contact_notification(
921 &self,
922 user_id: UserId,
923 contact_user_id: UserId,
924 ) -> Result<()> {
925 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
926 (user_id, contact_user_id, true)
927 } else {
928 (contact_user_id, user_id, false)
929 };
930
931 let query = "
932 UPDATE contacts
933 SET should_notify = 'f'
934 WHERE
935 user_id_a = $1 AND user_id_b = $2 AND
936 (
937 (a_to_b = $3 AND accepted) OR
938 (a_to_b != $3 AND NOT accepted)
939 );
940 ";
941
942 let result = sqlx::query(query)
943 .bind(id_a.0)
944 .bind(id_b.0)
945 .bind(a_to_b)
946 .execute(&self.pool)
947 .await?;
948
949 if result.rows_affected() == 0 {
950 Err(anyhow!("no such contact request"))?;
951 }
952
953 Ok(())
954 }
955
956 async fn respond_to_contact_request(
957 &self,
958 responder_id: UserId,
959 requester_id: UserId,
960 accept: bool,
961 ) -> Result<()> {
962 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
963 (responder_id, requester_id, false)
964 } else {
965 (requester_id, responder_id, true)
966 };
967 let result = if accept {
968 let query = "
969 UPDATE contacts
970 SET accepted = 't', should_notify = 't'
971 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
972 ";
973 sqlx::query(query)
974 .bind(id_a.0)
975 .bind(id_b.0)
976 .bind(a_to_b)
977 .execute(&self.pool)
978 .await?
979 } else {
980 let query = "
981 DELETE FROM contacts
982 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
983 ";
984 sqlx::query(query)
985 .bind(id_a.0)
986 .bind(id_b.0)
987 .bind(a_to_b)
988 .execute(&self.pool)
989 .await?
990 };
991 if result.rows_affected() == 1 {
992 Ok(())
993 } else {
994 Err(anyhow!("no such contact request"))?
995 }
996 }
997
998 // access tokens
999
1000 async fn create_access_token_hash(
1001 &self,
1002 user_id: UserId,
1003 access_token_hash: &str,
1004 max_access_token_count: usize,
1005 ) -> Result<()> {
1006 let insert_query = "
1007 INSERT INTO access_tokens (user_id, hash)
1008 VALUES ($1, $2);
1009 ";
1010 let cleanup_query = "
1011 DELETE FROM access_tokens
1012 WHERE id IN (
1013 SELECT id from access_tokens
1014 WHERE user_id = $1
1015 ORDER BY id DESC
1016 OFFSET $3
1017 )
1018 ";
1019
1020 let mut tx = self.pool.begin().await?;
1021 sqlx::query(insert_query)
1022 .bind(user_id.0)
1023 .bind(access_token_hash)
1024 .execute(&mut tx)
1025 .await?;
1026 sqlx::query(cleanup_query)
1027 .bind(user_id.0)
1028 .bind(access_token_hash)
1029 .bind(max_access_token_count as i32)
1030 .execute(&mut tx)
1031 .await?;
1032 Ok(tx.commit().await?)
1033 }
1034
1035 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1036 let query = "
1037 SELECT hash
1038 FROM access_tokens
1039 WHERE user_id = $1
1040 ORDER BY id DESC
1041 ";
1042 Ok(sqlx::query_scalar(query)
1043 .bind(user_id.0)
1044 .fetch_all(&self.pool)
1045 .await?)
1046 }
1047
1048 // orgs
1049
1050 #[allow(unused)] // Help rust-analyzer
1051 #[cfg(any(test, feature = "seed-support"))]
1052 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1053 let query = "
1054 SELECT *
1055 FROM orgs
1056 WHERE slug = $1
1057 ";
1058 Ok(sqlx::query_as(query)
1059 .bind(slug)
1060 .fetch_optional(&self.pool)
1061 .await?)
1062 }
1063
1064 #[cfg(any(test, feature = "seed-support"))]
1065 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1066 let query = "
1067 INSERT INTO orgs (name, slug)
1068 VALUES ($1, $2)
1069 RETURNING id
1070 ";
1071 Ok(sqlx::query_scalar(query)
1072 .bind(name)
1073 .bind(slug)
1074 .fetch_one(&self.pool)
1075 .await
1076 .map(OrgId)?)
1077 }
1078
1079 #[cfg(any(test, feature = "seed-support"))]
1080 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1081 let query = "
1082 INSERT INTO org_memberships (org_id, user_id, admin)
1083 VALUES ($1, $2, $3)
1084 ON CONFLICT DO NOTHING
1085 ";
1086 Ok(sqlx::query(query)
1087 .bind(org_id.0)
1088 .bind(user_id.0)
1089 .bind(is_admin)
1090 .execute(&self.pool)
1091 .await
1092 .map(drop)?)
1093 }
1094
1095 // channels
1096
1097 #[cfg(any(test, feature = "seed-support"))]
1098 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1099 let query = "
1100 INSERT INTO channels (owner_id, owner_is_user, name)
1101 VALUES ($1, false, $2)
1102 RETURNING id
1103 ";
1104 Ok(sqlx::query_scalar(query)
1105 .bind(org_id.0)
1106 .bind(name)
1107 .fetch_one(&self.pool)
1108 .await
1109 .map(ChannelId)?)
1110 }
1111
1112 #[allow(unused)] // Help rust-analyzer
1113 #[cfg(any(test, feature = "seed-support"))]
1114 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1115 let query = "
1116 SELECT *
1117 FROM channels
1118 WHERE
1119 channels.owner_is_user = false AND
1120 channels.owner_id = $1
1121 ";
1122 Ok(sqlx::query_as(query)
1123 .bind(org_id.0)
1124 .fetch_all(&self.pool)
1125 .await?)
1126 }
1127
1128 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1129 let query = "
1130 SELECT
1131 channels.*
1132 FROM
1133 channel_memberships, channels
1134 WHERE
1135 channel_memberships.user_id = $1 AND
1136 channel_memberships.channel_id = channels.id
1137 ";
1138 Ok(sqlx::query_as(query)
1139 .bind(user_id.0)
1140 .fetch_all(&self.pool)
1141 .await?)
1142 }
1143
1144 async fn can_user_access_channel(
1145 &self,
1146 user_id: UserId,
1147 channel_id: ChannelId,
1148 ) -> Result<bool> {
1149 let query = "
1150 SELECT id
1151 FROM channel_memberships
1152 WHERE user_id = $1 AND channel_id = $2
1153 LIMIT 1
1154 ";
1155 Ok(sqlx::query_scalar::<_, i32>(query)
1156 .bind(user_id.0)
1157 .bind(channel_id.0)
1158 .fetch_optional(&self.pool)
1159 .await
1160 .map(|e| e.is_some())?)
1161 }
1162
1163 #[cfg(any(test, feature = "seed-support"))]
1164 async fn add_channel_member(
1165 &self,
1166 channel_id: ChannelId,
1167 user_id: UserId,
1168 is_admin: bool,
1169 ) -> Result<()> {
1170 let query = "
1171 INSERT INTO channel_memberships (channel_id, user_id, admin)
1172 VALUES ($1, $2, $3)
1173 ON CONFLICT DO NOTHING
1174 ";
1175 Ok(sqlx::query(query)
1176 .bind(channel_id.0)
1177 .bind(user_id.0)
1178 .bind(is_admin)
1179 .execute(&self.pool)
1180 .await
1181 .map(drop)?)
1182 }
1183
1184 // messages
1185
1186 async fn create_channel_message(
1187 &self,
1188 channel_id: ChannelId,
1189 sender_id: UserId,
1190 body: &str,
1191 timestamp: OffsetDateTime,
1192 nonce: u128,
1193 ) -> Result<MessageId> {
1194 let query = "
1195 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1196 VALUES ($1, $2, $3, $4, $5)
1197 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1198 RETURNING id
1199 ";
1200 Ok(sqlx::query_scalar(query)
1201 .bind(channel_id.0)
1202 .bind(sender_id.0)
1203 .bind(body)
1204 .bind(timestamp)
1205 .bind(Uuid::from_u128(nonce))
1206 .fetch_one(&self.pool)
1207 .await
1208 .map(MessageId)?)
1209 }
1210
1211 async fn get_channel_messages(
1212 &self,
1213 channel_id: ChannelId,
1214 count: usize,
1215 before_id: Option<MessageId>,
1216 ) -> Result<Vec<ChannelMessage>> {
1217 let query = r#"
1218 SELECT * FROM (
1219 SELECT
1220 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1221 FROM
1222 channel_messages
1223 WHERE
1224 channel_id = $1 AND
1225 id < $2
1226 ORDER BY id DESC
1227 LIMIT $3
1228 ) as recent_messages
1229 ORDER BY id ASC
1230 "#;
1231 Ok(sqlx::query_as(query)
1232 .bind(channel_id.0)
1233 .bind(before_id.unwrap_or(MessageId::MAX))
1234 .bind(count as i64)
1235 .fetch_all(&self.pool)
1236 .await?)
1237 }
1238
1239 #[cfg(test)]
1240 async fn teardown(&self, url: &str) {
1241 use util::ResultExt;
1242
1243 let query = "
1244 SELECT pg_terminate_backend(pg_stat_activity.pid)
1245 FROM pg_stat_activity
1246 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1247 ";
1248 sqlx::query(query).execute(&self.pool).await.log_err();
1249 self.pool.close().await;
1250 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1251 .await
1252 .log_err();
1253 }
1254
1255 #[cfg(test)]
1256 fn as_fake(&self) -> Option<&tests::FakeDb> {
1257 None
1258 }
1259}
1260
1261macro_rules! id_type {
1262 ($name:ident) => {
1263 #[derive(
1264 Clone,
1265 Copy,
1266 Debug,
1267 Default,
1268 PartialEq,
1269 Eq,
1270 PartialOrd,
1271 Ord,
1272 Hash,
1273 sqlx::Type,
1274 Serialize,
1275 Deserialize,
1276 )]
1277 #[sqlx(transparent)]
1278 #[serde(transparent)]
1279 pub struct $name(pub i32);
1280
1281 impl $name {
1282 #[allow(unused)]
1283 pub const MAX: Self = Self(i32::MAX);
1284
1285 #[allow(unused)]
1286 pub fn from_proto(value: u64) -> Self {
1287 Self(value as i32)
1288 }
1289
1290 #[allow(unused)]
1291 pub fn to_proto(&self) -> u64 {
1292 self.0 as u64
1293 }
1294 }
1295
1296 impl std::fmt::Display for $name {
1297 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1298 self.0.fmt(f)
1299 }
1300 }
1301 };
1302}
1303
1304id_type!(UserId);
1305#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1306pub struct User {
1307 pub id: UserId,
1308 pub github_login: String,
1309 pub email_address: Option<String>,
1310 pub admin: bool,
1311 pub invite_code: Option<String>,
1312 pub invite_count: i32,
1313 pub connected_once: bool,
1314}
1315
1316id_type!(ProjectId);
1317#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1318pub struct Project {
1319 pub id: ProjectId,
1320 pub host_user_id: UserId,
1321 pub unregistered: bool,
1322}
1323
1324#[derive(Clone, Debug, PartialEq, Serialize)]
1325pub struct UserActivitySummary {
1326 pub id: UserId,
1327 pub github_login: String,
1328 pub project_activity: Vec<ProjectActivitySummary>,
1329}
1330
1331#[derive(Clone, Debug, PartialEq, Serialize)]
1332pub struct ProjectActivitySummary {
1333 id: ProjectId,
1334 duration: Duration,
1335 max_collaborators: usize,
1336}
1337
1338#[derive(Clone, Debug, PartialEq, Serialize)]
1339pub struct UserActivityPeriod {
1340 project_id: ProjectId,
1341 #[serde(with = "time::serde::iso8601")]
1342 start: OffsetDateTime,
1343 #[serde(with = "time::serde::iso8601")]
1344 end: OffsetDateTime,
1345 extensions: HashMap<String, usize>,
1346}
1347
1348id_type!(OrgId);
1349#[derive(FromRow)]
1350pub struct Org {
1351 pub id: OrgId,
1352 pub name: String,
1353 pub slug: String,
1354}
1355
1356id_type!(ChannelId);
1357#[derive(Clone, Debug, FromRow, Serialize)]
1358pub struct Channel {
1359 pub id: ChannelId,
1360 pub name: String,
1361 pub owner_id: i32,
1362 pub owner_is_user: bool,
1363}
1364
1365id_type!(MessageId);
1366#[derive(Clone, Debug, FromRow)]
1367pub struct ChannelMessage {
1368 pub id: MessageId,
1369 pub channel_id: ChannelId,
1370 pub sender_id: UserId,
1371 pub body: String,
1372 pub sent_at: OffsetDateTime,
1373 pub nonce: Uuid,
1374}
1375
1376#[derive(Clone, Debug, PartialEq, Eq)]
1377pub enum Contact {
1378 Accepted {
1379 user_id: UserId,
1380 should_notify: bool,
1381 },
1382 Outgoing {
1383 user_id: UserId,
1384 },
1385 Incoming {
1386 user_id: UserId,
1387 should_notify: bool,
1388 },
1389}
1390
1391impl Contact {
1392 pub fn user_id(&self) -> UserId {
1393 match self {
1394 Contact::Accepted { user_id, .. } => *user_id,
1395 Contact::Outgoing { user_id } => *user_id,
1396 Contact::Incoming { user_id, .. } => *user_id,
1397 }
1398 }
1399}
1400
1401#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1402pub struct IncomingContactRequest {
1403 pub requester_id: UserId,
1404 pub should_notify: bool,
1405}
1406
1407fn fuzzy_like_string(string: &str) -> String {
1408 let mut result = String::with_capacity(string.len() * 2 + 1);
1409 for c in string.chars() {
1410 if c.is_alphanumeric() {
1411 result.push('%');
1412 result.push(c);
1413 }
1414 }
1415 result.push('%');
1416 result
1417}
1418
1419#[cfg(test)]
1420pub mod tests {
1421 use super::*;
1422 use anyhow::anyhow;
1423 use collections::BTreeMap;
1424 use gpui::executor::{Background, Deterministic};
1425 use lazy_static::lazy_static;
1426 use parking_lot::Mutex;
1427 use rand::prelude::*;
1428 use sqlx::{
1429 migrate::{MigrateDatabase, Migrator},
1430 Postgres,
1431 };
1432 use std::{path::Path, sync::Arc};
1433 use util::post_inc;
1434
1435 #[tokio::test(flavor = "multi_thread")]
1436 async fn test_get_users_by_ids() {
1437 for test_db in [
1438 TestDb::postgres().await,
1439 TestDb::fake(build_background_executor()),
1440 ] {
1441 let db = test_db.db();
1442
1443 let user = db.create_user("user", None, false).await.unwrap();
1444 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1445 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1446 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1447
1448 assert_eq!(
1449 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1450 .await
1451 .unwrap(),
1452 vec![
1453 User {
1454 id: user,
1455 github_login: "user".to_string(),
1456 admin: false,
1457 ..Default::default()
1458 },
1459 User {
1460 id: friend1,
1461 github_login: "friend-1".to_string(),
1462 admin: false,
1463 ..Default::default()
1464 },
1465 User {
1466 id: friend2,
1467 github_login: "friend-2".to_string(),
1468 admin: false,
1469 ..Default::default()
1470 },
1471 User {
1472 id: friend3,
1473 github_login: "friend-3".to_string(),
1474 admin: false,
1475 ..Default::default()
1476 }
1477 ]
1478 );
1479 }
1480 }
1481
1482 #[tokio::test(flavor = "multi_thread")]
1483 async fn test_create_users() {
1484 let db = TestDb::postgres().await;
1485 let db = db.db();
1486
1487 // Create the first batch of users, ensuring invite counts are assigned
1488 // correctly and the respective invite codes are unique.
1489 let user_ids_batch_1 = db
1490 .create_users(vec![
1491 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1492 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1493 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1494 ])
1495 .await
1496 .unwrap();
1497 assert_eq!(user_ids_batch_1.len(), 3);
1498
1499 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1500 assert_eq!(users.len(), 3);
1501 assert_eq!(users[0].github_login, "user1");
1502 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1503 assert_eq!(users[0].invite_count, 5);
1504 assert_eq!(users[1].github_login, "user2");
1505 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1506 assert_eq!(users[1].invite_count, 4);
1507 assert_eq!(users[2].github_login, "user3");
1508 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1509 assert_eq!(users[2].invite_count, 3);
1510
1511 let invite_code_1 = users[0].invite_code.clone().unwrap();
1512 let invite_code_2 = users[1].invite_code.clone().unwrap();
1513 let invite_code_3 = users[2].invite_code.clone().unwrap();
1514 assert_ne!(invite_code_1, invite_code_2);
1515 assert_ne!(invite_code_1, invite_code_3);
1516 assert_ne!(invite_code_2, invite_code_3);
1517
1518 // Create the second batch of users and include a user that is already in the database, ensuring
1519 // the invite count for the existing user is updated without changing their invite code.
1520 let user_ids_batch_2 = db
1521 .create_users(vec![
1522 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1523 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1524 ])
1525 .await
1526 .unwrap();
1527 assert_eq!(user_ids_batch_2.len(), 2);
1528 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1529
1530 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1531 assert_eq!(users.len(), 2);
1532 assert_eq!(users[0].github_login, "user2");
1533 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1534 assert_eq!(users[0].invite_count, 10);
1535 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1536 assert_eq!(users[1].github_login, "user4");
1537 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1538 assert_eq!(users[1].invite_count, 2);
1539
1540 let invite_code_4 = users[1].invite_code.clone().unwrap();
1541 assert_ne!(invite_code_4, invite_code_1);
1542 assert_ne!(invite_code_4, invite_code_2);
1543 assert_ne!(invite_code_4, invite_code_3);
1544 }
1545
1546 #[tokio::test(flavor = "multi_thread")]
1547 async fn test_worktree_extensions() {
1548 let test_db = TestDb::postgres().await;
1549 let db = test_db.db();
1550
1551 let user = db.create_user("user_1", None, false).await.unwrap();
1552 let project = db.register_project(user).await.unwrap();
1553
1554 db.update_worktree_extensions(project, 100, Default::default())
1555 .await
1556 .unwrap();
1557 db.update_worktree_extensions(
1558 project,
1559 100,
1560 [("rs".to_string(), 5), ("md".to_string(), 3)]
1561 .into_iter()
1562 .collect(),
1563 )
1564 .await
1565 .unwrap();
1566 db.update_worktree_extensions(
1567 project,
1568 100,
1569 [("rs".to_string(), 6), ("md".to_string(), 5)]
1570 .into_iter()
1571 .collect(),
1572 )
1573 .await
1574 .unwrap();
1575 db.update_worktree_extensions(
1576 project,
1577 101,
1578 [("ts".to_string(), 2), ("md".to_string(), 1)]
1579 .into_iter()
1580 .collect(),
1581 )
1582 .await
1583 .unwrap();
1584
1585 assert_eq!(
1586 db.get_project_extensions(project).await.unwrap(),
1587 [
1588 (
1589 100,
1590 [("rs".into(), 6), ("md".into(), 5),]
1591 .into_iter()
1592 .collect::<HashMap<_, _>>()
1593 ),
1594 (
1595 101,
1596 [("ts".into(), 2), ("md".into(), 1),]
1597 .into_iter()
1598 .collect::<HashMap<_, _>>()
1599 )
1600 ]
1601 .into_iter()
1602 .collect()
1603 );
1604 }
1605
1606 #[tokio::test(flavor = "multi_thread")]
1607 async fn test_user_activity() {
1608 let test_db = TestDb::postgres().await;
1609 let db = test_db.db();
1610
1611 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1612 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1613 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1614 let project_1 = db.register_project(user_1).await.unwrap();
1615 db.update_worktree_extensions(
1616 project_1,
1617 1,
1618 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1619 )
1620 .await
1621 .unwrap();
1622 let project_2 = db.register_project(user_2).await.unwrap();
1623 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1624
1625 // User 2 opens a project
1626 let t1 = t0 + Duration::from_secs(10);
1627 db.record_user_activity(t0..t1, &[(user_2, project_2)])
1628 .await
1629 .unwrap();
1630
1631 let t2 = t1 + Duration::from_secs(10);
1632 db.record_user_activity(t1..t2, &[(user_2, project_2)])
1633 .await
1634 .unwrap();
1635
1636 // User 1 joins the project
1637 let t3 = t2 + Duration::from_secs(10);
1638 db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1639 .await
1640 .unwrap();
1641
1642 // User 1 opens another project
1643 let t4 = t3 + Duration::from_secs(10);
1644 db.record_user_activity(
1645 t3..t4,
1646 &[
1647 (user_2, project_2),
1648 (user_1, project_2),
1649 (user_1, project_1),
1650 ],
1651 )
1652 .await
1653 .unwrap();
1654
1655 // User 3 joins that project
1656 let t5 = t4 + Duration::from_secs(10);
1657 db.record_user_activity(
1658 t4..t5,
1659 &[
1660 (user_2, project_2),
1661 (user_1, project_2),
1662 (user_1, project_1),
1663 (user_3, project_1),
1664 ],
1665 )
1666 .await
1667 .unwrap();
1668
1669 // User 2 leaves
1670 let t6 = t5 + Duration::from_secs(5);
1671 db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1672 .await
1673 .unwrap();
1674
1675 let t7 = t6 + Duration::from_secs(60);
1676 let t8 = t7 + Duration::from_secs(10);
1677 db.record_user_activity(t7..t8, &[(user_1, project_1)])
1678 .await
1679 .unwrap();
1680
1681 assert_eq!(
1682 db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1683 &[
1684 UserActivitySummary {
1685 id: user_1,
1686 github_login: "user_1".to_string(),
1687 project_activity: vec![
1688 ProjectActivitySummary {
1689 id: project_1,
1690 duration: Duration::from_secs(25),
1691 max_collaborators: 2
1692 },
1693 ProjectActivitySummary {
1694 id: project_2,
1695 duration: Duration::from_secs(30),
1696 max_collaborators: 2
1697 }
1698 ]
1699 },
1700 UserActivitySummary {
1701 id: user_2,
1702 github_login: "user_2".to_string(),
1703 project_activity: vec![ProjectActivitySummary {
1704 id: project_2,
1705 duration: Duration::from_secs(50),
1706 max_collaborators: 2
1707 }]
1708 },
1709 UserActivitySummary {
1710 id: user_3,
1711 github_login: "user_3".to_string(),
1712 project_activity: vec![ProjectActivitySummary {
1713 id: project_1,
1714 duration: Duration::from_secs(15),
1715 max_collaborators: 2
1716 }]
1717 },
1718 ]
1719 );
1720
1721 assert_eq!(
1722 db.get_active_user_count(t0..t6, Duration::from_secs(56))
1723 .await
1724 .unwrap(),
1725 0
1726 );
1727 assert_eq!(
1728 db.get_active_user_count(t0..t6, Duration::from_secs(54))
1729 .await
1730 .unwrap(),
1731 1
1732 );
1733 assert_eq!(
1734 db.get_active_user_count(t0..t6, Duration::from_secs(30))
1735 .await
1736 .unwrap(),
1737 2
1738 );
1739 assert_eq!(
1740 db.get_active_user_count(t0..t6, Duration::from_secs(10))
1741 .await
1742 .unwrap(),
1743 3
1744 );
1745
1746 assert_eq!(
1747 db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1748 &[
1749 UserActivityPeriod {
1750 project_id: project_1,
1751 start: t3,
1752 end: t6,
1753 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1754 },
1755 UserActivityPeriod {
1756 project_id: project_2,
1757 start: t3,
1758 end: t5,
1759 extensions: Default::default(),
1760 },
1761 ]
1762 );
1763 assert_eq!(
1764 db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1765 &[
1766 UserActivityPeriod {
1767 project_id: project_2,
1768 start: t2,
1769 end: t5,
1770 extensions: Default::default(),
1771 },
1772 UserActivityPeriod {
1773 project_id: project_1,
1774 start: t3,
1775 end: t6,
1776 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1777 },
1778 UserActivityPeriod {
1779 project_id: project_1,
1780 start: t7,
1781 end: t8,
1782 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1783 },
1784 ]
1785 );
1786 }
1787
1788 #[tokio::test(flavor = "multi_thread")]
1789 async fn test_recent_channel_messages() {
1790 for test_db in [
1791 TestDb::postgres().await,
1792 TestDb::fake(build_background_executor()),
1793 ] {
1794 let db = test_db.db();
1795 let user = db.create_user("user", None, false).await.unwrap();
1796 let org = db.create_org("org", "org").await.unwrap();
1797 let channel = db.create_org_channel(org, "channel").await.unwrap();
1798 for i in 0..10 {
1799 db.create_channel_message(
1800 channel,
1801 user,
1802 &i.to_string(),
1803 OffsetDateTime::now_utc(),
1804 i,
1805 )
1806 .await
1807 .unwrap();
1808 }
1809
1810 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1811 assert_eq!(
1812 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1813 ["5", "6", "7", "8", "9"]
1814 );
1815
1816 let prev_messages = db
1817 .get_channel_messages(channel, 4, Some(messages[0].id))
1818 .await
1819 .unwrap();
1820 assert_eq!(
1821 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1822 ["1", "2", "3", "4"]
1823 );
1824 }
1825 }
1826
1827 #[tokio::test(flavor = "multi_thread")]
1828 async fn test_channel_message_nonces() {
1829 for test_db in [
1830 TestDb::postgres().await,
1831 TestDb::fake(build_background_executor()),
1832 ] {
1833 let db = test_db.db();
1834 let user = db.create_user("user", None, false).await.unwrap();
1835 let org = db.create_org("org", "org").await.unwrap();
1836 let channel = db.create_org_channel(org, "channel").await.unwrap();
1837
1838 let msg1_id = db
1839 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1840 .await
1841 .unwrap();
1842 let msg2_id = db
1843 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1844 .await
1845 .unwrap();
1846 let msg3_id = db
1847 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1848 .await
1849 .unwrap();
1850 let msg4_id = db
1851 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1852 .await
1853 .unwrap();
1854
1855 assert_ne!(msg1_id, msg2_id);
1856 assert_eq!(msg1_id, msg3_id);
1857 assert_eq!(msg2_id, msg4_id);
1858 }
1859 }
1860
1861 #[tokio::test(flavor = "multi_thread")]
1862 async fn test_create_access_tokens() {
1863 let test_db = TestDb::postgres().await;
1864 let db = test_db.db();
1865 let user = db.create_user("the-user", None, false).await.unwrap();
1866
1867 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1868 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1869 assert_eq!(
1870 db.get_access_token_hashes(user).await.unwrap(),
1871 &["h2".to_string(), "h1".to_string()]
1872 );
1873
1874 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1875 assert_eq!(
1876 db.get_access_token_hashes(user).await.unwrap(),
1877 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1878 );
1879
1880 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1881 assert_eq!(
1882 db.get_access_token_hashes(user).await.unwrap(),
1883 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1884 );
1885
1886 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1887 assert_eq!(
1888 db.get_access_token_hashes(user).await.unwrap(),
1889 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1890 );
1891 }
1892
1893 #[test]
1894 fn test_fuzzy_like_string() {
1895 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1896 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1897 assert_eq!(fuzzy_like_string(" z "), "%z%");
1898 }
1899
1900 #[tokio::test(flavor = "multi_thread")]
1901 async fn test_fuzzy_search_users() {
1902 let test_db = TestDb::postgres().await;
1903 let db = test_db.db();
1904 for github_login in [
1905 "California",
1906 "colorado",
1907 "oregon",
1908 "washington",
1909 "florida",
1910 "delaware",
1911 "rhode-island",
1912 ] {
1913 db.create_user(github_login, None, false).await.unwrap();
1914 }
1915
1916 assert_eq!(
1917 fuzzy_search_user_names(db, "clr").await,
1918 &["colorado", "California"]
1919 );
1920 assert_eq!(
1921 fuzzy_search_user_names(db, "ro").await,
1922 &["rhode-island", "colorado", "oregon"],
1923 );
1924
1925 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1926 db.fuzzy_search_users(query, 10)
1927 .await
1928 .unwrap()
1929 .into_iter()
1930 .map(|user| user.github_login)
1931 .collect::<Vec<_>>()
1932 }
1933 }
1934
1935 #[tokio::test(flavor = "multi_thread")]
1936 async fn test_add_contacts() {
1937 for test_db in [
1938 TestDb::postgres().await,
1939 TestDb::fake(build_background_executor()),
1940 ] {
1941 let db = test_db.db();
1942
1943 let user_1 = db.create_user("user1", None, false).await.unwrap();
1944 let user_2 = db.create_user("user2", None, false).await.unwrap();
1945 let user_3 = db.create_user("user3", None, false).await.unwrap();
1946
1947 // User starts with no contacts
1948 assert_eq!(
1949 db.get_contacts(user_1).await.unwrap(),
1950 vec![Contact::Accepted {
1951 user_id: user_1,
1952 should_notify: false
1953 }],
1954 );
1955
1956 // User requests a contact. Both users see the pending request.
1957 db.send_contact_request(user_1, user_2).await.unwrap();
1958 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1959 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1960 assert_eq!(
1961 db.get_contacts(user_1).await.unwrap(),
1962 &[
1963 Contact::Accepted {
1964 user_id: user_1,
1965 should_notify: false
1966 },
1967 Contact::Outgoing { user_id: user_2 }
1968 ],
1969 );
1970 assert_eq!(
1971 db.get_contacts(user_2).await.unwrap(),
1972 &[
1973 Contact::Incoming {
1974 user_id: user_1,
1975 should_notify: true
1976 },
1977 Contact::Accepted {
1978 user_id: user_2,
1979 should_notify: false
1980 },
1981 ]
1982 );
1983
1984 // User 2 dismisses the contact request notification without accepting or rejecting.
1985 // We shouldn't notify them again.
1986 db.dismiss_contact_notification(user_1, user_2)
1987 .await
1988 .unwrap_err();
1989 db.dismiss_contact_notification(user_2, user_1)
1990 .await
1991 .unwrap();
1992 assert_eq!(
1993 db.get_contacts(user_2).await.unwrap(),
1994 &[
1995 Contact::Incoming {
1996 user_id: user_1,
1997 should_notify: false
1998 },
1999 Contact::Accepted {
2000 user_id: user_2,
2001 should_notify: false
2002 },
2003 ]
2004 );
2005
2006 // User can't accept their own contact request
2007 db.respond_to_contact_request(user_1, user_2, true)
2008 .await
2009 .unwrap_err();
2010
2011 // User accepts a contact request. Both users see the contact.
2012 db.respond_to_contact_request(user_2, user_1, true)
2013 .await
2014 .unwrap();
2015 assert_eq!(
2016 db.get_contacts(user_1).await.unwrap(),
2017 &[
2018 Contact::Accepted {
2019 user_id: user_1,
2020 should_notify: false
2021 },
2022 Contact::Accepted {
2023 user_id: user_2,
2024 should_notify: true
2025 }
2026 ],
2027 );
2028 assert!(db.has_contact(user_1, user_2).await.unwrap());
2029 assert!(db.has_contact(user_2, user_1).await.unwrap());
2030 assert_eq!(
2031 db.get_contacts(user_2).await.unwrap(),
2032 &[
2033 Contact::Accepted {
2034 user_id: user_1,
2035 should_notify: false,
2036 },
2037 Contact::Accepted {
2038 user_id: user_2,
2039 should_notify: false,
2040 },
2041 ]
2042 );
2043
2044 // Users cannot re-request existing contacts.
2045 db.send_contact_request(user_1, user_2).await.unwrap_err();
2046 db.send_contact_request(user_2, user_1).await.unwrap_err();
2047
2048 // Users can't dismiss notifications of them accepting other users' requests.
2049 db.dismiss_contact_notification(user_2, user_1)
2050 .await
2051 .unwrap_err();
2052 assert_eq!(
2053 db.get_contacts(user_1).await.unwrap(),
2054 &[
2055 Contact::Accepted {
2056 user_id: user_1,
2057 should_notify: false
2058 },
2059 Contact::Accepted {
2060 user_id: user_2,
2061 should_notify: true,
2062 },
2063 ]
2064 );
2065
2066 // Users can dismiss notifications of other users accepting their requests.
2067 db.dismiss_contact_notification(user_1, user_2)
2068 .await
2069 .unwrap();
2070 assert_eq!(
2071 db.get_contacts(user_1).await.unwrap(),
2072 &[
2073 Contact::Accepted {
2074 user_id: user_1,
2075 should_notify: false
2076 },
2077 Contact::Accepted {
2078 user_id: user_2,
2079 should_notify: false,
2080 },
2081 ]
2082 );
2083
2084 // Users send each other concurrent contact requests and
2085 // see that they are immediately accepted.
2086 db.send_contact_request(user_1, user_3).await.unwrap();
2087 db.send_contact_request(user_3, user_1).await.unwrap();
2088 assert_eq!(
2089 db.get_contacts(user_1).await.unwrap(),
2090 &[
2091 Contact::Accepted {
2092 user_id: user_1,
2093 should_notify: false
2094 },
2095 Contact::Accepted {
2096 user_id: user_2,
2097 should_notify: false,
2098 },
2099 Contact::Accepted {
2100 user_id: user_3,
2101 should_notify: false
2102 },
2103 ]
2104 );
2105 assert_eq!(
2106 db.get_contacts(user_3).await.unwrap(),
2107 &[
2108 Contact::Accepted {
2109 user_id: user_1,
2110 should_notify: false
2111 },
2112 Contact::Accepted {
2113 user_id: user_3,
2114 should_notify: false
2115 }
2116 ],
2117 );
2118
2119 // User declines a contact request. Both users see that it is gone.
2120 db.send_contact_request(user_2, user_3).await.unwrap();
2121 db.respond_to_contact_request(user_3, user_2, false)
2122 .await
2123 .unwrap();
2124 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2125 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2126 assert_eq!(
2127 db.get_contacts(user_2).await.unwrap(),
2128 &[
2129 Contact::Accepted {
2130 user_id: user_1,
2131 should_notify: false
2132 },
2133 Contact::Accepted {
2134 user_id: user_2,
2135 should_notify: false
2136 }
2137 ]
2138 );
2139 assert_eq!(
2140 db.get_contacts(user_3).await.unwrap(),
2141 &[
2142 Contact::Accepted {
2143 user_id: user_1,
2144 should_notify: false
2145 },
2146 Contact::Accepted {
2147 user_id: user_3,
2148 should_notify: false
2149 }
2150 ],
2151 );
2152 }
2153 }
2154
2155 #[tokio::test(flavor = "multi_thread")]
2156 async fn test_invite_codes() {
2157 let postgres = TestDb::postgres().await;
2158 let db = postgres.db();
2159 let user1 = db.create_user("user-1", None, false).await.unwrap();
2160
2161 // Initially, user 1 has no invite code
2162 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2163
2164 // Setting invite count to 0 when no code is assigned does not assign a new code
2165 db.set_invite_count(user1, 0).await.unwrap();
2166 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2167
2168 // User 1 creates an invite code that can be used twice.
2169 db.set_invite_count(user1, 2).await.unwrap();
2170 let (invite_code, invite_count) =
2171 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2172 assert_eq!(invite_count, 2);
2173
2174 // User 2 redeems the invite code and becomes a contact of user 1.
2175 let user2 = db
2176 .redeem_invite_code(&invite_code, "user-2", None)
2177 .await
2178 .unwrap();
2179 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2180 assert_eq!(invite_count, 1);
2181 assert_eq!(
2182 db.get_contacts(user1).await.unwrap(),
2183 [
2184 Contact::Accepted {
2185 user_id: user1,
2186 should_notify: false
2187 },
2188 Contact::Accepted {
2189 user_id: user2,
2190 should_notify: true
2191 }
2192 ]
2193 );
2194 assert_eq!(
2195 db.get_contacts(user2).await.unwrap(),
2196 [
2197 Contact::Accepted {
2198 user_id: user1,
2199 should_notify: false
2200 },
2201 Contact::Accepted {
2202 user_id: user2,
2203 should_notify: false
2204 }
2205 ]
2206 );
2207
2208 // User 3 redeems the invite code and becomes a contact of user 1.
2209 let user3 = db
2210 .redeem_invite_code(&invite_code, "user-3", None)
2211 .await
2212 .unwrap();
2213 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2214 assert_eq!(invite_count, 0);
2215 assert_eq!(
2216 db.get_contacts(user1).await.unwrap(),
2217 [
2218 Contact::Accepted {
2219 user_id: user1,
2220 should_notify: false
2221 },
2222 Contact::Accepted {
2223 user_id: user2,
2224 should_notify: true
2225 },
2226 Contact::Accepted {
2227 user_id: user3,
2228 should_notify: true
2229 }
2230 ]
2231 );
2232 assert_eq!(
2233 db.get_contacts(user3).await.unwrap(),
2234 [
2235 Contact::Accepted {
2236 user_id: user1,
2237 should_notify: false
2238 },
2239 Contact::Accepted {
2240 user_id: user3,
2241 should_notify: false
2242 },
2243 ]
2244 );
2245
2246 // Trying to reedem the code for the third time results in an error.
2247 db.redeem_invite_code(&invite_code, "user-4", None)
2248 .await
2249 .unwrap_err();
2250
2251 // Invite count can be updated after the code has been created.
2252 db.set_invite_count(user1, 2).await.unwrap();
2253 let (latest_code, invite_count) =
2254 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2255 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2256 assert_eq!(invite_count, 2);
2257
2258 // User 4 can now redeem the invite code and becomes a contact of user 1.
2259 let user4 = db
2260 .redeem_invite_code(&invite_code, "user-4", None)
2261 .await
2262 .unwrap();
2263 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2264 assert_eq!(invite_count, 1);
2265 assert_eq!(
2266 db.get_contacts(user1).await.unwrap(),
2267 [
2268 Contact::Accepted {
2269 user_id: user1,
2270 should_notify: false
2271 },
2272 Contact::Accepted {
2273 user_id: user2,
2274 should_notify: true
2275 },
2276 Contact::Accepted {
2277 user_id: user3,
2278 should_notify: true
2279 },
2280 Contact::Accepted {
2281 user_id: user4,
2282 should_notify: true
2283 }
2284 ]
2285 );
2286 assert_eq!(
2287 db.get_contacts(user4).await.unwrap(),
2288 [
2289 Contact::Accepted {
2290 user_id: user1,
2291 should_notify: false
2292 },
2293 Contact::Accepted {
2294 user_id: user4,
2295 should_notify: false
2296 },
2297 ]
2298 );
2299
2300 // An existing user cannot redeem invite codes.
2301 db.redeem_invite_code(&invite_code, "user-2", None)
2302 .await
2303 .unwrap_err();
2304 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2305 assert_eq!(invite_count, 1);
2306 }
2307
2308 pub struct TestDb {
2309 pub db: Option<Arc<dyn Db>>,
2310 pub url: String,
2311 }
2312
2313 impl TestDb {
2314 pub async fn postgres() -> Self {
2315 lazy_static! {
2316 static ref LOCK: Mutex<()> = Mutex::new(());
2317 }
2318
2319 let _guard = LOCK.lock();
2320 let mut rng = StdRng::from_entropy();
2321 let name = format!("zed-test-{}", rng.gen::<u128>());
2322 let url = format!("postgres://postgres@localhost/{}", name);
2323 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2324 Postgres::create_database(&url)
2325 .await
2326 .expect("failed to create test db");
2327 let db = PostgresDb::new(&url, 5).await.unwrap();
2328 let migrator = Migrator::new(migrations_path).await.unwrap();
2329 migrator.run(&db.pool).await.unwrap();
2330 Self {
2331 db: Some(Arc::new(db)),
2332 url,
2333 }
2334 }
2335
2336 pub fn fake(background: Arc<Background>) -> Self {
2337 Self {
2338 db: Some(Arc::new(FakeDb::new(background))),
2339 url: Default::default(),
2340 }
2341 }
2342
2343 pub fn db(&self) -> &Arc<dyn Db> {
2344 self.db.as_ref().unwrap()
2345 }
2346 }
2347
2348 impl Drop for TestDb {
2349 fn drop(&mut self) {
2350 if let Some(db) = self.db.take() {
2351 futures::executor::block_on(db.teardown(&self.url));
2352 }
2353 }
2354 }
2355
2356 pub struct FakeDb {
2357 background: Arc<Background>,
2358 pub users: Mutex<BTreeMap<UserId, User>>,
2359 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2360 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2361 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2362 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2363 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2364 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2365 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2366 pub contacts: Mutex<Vec<FakeContact>>,
2367 next_channel_message_id: Mutex<i32>,
2368 next_user_id: Mutex<i32>,
2369 next_org_id: Mutex<i32>,
2370 next_channel_id: Mutex<i32>,
2371 next_project_id: Mutex<i32>,
2372 }
2373
2374 #[derive(Debug)]
2375 pub struct FakeContact {
2376 pub requester_id: UserId,
2377 pub responder_id: UserId,
2378 pub accepted: bool,
2379 pub should_notify: bool,
2380 }
2381
2382 impl FakeDb {
2383 pub fn new(background: Arc<Background>) -> Self {
2384 Self {
2385 background,
2386 users: Default::default(),
2387 next_user_id: Mutex::new(0),
2388 projects: Default::default(),
2389 worktree_extensions: Default::default(),
2390 next_project_id: Mutex::new(1),
2391 orgs: Default::default(),
2392 next_org_id: Mutex::new(1),
2393 org_memberships: Default::default(),
2394 channels: Default::default(),
2395 next_channel_id: Mutex::new(1),
2396 channel_memberships: Default::default(),
2397 channel_messages: Default::default(),
2398 next_channel_message_id: Mutex::new(1),
2399 contacts: Default::default(),
2400 }
2401 }
2402 }
2403
2404 #[async_trait]
2405 impl Db for FakeDb {
2406 async fn create_user(
2407 &self,
2408 github_login: &str,
2409 email_address: Option<&str>,
2410 admin: bool,
2411 ) -> Result<UserId> {
2412 self.background.simulate_random_delay().await;
2413
2414 let mut users = self.users.lock();
2415 if let Some(user) = users
2416 .values()
2417 .find(|user| user.github_login == github_login)
2418 {
2419 Ok(user.id)
2420 } else {
2421 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2422 users.insert(
2423 user_id,
2424 User {
2425 id: user_id,
2426 github_login: github_login.to_string(),
2427 email_address: email_address.map(str::to_string),
2428 admin,
2429 invite_code: None,
2430 invite_count: 0,
2431 connected_once: false,
2432 },
2433 );
2434 Ok(user_id)
2435 }
2436 }
2437
2438 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2439 unimplemented!()
2440 }
2441
2442 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2443 unimplemented!()
2444 }
2445
2446 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2447 unimplemented!()
2448 }
2449
2450 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2451 self.background.simulate_random_delay().await;
2452 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2453 }
2454
2455 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2456 self.background.simulate_random_delay().await;
2457 let users = self.users.lock();
2458 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2459 }
2460
2461 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2462 unimplemented!()
2463 }
2464
2465 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2466 self.background.simulate_random_delay().await;
2467 Ok(self
2468 .users
2469 .lock()
2470 .values()
2471 .find(|user| user.github_login == github_login)
2472 .cloned())
2473 }
2474
2475 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2476 unimplemented!()
2477 }
2478
2479 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2480 self.background.simulate_random_delay().await;
2481 let mut users = self.users.lock();
2482 let mut user = users
2483 .get_mut(&id)
2484 .ok_or_else(|| anyhow!("user not found"))?;
2485 user.connected_once = connected_once;
2486 Ok(())
2487 }
2488
2489 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2490 unimplemented!()
2491 }
2492
2493 // invite codes
2494
2495 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2496 unimplemented!()
2497 }
2498
2499 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2500 self.background.simulate_random_delay().await;
2501 Ok(None)
2502 }
2503
2504 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2505 unimplemented!()
2506 }
2507
2508 async fn redeem_invite_code(
2509 &self,
2510 _code: &str,
2511 _login: &str,
2512 _email_address: Option<&str>,
2513 ) -> Result<UserId> {
2514 unimplemented!()
2515 }
2516
2517 // projects
2518
2519 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2520 self.background.simulate_random_delay().await;
2521 if !self.users.lock().contains_key(&host_user_id) {
2522 Err(anyhow!("no such user"))?;
2523 }
2524
2525 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2526 self.projects.lock().insert(
2527 project_id,
2528 Project {
2529 id: project_id,
2530 host_user_id,
2531 unregistered: false,
2532 },
2533 );
2534 Ok(project_id)
2535 }
2536
2537 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2538 self.background.simulate_random_delay().await;
2539 self.projects
2540 .lock()
2541 .get_mut(&project_id)
2542 .ok_or_else(|| anyhow!("no such project"))?
2543 .unregistered = true;
2544 Ok(())
2545 }
2546
2547 async fn update_worktree_extensions(
2548 &self,
2549 project_id: ProjectId,
2550 worktree_id: u64,
2551 extensions: HashMap<String, u32>,
2552 ) -> Result<()> {
2553 self.background.simulate_random_delay().await;
2554 if !self.projects.lock().contains_key(&project_id) {
2555 Err(anyhow!("no such project"))?;
2556 }
2557
2558 for (extension, count) in extensions {
2559 self.worktree_extensions
2560 .lock()
2561 .insert((project_id, worktree_id, extension), count);
2562 }
2563
2564 Ok(())
2565 }
2566
2567 async fn get_project_extensions(
2568 &self,
2569 _project_id: ProjectId,
2570 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2571 unimplemented!()
2572 }
2573
2574 async fn record_user_activity(
2575 &self,
2576 _time_period: Range<OffsetDateTime>,
2577 _active_projects: &[(UserId, ProjectId)],
2578 ) -> Result<()> {
2579 unimplemented!()
2580 }
2581
2582 async fn get_active_user_count(
2583 &self,
2584 _time_period: Range<OffsetDateTime>,
2585 _min_duration: Duration,
2586 ) -> Result<usize> {
2587 unimplemented!()
2588 }
2589
2590 async fn get_top_users_activity_summary(
2591 &self,
2592 _time_period: Range<OffsetDateTime>,
2593 _limit: usize,
2594 ) -> Result<Vec<UserActivitySummary>> {
2595 unimplemented!()
2596 }
2597
2598 async fn get_user_activity_timeline(
2599 &self,
2600 _time_period: Range<OffsetDateTime>,
2601 _user_id: UserId,
2602 ) -> Result<Vec<UserActivityPeriod>> {
2603 unimplemented!()
2604 }
2605
2606 // contacts
2607
2608 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2609 self.background.simulate_random_delay().await;
2610 let mut contacts = vec![Contact::Accepted {
2611 user_id: id,
2612 should_notify: false,
2613 }];
2614
2615 for contact in self.contacts.lock().iter() {
2616 if contact.requester_id == id {
2617 if contact.accepted {
2618 contacts.push(Contact::Accepted {
2619 user_id: contact.responder_id,
2620 should_notify: contact.should_notify,
2621 });
2622 } else {
2623 contacts.push(Contact::Outgoing {
2624 user_id: contact.responder_id,
2625 });
2626 }
2627 } else if contact.responder_id == id {
2628 if contact.accepted {
2629 contacts.push(Contact::Accepted {
2630 user_id: contact.requester_id,
2631 should_notify: false,
2632 });
2633 } else {
2634 contacts.push(Contact::Incoming {
2635 user_id: contact.requester_id,
2636 should_notify: contact.should_notify,
2637 });
2638 }
2639 }
2640 }
2641
2642 contacts.sort_unstable_by_key(|contact| contact.user_id());
2643 Ok(contacts)
2644 }
2645
2646 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2647 self.background.simulate_random_delay().await;
2648 Ok(self.contacts.lock().iter().any(|contact| {
2649 contact.accepted
2650 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2651 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2652 }))
2653 }
2654
2655 async fn send_contact_request(
2656 &self,
2657 requester_id: UserId,
2658 responder_id: UserId,
2659 ) -> Result<()> {
2660 self.background.simulate_random_delay().await;
2661 let mut contacts = self.contacts.lock();
2662 for contact in contacts.iter_mut() {
2663 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2664 if contact.accepted {
2665 Err(anyhow!("contact already exists"))?;
2666 } else {
2667 Err(anyhow!("contact already requested"))?;
2668 }
2669 }
2670 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2671 if contact.accepted {
2672 Err(anyhow!("contact already exists"))?;
2673 } else {
2674 contact.accepted = true;
2675 contact.should_notify = false;
2676 return Ok(());
2677 }
2678 }
2679 }
2680 contacts.push(FakeContact {
2681 requester_id,
2682 responder_id,
2683 accepted: false,
2684 should_notify: true,
2685 });
2686 Ok(())
2687 }
2688
2689 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2690 self.background.simulate_random_delay().await;
2691 self.contacts.lock().retain(|contact| {
2692 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2693 });
2694 Ok(())
2695 }
2696
2697 async fn dismiss_contact_notification(
2698 &self,
2699 user_id: UserId,
2700 contact_user_id: UserId,
2701 ) -> Result<()> {
2702 self.background.simulate_random_delay().await;
2703 let mut contacts = self.contacts.lock();
2704 for contact in contacts.iter_mut() {
2705 if contact.requester_id == contact_user_id
2706 && contact.responder_id == user_id
2707 && !contact.accepted
2708 {
2709 contact.should_notify = false;
2710 return Ok(());
2711 }
2712 if contact.requester_id == user_id
2713 && contact.responder_id == contact_user_id
2714 && contact.accepted
2715 {
2716 contact.should_notify = false;
2717 return Ok(());
2718 }
2719 }
2720 Err(anyhow!("no such notification"))?
2721 }
2722
2723 async fn respond_to_contact_request(
2724 &self,
2725 responder_id: UserId,
2726 requester_id: UserId,
2727 accept: bool,
2728 ) -> Result<()> {
2729 self.background.simulate_random_delay().await;
2730 let mut contacts = self.contacts.lock();
2731 for (ix, contact) in contacts.iter_mut().enumerate() {
2732 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2733 if contact.accepted {
2734 Err(anyhow!("contact already confirmed"))?;
2735 }
2736 if accept {
2737 contact.accepted = true;
2738 contact.should_notify = true;
2739 } else {
2740 contacts.remove(ix);
2741 }
2742 return Ok(());
2743 }
2744 }
2745 Err(anyhow!("no such contact request"))?
2746 }
2747
2748 async fn create_access_token_hash(
2749 &self,
2750 _user_id: UserId,
2751 _access_token_hash: &str,
2752 _max_access_token_count: usize,
2753 ) -> Result<()> {
2754 unimplemented!()
2755 }
2756
2757 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2758 unimplemented!()
2759 }
2760
2761 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2762 unimplemented!()
2763 }
2764
2765 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2766 self.background.simulate_random_delay().await;
2767 let mut orgs = self.orgs.lock();
2768 if orgs.values().any(|org| org.slug == slug) {
2769 Err(anyhow!("org already exists"))?
2770 } else {
2771 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2772 orgs.insert(
2773 org_id,
2774 Org {
2775 id: org_id,
2776 name: name.to_string(),
2777 slug: slug.to_string(),
2778 },
2779 );
2780 Ok(org_id)
2781 }
2782 }
2783
2784 async fn add_org_member(
2785 &self,
2786 org_id: OrgId,
2787 user_id: UserId,
2788 is_admin: bool,
2789 ) -> Result<()> {
2790 self.background.simulate_random_delay().await;
2791 if !self.orgs.lock().contains_key(&org_id) {
2792 Err(anyhow!("org does not exist"))?;
2793 }
2794 if !self.users.lock().contains_key(&user_id) {
2795 Err(anyhow!("user does not exist"))?;
2796 }
2797
2798 self.org_memberships
2799 .lock()
2800 .entry((org_id, user_id))
2801 .or_insert(is_admin);
2802 Ok(())
2803 }
2804
2805 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2806 self.background.simulate_random_delay().await;
2807 if !self.orgs.lock().contains_key(&org_id) {
2808 Err(anyhow!("org does not exist"))?;
2809 }
2810
2811 let mut channels = self.channels.lock();
2812 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2813 channels.insert(
2814 channel_id,
2815 Channel {
2816 id: channel_id,
2817 name: name.to_string(),
2818 owner_id: org_id.0,
2819 owner_is_user: false,
2820 },
2821 );
2822 Ok(channel_id)
2823 }
2824
2825 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2826 self.background.simulate_random_delay().await;
2827 Ok(self
2828 .channels
2829 .lock()
2830 .values()
2831 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2832 .cloned()
2833 .collect())
2834 }
2835
2836 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2837 self.background.simulate_random_delay().await;
2838 let channels = self.channels.lock();
2839 let memberships = self.channel_memberships.lock();
2840 Ok(channels
2841 .values()
2842 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2843 .cloned()
2844 .collect())
2845 }
2846
2847 async fn can_user_access_channel(
2848 &self,
2849 user_id: UserId,
2850 channel_id: ChannelId,
2851 ) -> Result<bool> {
2852 self.background.simulate_random_delay().await;
2853 Ok(self
2854 .channel_memberships
2855 .lock()
2856 .contains_key(&(channel_id, user_id)))
2857 }
2858
2859 async fn add_channel_member(
2860 &self,
2861 channel_id: ChannelId,
2862 user_id: UserId,
2863 is_admin: bool,
2864 ) -> Result<()> {
2865 self.background.simulate_random_delay().await;
2866 if !self.channels.lock().contains_key(&channel_id) {
2867 Err(anyhow!("channel does not exist"))?;
2868 }
2869 if !self.users.lock().contains_key(&user_id) {
2870 Err(anyhow!("user does not exist"))?;
2871 }
2872
2873 self.channel_memberships
2874 .lock()
2875 .entry((channel_id, user_id))
2876 .or_insert(is_admin);
2877 Ok(())
2878 }
2879
2880 async fn create_channel_message(
2881 &self,
2882 channel_id: ChannelId,
2883 sender_id: UserId,
2884 body: &str,
2885 timestamp: OffsetDateTime,
2886 nonce: u128,
2887 ) -> Result<MessageId> {
2888 self.background.simulate_random_delay().await;
2889 if !self.channels.lock().contains_key(&channel_id) {
2890 Err(anyhow!("channel does not exist"))?;
2891 }
2892 if !self.users.lock().contains_key(&sender_id) {
2893 Err(anyhow!("user does not exist"))?;
2894 }
2895
2896 let mut messages = self.channel_messages.lock();
2897 if let Some(message) = messages
2898 .values()
2899 .find(|message| message.nonce.as_u128() == nonce)
2900 {
2901 Ok(message.id)
2902 } else {
2903 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2904 messages.insert(
2905 message_id,
2906 ChannelMessage {
2907 id: message_id,
2908 channel_id,
2909 sender_id,
2910 body: body.to_string(),
2911 sent_at: timestamp,
2912 nonce: Uuid::from_u128(nonce),
2913 },
2914 );
2915 Ok(message_id)
2916 }
2917 }
2918
2919 async fn get_channel_messages(
2920 &self,
2921 channel_id: ChannelId,
2922 count: usize,
2923 before_id: Option<MessageId>,
2924 ) -> Result<Vec<ChannelMessage>> {
2925 self.background.simulate_random_delay().await;
2926 let mut messages = self
2927 .channel_messages
2928 .lock()
2929 .values()
2930 .rev()
2931 .filter(|message| {
2932 message.channel_id == channel_id
2933 && message.id < before_id.unwrap_or(MessageId::MAX)
2934 })
2935 .take(count)
2936 .cloned()
2937 .collect::<Vec<_>>();
2938 messages.sort_unstable_by_key(|message| message.id);
2939 Ok(messages)
2940 }
2941
2942 async fn teardown(&self, _: &str) {}
2943
2944 #[cfg(test)]
2945 fn as_fake(&self) -> Option<&FakeDb> {
2946 Some(self)
2947 }
2948 }
2949
2950 fn build_background_executor() -> Arc<Background> {
2951 Deterministic::new(0).build_background()
2952 }
2953}