1use std::{cmp, ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::{Deserialize, Serialize};
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
29 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
30 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
31 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
32 async fn destroy_user(&self, id: UserId) -> Result<()>;
33
34 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
35 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
36 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
37 async fn redeem_invite_code(
38 &self,
39 code: &str,
40 login: &str,
41 email_address: Option<&str>,
42 ) -> Result<UserId>;
43
44 /// Registers a new project for the given user.
45 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
46
47 /// Unregisters a project for the given project id.
48 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
49
50 /// Update file counts by extension for the given project and worktree.
51 async fn update_worktree_extensions(
52 &self,
53 project_id: ProjectId,
54 worktree_id: u64,
55 extensions: HashMap<String, u32>,
56 ) -> Result<()>;
57
58 /// Get the file counts on the given project keyed by their worktree and extension.
59 async fn get_project_extensions(
60 &self,
61 project_id: ProjectId,
62 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
63
64 /// Record which users have been active in which projects during
65 /// a given period of time.
66 async fn record_user_activity(
67 &self,
68 time_period: Range<OffsetDateTime>,
69 active_projects: &[(UserId, ProjectId)],
70 ) -> Result<()>;
71
72 /// Get the number of users who have been active in the given
73 /// time period for at least the given time duration.
74 async fn get_active_user_count(
75 &self,
76 time_period: Range<OffsetDateTime>,
77 min_duration: Duration,
78 ) -> Result<usize>;
79
80 /// Get the users that have been most active during the given time period,
81 /// along with the amount of time they have been active in each project.
82 async fn get_top_users_activity_summary(
83 &self,
84 time_period: Range<OffsetDateTime>,
85 max_user_count: usize,
86 ) -> Result<Vec<UserActivitySummary>>;
87
88 /// Get the project activity for the given user and time period.
89 async fn get_user_activity_timeline(
90 &self,
91 time_period: Range<OffsetDateTime>,
92 user_id: UserId,
93 ) -> Result<Vec<UserActivityPeriod>>;
94
95 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
96 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
97 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
98 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
99 async fn dismiss_contact_notification(
100 &self,
101 responder_id: UserId,
102 requester_id: UserId,
103 ) -> Result<()>;
104 async fn respond_to_contact_request(
105 &self,
106 responder_id: UserId,
107 requester_id: UserId,
108 accept: bool,
109 ) -> Result<()>;
110
111 async fn create_access_token_hash(
112 &self,
113 user_id: UserId,
114 access_token_hash: &str,
115 max_access_token_count: usize,
116 ) -> Result<()>;
117 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
118 #[cfg(any(test, feature = "seed-support"))]
119
120 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
121 #[cfg(any(test, feature = "seed-support"))]
122 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
123 #[cfg(any(test, feature = "seed-support"))]
124 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
125 #[cfg(any(test, feature = "seed-support"))]
126 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
127 #[cfg(any(test, feature = "seed-support"))]
128
129 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
130 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
131 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
132 -> Result<bool>;
133 #[cfg(any(test, feature = "seed-support"))]
134 async fn add_channel_member(
135 &self,
136 channel_id: ChannelId,
137 user_id: UserId,
138 is_admin: bool,
139 ) -> Result<()>;
140 async fn create_channel_message(
141 &self,
142 channel_id: ChannelId,
143 sender_id: UserId,
144 body: &str,
145 timestamp: OffsetDateTime,
146 nonce: u128,
147 ) -> Result<MessageId>;
148 async fn get_channel_messages(
149 &self,
150 channel_id: ChannelId,
151 count: usize,
152 before_id: Option<MessageId>,
153 ) -> Result<Vec<ChannelMessage>>;
154 #[cfg(test)]
155 async fn teardown(&self, url: &str);
156 #[cfg(test)]
157 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
158}
159
160pub struct PostgresDb {
161 pool: sqlx::PgPool,
162}
163
164impl PostgresDb {
165 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
166 let pool = DbOptions::new()
167 .max_connections(max_connections)
168 .connect(&url)
169 .await
170 .context("failed to connect to postgres database")?;
171 Ok(Self { pool })
172 }
173}
174
175#[async_trait]
176impl Db for PostgresDb {
177 // users
178
179 async fn create_user(
180 &self,
181 github_login: &str,
182 email_address: Option<&str>,
183 admin: bool,
184 ) -> Result<UserId> {
185 let query = "
186 INSERT INTO users (github_login, email_address, admin)
187 VALUES ($1, $2, $3)
188 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
189 RETURNING id
190 ";
191 Ok(sqlx::query_scalar(query)
192 .bind(github_login)
193 .bind(email_address)
194 .bind(admin)
195 .fetch_one(&self.pool)
196 .await
197 .map(UserId)?)
198 }
199
200 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
201 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
202 Ok(sqlx::query_as(query)
203 .bind(limit as i32)
204 .bind((page * limit) as i32)
205 .fetch_all(&self.pool)
206 .await?)
207 }
208
209 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
210 let mut query = QueryBuilder::new(
211 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
212 );
213 query.push_values(
214 users,
215 |mut query, (github_login, email_address, invite_count)| {
216 query
217 .push_bind(github_login)
218 .push_bind(email_address)
219 .push_bind(false)
220 .push_bind(nanoid!(16))
221 .push_bind(invite_count as i32);
222 },
223 );
224 query.push(
225 "
226 ON CONFLICT (github_login) DO UPDATE SET
227 github_login = excluded.github_login,
228 invite_count = excluded.invite_count,
229 invite_code = CASE WHEN users.invite_code IS NULL
230 THEN excluded.invite_code
231 ELSE users.invite_code
232 END
233 RETURNING id
234 ",
235 );
236
237 let rows = query.build().fetch_all(&self.pool).await?;
238 Ok(rows
239 .into_iter()
240 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
241 .collect())
242 }
243
244 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
245 let like_string = fuzzy_like_string(name_query);
246 let query = "
247 SELECT users.*
248 FROM users
249 WHERE github_login ILIKE $1
250 ORDER BY github_login <-> $2
251 LIMIT $3
252 ";
253 Ok(sqlx::query_as(query)
254 .bind(like_string)
255 .bind(name_query)
256 .bind(limit as i32)
257 .fetch_all(&self.pool)
258 .await?)
259 }
260
261 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
262 let users = self.get_users_by_ids(vec![id]).await?;
263 Ok(users.into_iter().next())
264 }
265
266 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
267 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
268 let query = "
269 SELECT users.*
270 FROM users
271 WHERE users.id = ANY ($1)
272 ";
273 Ok(sqlx::query_as(query)
274 .bind(&ids)
275 .fetch_all(&self.pool)
276 .await?)
277 }
278
279 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
280 let query = format!(
281 "
282 SELECT users.*
283 FROM users
284 WHERE invite_count = 0
285 AND inviter_id IS{} NULL
286 ",
287 if invited_by_another_user { " NOT" } else { "" }
288 );
289
290 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
291 }
292
293 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
294 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
295 Ok(sqlx::query_as(query)
296 .bind(github_login)
297 .fetch_optional(&self.pool)
298 .await?)
299 }
300
301 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
302 let query = "UPDATE users SET admin = $1 WHERE id = $2";
303 Ok(sqlx::query(query)
304 .bind(is_admin)
305 .bind(id.0)
306 .execute(&self.pool)
307 .await
308 .map(drop)?)
309 }
310
311 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
312 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
313 Ok(sqlx::query(query)
314 .bind(connected_once)
315 .bind(id.0)
316 .execute(&self.pool)
317 .await
318 .map(drop)?)
319 }
320
321 async fn destroy_user(&self, id: UserId) -> Result<()> {
322 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
323 sqlx::query(query)
324 .bind(id.0)
325 .execute(&self.pool)
326 .await
327 .map(drop)?;
328 let query = "DELETE FROM users WHERE id = $1;";
329 Ok(sqlx::query(query)
330 .bind(id.0)
331 .execute(&self.pool)
332 .await
333 .map(drop)?)
334 }
335
336 // invite codes
337
338 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
339 let mut tx = self.pool.begin().await?;
340 if count > 0 {
341 sqlx::query(
342 "
343 UPDATE users
344 SET invite_code = $1
345 WHERE id = $2 AND invite_code IS NULL
346 ",
347 )
348 .bind(nanoid!(16))
349 .bind(id)
350 .execute(&mut tx)
351 .await?;
352 }
353
354 sqlx::query(
355 "
356 UPDATE users
357 SET invite_count = $1
358 WHERE id = $2
359 ",
360 )
361 .bind(count as i32)
362 .bind(id)
363 .execute(&mut tx)
364 .await?;
365 tx.commit().await?;
366 Ok(())
367 }
368
369 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
370 let result: Option<(String, i32)> = sqlx::query_as(
371 "
372 SELECT invite_code, invite_count
373 FROM users
374 WHERE id = $1 AND invite_code IS NOT NULL
375 ",
376 )
377 .bind(id)
378 .fetch_optional(&self.pool)
379 .await?;
380 if let Some((code, count)) = result {
381 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
382 } else {
383 Ok(None)
384 }
385 }
386
387 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
388 sqlx::query_as(
389 "
390 SELECT *
391 FROM users
392 WHERE invite_code = $1
393 ",
394 )
395 .bind(code)
396 .fetch_optional(&self.pool)
397 .await?
398 .ok_or_else(|| {
399 Error::Http(
400 StatusCode::NOT_FOUND,
401 "that invite code does not exist".to_string(),
402 )
403 })
404 }
405
406 async fn redeem_invite_code(
407 &self,
408 code: &str,
409 login: &str,
410 email_address: Option<&str>,
411 ) -> Result<UserId> {
412 let mut tx = self.pool.begin().await?;
413
414 let inviter_id: Option<UserId> = sqlx::query_scalar(
415 "
416 UPDATE users
417 SET invite_count = invite_count - 1
418 WHERE
419 invite_code = $1 AND
420 invite_count > 0
421 RETURNING id
422 ",
423 )
424 .bind(code)
425 .fetch_optional(&mut tx)
426 .await?;
427
428 let inviter_id = match inviter_id {
429 Some(inviter_id) => inviter_id,
430 None => {
431 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
432 .bind(code)
433 .fetch_optional(&mut tx)
434 .await?
435 .is_some()
436 {
437 Err(Error::Http(
438 StatusCode::UNAUTHORIZED,
439 "no invites remaining".to_string(),
440 ))?
441 } else {
442 Err(Error::Http(
443 StatusCode::NOT_FOUND,
444 "invite code not found".to_string(),
445 ))?
446 }
447 }
448 };
449
450 let invitee_id = sqlx::query_scalar(
451 "
452 INSERT INTO users
453 (github_login, email_address, admin, inviter_id)
454 VALUES
455 ($1, $2, 'f', $3)
456 RETURNING id
457 ",
458 )
459 .bind(login)
460 .bind(email_address)
461 .bind(inviter_id)
462 .fetch_one(&mut tx)
463 .await
464 .map(UserId)?;
465
466 sqlx::query(
467 "
468 INSERT INTO contacts
469 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
470 VALUES
471 ($1, $2, 't', 't', 't')
472 ",
473 )
474 .bind(inviter_id)
475 .bind(invitee_id)
476 .execute(&mut tx)
477 .await?;
478
479 tx.commit().await?;
480 Ok(invitee_id)
481 }
482
483 // projects
484
485 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
486 Ok(sqlx::query_scalar(
487 "
488 INSERT INTO projects(host_user_id)
489 VALUES ($1)
490 RETURNING id
491 ",
492 )
493 .bind(host_user_id)
494 .fetch_one(&self.pool)
495 .await
496 .map(ProjectId)?)
497 }
498
499 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
500 sqlx::query(
501 "
502 UPDATE projects
503 SET unregistered = 't'
504 WHERE id = $1
505 ",
506 )
507 .bind(project_id)
508 .execute(&self.pool)
509 .await?;
510 Ok(())
511 }
512
513 async fn update_worktree_extensions(
514 &self,
515 project_id: ProjectId,
516 worktree_id: u64,
517 extensions: HashMap<String, u32>,
518 ) -> Result<()> {
519 if extensions.is_empty() {
520 return Ok(());
521 }
522
523 let mut query = QueryBuilder::new(
524 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
525 );
526 query.push_values(extensions, |mut query, (extension, count)| {
527 query
528 .push_bind(project_id)
529 .push_bind(worktree_id as i32)
530 .push_bind(extension)
531 .push_bind(count as i32);
532 });
533 query.push(
534 "
535 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
536 count = excluded.count
537 ",
538 );
539 query.build().execute(&self.pool).await?;
540
541 Ok(())
542 }
543
544 async fn get_project_extensions(
545 &self,
546 project_id: ProjectId,
547 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
548 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
549 struct WorktreeExtension {
550 worktree_id: i32,
551 extension: String,
552 count: i32,
553 }
554
555 let query = "
556 SELECT worktree_id, extension, count
557 FROM worktree_extensions
558 WHERE project_id = $1
559 ";
560 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
561 .bind(&project_id)
562 .fetch_all(&self.pool)
563 .await?;
564
565 let mut extension_counts = HashMap::default();
566 for count in counts {
567 extension_counts
568 .entry(count.worktree_id as u64)
569 .or_insert(HashMap::default())
570 .insert(count.extension, count.count as usize);
571 }
572 Ok(extension_counts)
573 }
574
575 async fn record_user_activity(
576 &self,
577 time_period: Range<OffsetDateTime>,
578 projects: &[(UserId, ProjectId)],
579 ) -> Result<()> {
580 let query = "
581 INSERT INTO project_activity_periods
582 (ended_at, duration_millis, user_id, project_id)
583 VALUES
584 ($1, $2, $3, $4);
585 ";
586
587 let mut tx = self.pool.begin().await?;
588 let duration_millis =
589 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
590 for (user_id, project_id) in projects {
591 sqlx::query(query)
592 .bind(time_period.end)
593 .bind(duration_millis)
594 .bind(user_id)
595 .bind(project_id)
596 .execute(&mut tx)
597 .await?;
598 }
599 tx.commit().await?;
600
601 Ok(())
602 }
603
604 async fn get_active_user_count(
605 &self,
606 time_period: Range<OffsetDateTime>,
607 min_duration: Duration,
608 ) -> Result<usize> {
609 let query = "
610 WITH
611 project_durations AS (
612 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
613 FROM project_activity_periods
614 WHERE $1 < ended_at AND ended_at <= $2
615 GROUP BY user_id, project_id
616 ),
617 user_durations AS (
618 SELECT user_id, SUM(project_duration) as total_duration
619 FROM project_durations
620 GROUP BY user_id
621 ORDER BY total_duration DESC
622 LIMIT $3
623 )
624 SELECT count(user_durations.user_id)
625 FROM user_durations
626 WHERE user_durations.total_duration >= $3
627 ";
628
629 let count: i64 = sqlx::query_scalar(query)
630 .bind(time_period.start)
631 .bind(time_period.end)
632 .bind(min_duration.as_millis() as i64)
633 .fetch_one(&self.pool)
634 .await?;
635 Ok(count as usize)
636 }
637
638 async fn get_top_users_activity_summary(
639 &self,
640 time_period: Range<OffsetDateTime>,
641 max_user_count: usize,
642 ) -> Result<Vec<UserActivitySummary>> {
643 let query = "
644 WITH
645 project_durations AS (
646 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
647 FROM project_activity_periods
648 WHERE $1 < ended_at AND ended_at <= $2
649 GROUP BY user_id, project_id
650 ),
651 user_durations AS (
652 SELECT user_id, SUM(project_duration) as total_duration
653 FROM project_durations
654 GROUP BY user_id
655 ORDER BY total_duration DESC
656 LIMIT $3
657 )
658 SELECT user_durations.user_id, users.github_login, project_id, project_duration
659 FROM user_durations, project_durations, users
660 WHERE
661 user_durations.user_id = project_durations.user_id AND
662 user_durations.user_id = users.id
663 ORDER BY total_duration DESC, user_id ASC
664 ";
665
666 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
667 .bind(time_period.start)
668 .bind(time_period.end)
669 .bind(max_user_count as i32)
670 .fetch(&self.pool);
671
672 let mut result = Vec::<UserActivitySummary>::new();
673 while let Some(row) = rows.next().await {
674 let (user_id, github_login, project_id, duration_millis) = row?;
675 let project_id = project_id;
676 let duration = Duration::from_millis(duration_millis as u64);
677 if let Some(last_summary) = result.last_mut() {
678 if last_summary.id == user_id {
679 last_summary.project_activity.push((project_id, duration));
680 continue;
681 }
682 }
683 result.push(UserActivitySummary {
684 id: user_id,
685 project_activity: vec![(project_id, duration)],
686 github_login,
687 });
688 }
689
690 Ok(result)
691 }
692
693 async fn get_user_activity_timeline(
694 &self,
695 time_period: Range<OffsetDateTime>,
696 user_id: UserId,
697 ) -> Result<Vec<UserActivityPeriod>> {
698 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
699
700 let query = "
701 SELECT
702 project_activity_periods.ended_at,
703 project_activity_periods.duration_millis,
704 project_activity_periods.project_id,
705 worktree_extensions.extension,
706 worktree_extensions.count
707 FROM project_activity_periods
708 LEFT OUTER JOIN
709 worktree_extensions
710 ON
711 project_activity_periods.project_id = worktree_extensions.project_id
712 WHERE
713 project_activity_periods.user_id = $1 AND
714 $2 < project_activity_periods.ended_at AND
715 project_activity_periods.ended_at <= $3
716 ORDER BY project_activity_periods.id ASC
717 ";
718
719 let mut rows = sqlx::query_as::<
720 _,
721 (
722 PrimitiveDateTime,
723 i32,
724 ProjectId,
725 Option<String>,
726 Option<i32>,
727 ),
728 >(query)
729 .bind(user_id)
730 .bind(time_period.start)
731 .bind(time_period.end)
732 .fetch(&self.pool);
733
734 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
735 while let Some(row) = rows.next().await {
736 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
737 let ended_at = ended_at.assume_utc();
738 let duration = Duration::from_millis(duration_millis as u64);
739 let started_at = ended_at - duration;
740 let project_time_periods = time_periods.entry(project_id).or_default();
741
742 if let Some(prev_duration) = project_time_periods.last_mut() {
743 if started_at <= prev_duration.end + COALESCE_THRESHOLD
744 && ended_at >= prev_duration.start
745 {
746 prev_duration.end = cmp::max(prev_duration.end, ended_at);
747 } else {
748 project_time_periods.push(UserActivityPeriod {
749 project_id,
750 start: started_at,
751 end: ended_at,
752 extensions: Default::default(),
753 });
754 }
755 } else {
756 project_time_periods.push(UserActivityPeriod {
757 project_id,
758 start: started_at,
759 end: ended_at,
760 extensions: Default::default(),
761 });
762 }
763
764 if let Some((extension, extension_count)) = extension.zip(extension_count) {
765 project_time_periods
766 .last_mut()
767 .unwrap()
768 .extensions
769 .insert(extension, extension_count as usize);
770 }
771 }
772
773 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
774 durations.sort_unstable_by_key(|duration| duration.start);
775 Ok(durations)
776 }
777
778 // contacts
779
780 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
781 let query = "
782 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
783 FROM contacts
784 WHERE user_id_a = $1 OR user_id_b = $1;
785 ";
786
787 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
788 .bind(user_id)
789 .fetch(&self.pool);
790
791 let mut contacts = vec![Contact::Accepted {
792 user_id,
793 should_notify: false,
794 }];
795 while let Some(row) = rows.next().await {
796 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
797
798 if user_id_a == user_id {
799 if accepted {
800 contacts.push(Contact::Accepted {
801 user_id: user_id_b,
802 should_notify: should_notify && a_to_b,
803 });
804 } else if a_to_b {
805 contacts.push(Contact::Outgoing { user_id: user_id_b })
806 } else {
807 contacts.push(Contact::Incoming {
808 user_id: user_id_b,
809 should_notify,
810 });
811 }
812 } else {
813 if accepted {
814 contacts.push(Contact::Accepted {
815 user_id: user_id_a,
816 should_notify: should_notify && !a_to_b,
817 });
818 } else if a_to_b {
819 contacts.push(Contact::Incoming {
820 user_id: user_id_a,
821 should_notify,
822 });
823 } else {
824 contacts.push(Contact::Outgoing { user_id: user_id_a });
825 }
826 }
827 }
828
829 contacts.sort_unstable_by_key(|contact| contact.user_id());
830
831 Ok(contacts)
832 }
833
834 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
835 let (id_a, id_b) = if user_id_1 < user_id_2 {
836 (user_id_1, user_id_2)
837 } else {
838 (user_id_2, user_id_1)
839 };
840
841 let query = "
842 SELECT 1 FROM contacts
843 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
844 LIMIT 1
845 ";
846 Ok(sqlx::query_scalar::<_, i32>(query)
847 .bind(id_a.0)
848 .bind(id_b.0)
849 .fetch_optional(&self.pool)
850 .await?
851 .is_some())
852 }
853
854 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
855 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
856 (sender_id, receiver_id, true)
857 } else {
858 (receiver_id, sender_id, false)
859 };
860 let query = "
861 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
862 VALUES ($1, $2, $3, 'f', 't')
863 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
864 SET
865 accepted = 't',
866 should_notify = 'f'
867 WHERE
868 NOT contacts.accepted AND
869 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
870 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
871 ";
872 let result = sqlx::query(query)
873 .bind(id_a.0)
874 .bind(id_b.0)
875 .bind(a_to_b)
876 .execute(&self.pool)
877 .await?;
878
879 if result.rows_affected() == 1 {
880 Ok(())
881 } else {
882 Err(anyhow!("contact already requested"))?
883 }
884 }
885
886 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
887 let (id_a, id_b) = if responder_id < requester_id {
888 (responder_id, requester_id)
889 } else {
890 (requester_id, responder_id)
891 };
892 let query = "
893 DELETE FROM contacts
894 WHERE user_id_a = $1 AND user_id_b = $2;
895 ";
896 let result = sqlx::query(query)
897 .bind(id_a.0)
898 .bind(id_b.0)
899 .execute(&self.pool)
900 .await?;
901
902 if result.rows_affected() == 1 {
903 Ok(())
904 } else {
905 Err(anyhow!("no such contact"))?
906 }
907 }
908
909 async fn dismiss_contact_notification(
910 &self,
911 user_id: UserId,
912 contact_user_id: UserId,
913 ) -> Result<()> {
914 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
915 (user_id, contact_user_id, true)
916 } else {
917 (contact_user_id, user_id, false)
918 };
919
920 let query = "
921 UPDATE contacts
922 SET should_notify = 'f'
923 WHERE
924 user_id_a = $1 AND user_id_b = $2 AND
925 (
926 (a_to_b = $3 AND accepted) OR
927 (a_to_b != $3 AND NOT accepted)
928 );
929 ";
930
931 let result = sqlx::query(query)
932 .bind(id_a.0)
933 .bind(id_b.0)
934 .bind(a_to_b)
935 .execute(&self.pool)
936 .await?;
937
938 if result.rows_affected() == 0 {
939 Err(anyhow!("no such contact request"))?;
940 }
941
942 Ok(())
943 }
944
945 async fn respond_to_contact_request(
946 &self,
947 responder_id: UserId,
948 requester_id: UserId,
949 accept: bool,
950 ) -> Result<()> {
951 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
952 (responder_id, requester_id, false)
953 } else {
954 (requester_id, responder_id, true)
955 };
956 let result = if accept {
957 let query = "
958 UPDATE contacts
959 SET accepted = 't', should_notify = 't'
960 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
961 ";
962 sqlx::query(query)
963 .bind(id_a.0)
964 .bind(id_b.0)
965 .bind(a_to_b)
966 .execute(&self.pool)
967 .await?
968 } else {
969 let query = "
970 DELETE FROM contacts
971 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
972 ";
973 sqlx::query(query)
974 .bind(id_a.0)
975 .bind(id_b.0)
976 .bind(a_to_b)
977 .execute(&self.pool)
978 .await?
979 };
980 if result.rows_affected() == 1 {
981 Ok(())
982 } else {
983 Err(anyhow!("no such contact request"))?
984 }
985 }
986
987 // access tokens
988
989 async fn create_access_token_hash(
990 &self,
991 user_id: UserId,
992 access_token_hash: &str,
993 max_access_token_count: usize,
994 ) -> Result<()> {
995 let insert_query = "
996 INSERT INTO access_tokens (user_id, hash)
997 VALUES ($1, $2);
998 ";
999 let cleanup_query = "
1000 DELETE FROM access_tokens
1001 WHERE id IN (
1002 SELECT id from access_tokens
1003 WHERE user_id = $1
1004 ORDER BY id DESC
1005 OFFSET $3
1006 )
1007 ";
1008
1009 let mut tx = self.pool.begin().await?;
1010 sqlx::query(insert_query)
1011 .bind(user_id.0)
1012 .bind(access_token_hash)
1013 .execute(&mut tx)
1014 .await?;
1015 sqlx::query(cleanup_query)
1016 .bind(user_id.0)
1017 .bind(access_token_hash)
1018 .bind(max_access_token_count as i32)
1019 .execute(&mut tx)
1020 .await?;
1021 Ok(tx.commit().await?)
1022 }
1023
1024 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1025 let query = "
1026 SELECT hash
1027 FROM access_tokens
1028 WHERE user_id = $1
1029 ORDER BY id DESC
1030 ";
1031 Ok(sqlx::query_scalar(query)
1032 .bind(user_id.0)
1033 .fetch_all(&self.pool)
1034 .await?)
1035 }
1036
1037 // orgs
1038
1039 #[allow(unused)] // Help rust-analyzer
1040 #[cfg(any(test, feature = "seed-support"))]
1041 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1042 let query = "
1043 SELECT *
1044 FROM orgs
1045 WHERE slug = $1
1046 ";
1047 Ok(sqlx::query_as(query)
1048 .bind(slug)
1049 .fetch_optional(&self.pool)
1050 .await?)
1051 }
1052
1053 #[cfg(any(test, feature = "seed-support"))]
1054 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1055 let query = "
1056 INSERT INTO orgs (name, slug)
1057 VALUES ($1, $2)
1058 RETURNING id
1059 ";
1060 Ok(sqlx::query_scalar(query)
1061 .bind(name)
1062 .bind(slug)
1063 .fetch_one(&self.pool)
1064 .await
1065 .map(OrgId)?)
1066 }
1067
1068 #[cfg(any(test, feature = "seed-support"))]
1069 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1070 let query = "
1071 INSERT INTO org_memberships (org_id, user_id, admin)
1072 VALUES ($1, $2, $3)
1073 ON CONFLICT DO NOTHING
1074 ";
1075 Ok(sqlx::query(query)
1076 .bind(org_id.0)
1077 .bind(user_id.0)
1078 .bind(is_admin)
1079 .execute(&self.pool)
1080 .await
1081 .map(drop)?)
1082 }
1083
1084 // channels
1085
1086 #[cfg(any(test, feature = "seed-support"))]
1087 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1088 let query = "
1089 INSERT INTO channels (owner_id, owner_is_user, name)
1090 VALUES ($1, false, $2)
1091 RETURNING id
1092 ";
1093 Ok(sqlx::query_scalar(query)
1094 .bind(org_id.0)
1095 .bind(name)
1096 .fetch_one(&self.pool)
1097 .await
1098 .map(ChannelId)?)
1099 }
1100
1101 #[allow(unused)] // Help rust-analyzer
1102 #[cfg(any(test, feature = "seed-support"))]
1103 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1104 let query = "
1105 SELECT *
1106 FROM channels
1107 WHERE
1108 channels.owner_is_user = false AND
1109 channels.owner_id = $1
1110 ";
1111 Ok(sqlx::query_as(query)
1112 .bind(org_id.0)
1113 .fetch_all(&self.pool)
1114 .await?)
1115 }
1116
1117 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1118 let query = "
1119 SELECT
1120 channels.*
1121 FROM
1122 channel_memberships, channels
1123 WHERE
1124 channel_memberships.user_id = $1 AND
1125 channel_memberships.channel_id = channels.id
1126 ";
1127 Ok(sqlx::query_as(query)
1128 .bind(user_id.0)
1129 .fetch_all(&self.pool)
1130 .await?)
1131 }
1132
1133 async fn can_user_access_channel(
1134 &self,
1135 user_id: UserId,
1136 channel_id: ChannelId,
1137 ) -> Result<bool> {
1138 let query = "
1139 SELECT id
1140 FROM channel_memberships
1141 WHERE user_id = $1 AND channel_id = $2
1142 LIMIT 1
1143 ";
1144 Ok(sqlx::query_scalar::<_, i32>(query)
1145 .bind(user_id.0)
1146 .bind(channel_id.0)
1147 .fetch_optional(&self.pool)
1148 .await
1149 .map(|e| e.is_some())?)
1150 }
1151
1152 #[cfg(any(test, feature = "seed-support"))]
1153 async fn add_channel_member(
1154 &self,
1155 channel_id: ChannelId,
1156 user_id: UserId,
1157 is_admin: bool,
1158 ) -> Result<()> {
1159 let query = "
1160 INSERT INTO channel_memberships (channel_id, user_id, admin)
1161 VALUES ($1, $2, $3)
1162 ON CONFLICT DO NOTHING
1163 ";
1164 Ok(sqlx::query(query)
1165 .bind(channel_id.0)
1166 .bind(user_id.0)
1167 .bind(is_admin)
1168 .execute(&self.pool)
1169 .await
1170 .map(drop)?)
1171 }
1172
1173 // messages
1174
1175 async fn create_channel_message(
1176 &self,
1177 channel_id: ChannelId,
1178 sender_id: UserId,
1179 body: &str,
1180 timestamp: OffsetDateTime,
1181 nonce: u128,
1182 ) -> Result<MessageId> {
1183 let query = "
1184 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1185 VALUES ($1, $2, $3, $4, $5)
1186 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1187 RETURNING id
1188 ";
1189 Ok(sqlx::query_scalar(query)
1190 .bind(channel_id.0)
1191 .bind(sender_id.0)
1192 .bind(body)
1193 .bind(timestamp)
1194 .bind(Uuid::from_u128(nonce))
1195 .fetch_one(&self.pool)
1196 .await
1197 .map(MessageId)?)
1198 }
1199
1200 async fn get_channel_messages(
1201 &self,
1202 channel_id: ChannelId,
1203 count: usize,
1204 before_id: Option<MessageId>,
1205 ) -> Result<Vec<ChannelMessage>> {
1206 let query = r#"
1207 SELECT * FROM (
1208 SELECT
1209 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1210 FROM
1211 channel_messages
1212 WHERE
1213 channel_id = $1 AND
1214 id < $2
1215 ORDER BY id DESC
1216 LIMIT $3
1217 ) as recent_messages
1218 ORDER BY id ASC
1219 "#;
1220 Ok(sqlx::query_as(query)
1221 .bind(channel_id.0)
1222 .bind(before_id.unwrap_or(MessageId::MAX))
1223 .bind(count as i64)
1224 .fetch_all(&self.pool)
1225 .await?)
1226 }
1227
1228 #[cfg(test)]
1229 async fn teardown(&self, url: &str) {
1230 use util::ResultExt;
1231
1232 let query = "
1233 SELECT pg_terminate_backend(pg_stat_activity.pid)
1234 FROM pg_stat_activity
1235 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1236 ";
1237 sqlx::query(query).execute(&self.pool).await.log_err();
1238 self.pool.close().await;
1239 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1240 .await
1241 .log_err();
1242 }
1243
1244 #[cfg(test)]
1245 fn as_fake(&self) -> Option<&tests::FakeDb> {
1246 None
1247 }
1248}
1249
1250macro_rules! id_type {
1251 ($name:ident) => {
1252 #[derive(
1253 Clone,
1254 Copy,
1255 Debug,
1256 Default,
1257 PartialEq,
1258 Eq,
1259 PartialOrd,
1260 Ord,
1261 Hash,
1262 sqlx::Type,
1263 Serialize,
1264 Deserialize,
1265 )]
1266 #[sqlx(transparent)]
1267 #[serde(transparent)]
1268 pub struct $name(pub i32);
1269
1270 impl $name {
1271 #[allow(unused)]
1272 pub const MAX: Self = Self(i32::MAX);
1273
1274 #[allow(unused)]
1275 pub fn from_proto(value: u64) -> Self {
1276 Self(value as i32)
1277 }
1278
1279 #[allow(unused)]
1280 pub fn to_proto(&self) -> u64 {
1281 self.0 as u64
1282 }
1283 }
1284
1285 impl std::fmt::Display for $name {
1286 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1287 self.0.fmt(f)
1288 }
1289 }
1290 };
1291}
1292
1293id_type!(UserId);
1294#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1295pub struct User {
1296 pub id: UserId,
1297 pub github_login: String,
1298 pub email_address: Option<String>,
1299 pub admin: bool,
1300 pub invite_code: Option<String>,
1301 pub invite_count: i32,
1302 pub connected_once: bool,
1303}
1304
1305id_type!(ProjectId);
1306#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1307pub struct Project {
1308 pub id: ProjectId,
1309 pub host_user_id: UserId,
1310 pub unregistered: bool,
1311}
1312
1313#[derive(Clone, Debug, PartialEq, Serialize)]
1314pub struct UserActivitySummary {
1315 pub id: UserId,
1316 pub github_login: String,
1317 pub project_activity: Vec<(ProjectId, Duration)>,
1318}
1319
1320#[derive(Clone, Debug, PartialEq, Serialize)]
1321pub struct UserActivityPeriod {
1322 project_id: ProjectId,
1323 #[serde(with = "time::serde::iso8601")]
1324 start: OffsetDateTime,
1325 #[serde(with = "time::serde::iso8601")]
1326 end: OffsetDateTime,
1327 extensions: HashMap<String, usize>,
1328}
1329
1330id_type!(OrgId);
1331#[derive(FromRow)]
1332pub struct Org {
1333 pub id: OrgId,
1334 pub name: String,
1335 pub slug: String,
1336}
1337
1338id_type!(ChannelId);
1339#[derive(Clone, Debug, FromRow, Serialize)]
1340pub struct Channel {
1341 pub id: ChannelId,
1342 pub name: String,
1343 pub owner_id: i32,
1344 pub owner_is_user: bool,
1345}
1346
1347id_type!(MessageId);
1348#[derive(Clone, Debug, FromRow)]
1349pub struct ChannelMessage {
1350 pub id: MessageId,
1351 pub channel_id: ChannelId,
1352 pub sender_id: UserId,
1353 pub body: String,
1354 pub sent_at: OffsetDateTime,
1355 pub nonce: Uuid,
1356}
1357
1358#[derive(Clone, Debug, PartialEq, Eq)]
1359pub enum Contact {
1360 Accepted {
1361 user_id: UserId,
1362 should_notify: bool,
1363 },
1364 Outgoing {
1365 user_id: UserId,
1366 },
1367 Incoming {
1368 user_id: UserId,
1369 should_notify: bool,
1370 },
1371}
1372
1373impl Contact {
1374 pub fn user_id(&self) -> UserId {
1375 match self {
1376 Contact::Accepted { user_id, .. } => *user_id,
1377 Contact::Outgoing { user_id } => *user_id,
1378 Contact::Incoming { user_id, .. } => *user_id,
1379 }
1380 }
1381}
1382
1383#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1384pub struct IncomingContactRequest {
1385 pub requester_id: UserId,
1386 pub should_notify: bool,
1387}
1388
1389fn fuzzy_like_string(string: &str) -> String {
1390 let mut result = String::with_capacity(string.len() * 2 + 1);
1391 for c in string.chars() {
1392 if c.is_alphanumeric() {
1393 result.push('%');
1394 result.push(c);
1395 }
1396 }
1397 result.push('%');
1398 result
1399}
1400
1401#[cfg(test)]
1402pub mod tests {
1403 use super::*;
1404 use anyhow::anyhow;
1405 use collections::BTreeMap;
1406 use gpui::executor::{Background, Deterministic};
1407 use lazy_static::lazy_static;
1408 use parking_lot::Mutex;
1409 use rand::prelude::*;
1410 use sqlx::{
1411 migrate::{MigrateDatabase, Migrator},
1412 Postgres,
1413 };
1414 use std::{path::Path, sync::Arc};
1415 use util::post_inc;
1416
1417 #[tokio::test(flavor = "multi_thread")]
1418 async fn test_get_users_by_ids() {
1419 for test_db in [
1420 TestDb::postgres().await,
1421 TestDb::fake(build_background_executor()),
1422 ] {
1423 let db = test_db.db();
1424
1425 let user = db.create_user("user", None, false).await.unwrap();
1426 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1427 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1428 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1429
1430 assert_eq!(
1431 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1432 .await
1433 .unwrap(),
1434 vec![
1435 User {
1436 id: user,
1437 github_login: "user".to_string(),
1438 admin: false,
1439 ..Default::default()
1440 },
1441 User {
1442 id: friend1,
1443 github_login: "friend-1".to_string(),
1444 admin: false,
1445 ..Default::default()
1446 },
1447 User {
1448 id: friend2,
1449 github_login: "friend-2".to_string(),
1450 admin: false,
1451 ..Default::default()
1452 },
1453 User {
1454 id: friend3,
1455 github_login: "friend-3".to_string(),
1456 admin: false,
1457 ..Default::default()
1458 }
1459 ]
1460 );
1461 }
1462 }
1463
1464 #[tokio::test(flavor = "multi_thread")]
1465 async fn test_create_users() {
1466 let db = TestDb::postgres().await;
1467 let db = db.db();
1468
1469 // Create the first batch of users, ensuring invite counts are assigned
1470 // correctly and the respective invite codes are unique.
1471 let user_ids_batch_1 = db
1472 .create_users(vec![
1473 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1474 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1475 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1476 ])
1477 .await
1478 .unwrap();
1479 assert_eq!(user_ids_batch_1.len(), 3);
1480
1481 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1482 assert_eq!(users.len(), 3);
1483 assert_eq!(users[0].github_login, "user1");
1484 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1485 assert_eq!(users[0].invite_count, 5);
1486 assert_eq!(users[1].github_login, "user2");
1487 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1488 assert_eq!(users[1].invite_count, 4);
1489 assert_eq!(users[2].github_login, "user3");
1490 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1491 assert_eq!(users[2].invite_count, 3);
1492
1493 let invite_code_1 = users[0].invite_code.clone().unwrap();
1494 let invite_code_2 = users[1].invite_code.clone().unwrap();
1495 let invite_code_3 = users[2].invite_code.clone().unwrap();
1496 assert_ne!(invite_code_1, invite_code_2);
1497 assert_ne!(invite_code_1, invite_code_3);
1498 assert_ne!(invite_code_2, invite_code_3);
1499
1500 // Create the second batch of users and include a user that is already in the database, ensuring
1501 // the invite count for the existing user is updated without changing their invite code.
1502 let user_ids_batch_2 = db
1503 .create_users(vec![
1504 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1505 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1506 ])
1507 .await
1508 .unwrap();
1509 assert_eq!(user_ids_batch_2.len(), 2);
1510 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1511
1512 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1513 assert_eq!(users.len(), 2);
1514 assert_eq!(users[0].github_login, "user2");
1515 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1516 assert_eq!(users[0].invite_count, 10);
1517 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1518 assert_eq!(users[1].github_login, "user4");
1519 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1520 assert_eq!(users[1].invite_count, 2);
1521
1522 let invite_code_4 = users[1].invite_code.clone().unwrap();
1523 assert_ne!(invite_code_4, invite_code_1);
1524 assert_ne!(invite_code_4, invite_code_2);
1525 assert_ne!(invite_code_4, invite_code_3);
1526 }
1527
1528 #[tokio::test(flavor = "multi_thread")]
1529 async fn test_worktree_extensions() {
1530 let test_db = TestDb::postgres().await;
1531 let db = test_db.db();
1532
1533 let user = db.create_user("user_1", None, false).await.unwrap();
1534 let project = db.register_project(user).await.unwrap();
1535
1536 db.update_worktree_extensions(project, 100, Default::default())
1537 .await
1538 .unwrap();
1539 db.update_worktree_extensions(
1540 project,
1541 100,
1542 [("rs".to_string(), 5), ("md".to_string(), 3)]
1543 .into_iter()
1544 .collect(),
1545 )
1546 .await
1547 .unwrap();
1548 db.update_worktree_extensions(
1549 project,
1550 100,
1551 [("rs".to_string(), 6), ("md".to_string(), 5)]
1552 .into_iter()
1553 .collect(),
1554 )
1555 .await
1556 .unwrap();
1557 db.update_worktree_extensions(
1558 project,
1559 101,
1560 [("ts".to_string(), 2), ("md".to_string(), 1)]
1561 .into_iter()
1562 .collect(),
1563 )
1564 .await
1565 .unwrap();
1566
1567 assert_eq!(
1568 db.get_project_extensions(project).await.unwrap(),
1569 [
1570 (
1571 100,
1572 [("rs".into(), 6), ("md".into(), 5),]
1573 .into_iter()
1574 .collect::<HashMap<_, _>>()
1575 ),
1576 (
1577 101,
1578 [("ts".into(), 2), ("md".into(), 1),]
1579 .into_iter()
1580 .collect::<HashMap<_, _>>()
1581 )
1582 ]
1583 .into_iter()
1584 .collect()
1585 );
1586 }
1587
1588 #[tokio::test(flavor = "multi_thread")]
1589 async fn test_user_activity() {
1590 let test_db = TestDb::postgres().await;
1591 let db = test_db.db();
1592
1593 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1594 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1595 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1596 let project_1 = db.register_project(user_1).await.unwrap();
1597 db.update_worktree_extensions(
1598 project_1,
1599 1,
1600 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1601 )
1602 .await
1603 .unwrap();
1604 let project_2 = db.register_project(user_2).await.unwrap();
1605 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1606
1607 // User 2 opens a project
1608 let t1 = t0 + Duration::from_secs(10);
1609 db.record_user_activity(t0..t1, &[(user_2, project_2)])
1610 .await
1611 .unwrap();
1612
1613 let t2 = t1 + Duration::from_secs(10);
1614 db.record_user_activity(t1..t2, &[(user_2, project_2)])
1615 .await
1616 .unwrap();
1617
1618 // User 1 joins the project
1619 let t3 = t2 + Duration::from_secs(10);
1620 db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1621 .await
1622 .unwrap();
1623
1624 // User 1 opens another project
1625 let t4 = t3 + Duration::from_secs(10);
1626 db.record_user_activity(
1627 t3..t4,
1628 &[
1629 (user_2, project_2),
1630 (user_1, project_2),
1631 (user_1, project_1),
1632 ],
1633 )
1634 .await
1635 .unwrap();
1636
1637 // User 3 joins that project
1638 let t5 = t4 + Duration::from_secs(10);
1639 db.record_user_activity(
1640 t4..t5,
1641 &[
1642 (user_2, project_2),
1643 (user_1, project_2),
1644 (user_1, project_1),
1645 (user_3, project_1),
1646 ],
1647 )
1648 .await
1649 .unwrap();
1650
1651 // User 2 leaves
1652 let t6 = t5 + Duration::from_secs(5);
1653 db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1654 .await
1655 .unwrap();
1656
1657 let t7 = t6 + Duration::from_secs(60);
1658 let t8 = t7 + Duration::from_secs(10);
1659 db.record_user_activity(t7..t8, &[(user_1, project_1)])
1660 .await
1661 .unwrap();
1662
1663 assert_eq!(
1664 db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1665 &[
1666 UserActivitySummary {
1667 id: user_1,
1668 github_login: "user_1".to_string(),
1669 project_activity: vec![
1670 (project_1, Duration::from_secs(25)),
1671 (project_2, Duration::from_secs(30)),
1672 ]
1673 },
1674 UserActivitySummary {
1675 id: user_2,
1676 github_login: "user_2".to_string(),
1677 project_activity: vec![(project_2, Duration::from_secs(50))]
1678 },
1679 UserActivitySummary {
1680 id: user_3,
1681 github_login: "user_3".to_string(),
1682 project_activity: vec![(project_1, Duration::from_secs(15))]
1683 },
1684 ]
1685 );
1686
1687 assert_eq!(
1688 db.get_active_user_count(t0..t6, Duration::from_secs(56))
1689 .await
1690 .unwrap(),
1691 0
1692 );
1693 assert_eq!(
1694 db.get_active_user_count(t0..t6, Duration::from_secs(54))
1695 .await
1696 .unwrap(),
1697 1
1698 );
1699 assert_eq!(
1700 db.get_active_user_count(t0..t6, Duration::from_secs(30))
1701 .await
1702 .unwrap(),
1703 2
1704 );
1705 assert_eq!(
1706 db.get_active_user_count(t0..t6, Duration::from_secs(10))
1707 .await
1708 .unwrap(),
1709 3
1710 );
1711
1712 assert_eq!(
1713 db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1714 &[
1715 UserActivityPeriod {
1716 project_id: project_1,
1717 start: t3,
1718 end: t6,
1719 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1720 },
1721 UserActivityPeriod {
1722 project_id: project_2,
1723 start: t3,
1724 end: t5,
1725 extensions: Default::default(),
1726 },
1727 ]
1728 );
1729 assert_eq!(
1730 db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1731 &[
1732 UserActivityPeriod {
1733 project_id: project_2,
1734 start: t2,
1735 end: t5,
1736 extensions: Default::default(),
1737 },
1738 UserActivityPeriod {
1739 project_id: project_1,
1740 start: t3,
1741 end: t6,
1742 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1743 },
1744 UserActivityPeriod {
1745 project_id: project_1,
1746 start: t7,
1747 end: t8,
1748 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1749 },
1750 ]
1751 );
1752 }
1753
1754 #[tokio::test(flavor = "multi_thread")]
1755 async fn test_recent_channel_messages() {
1756 for test_db in [
1757 TestDb::postgres().await,
1758 TestDb::fake(build_background_executor()),
1759 ] {
1760 let db = test_db.db();
1761 let user = db.create_user("user", None, false).await.unwrap();
1762 let org = db.create_org("org", "org").await.unwrap();
1763 let channel = db.create_org_channel(org, "channel").await.unwrap();
1764 for i in 0..10 {
1765 db.create_channel_message(
1766 channel,
1767 user,
1768 &i.to_string(),
1769 OffsetDateTime::now_utc(),
1770 i,
1771 )
1772 .await
1773 .unwrap();
1774 }
1775
1776 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1777 assert_eq!(
1778 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1779 ["5", "6", "7", "8", "9"]
1780 );
1781
1782 let prev_messages = db
1783 .get_channel_messages(channel, 4, Some(messages[0].id))
1784 .await
1785 .unwrap();
1786 assert_eq!(
1787 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1788 ["1", "2", "3", "4"]
1789 );
1790 }
1791 }
1792
1793 #[tokio::test(flavor = "multi_thread")]
1794 async fn test_channel_message_nonces() {
1795 for test_db in [
1796 TestDb::postgres().await,
1797 TestDb::fake(build_background_executor()),
1798 ] {
1799 let db = test_db.db();
1800 let user = db.create_user("user", None, false).await.unwrap();
1801 let org = db.create_org("org", "org").await.unwrap();
1802 let channel = db.create_org_channel(org, "channel").await.unwrap();
1803
1804 let msg1_id = db
1805 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1806 .await
1807 .unwrap();
1808 let msg2_id = db
1809 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1810 .await
1811 .unwrap();
1812 let msg3_id = db
1813 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1814 .await
1815 .unwrap();
1816 let msg4_id = db
1817 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1818 .await
1819 .unwrap();
1820
1821 assert_ne!(msg1_id, msg2_id);
1822 assert_eq!(msg1_id, msg3_id);
1823 assert_eq!(msg2_id, msg4_id);
1824 }
1825 }
1826
1827 #[tokio::test(flavor = "multi_thread")]
1828 async fn test_create_access_tokens() {
1829 let test_db = TestDb::postgres().await;
1830 let db = test_db.db();
1831 let user = db.create_user("the-user", None, false).await.unwrap();
1832
1833 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1834 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1835 assert_eq!(
1836 db.get_access_token_hashes(user).await.unwrap(),
1837 &["h2".to_string(), "h1".to_string()]
1838 );
1839
1840 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1841 assert_eq!(
1842 db.get_access_token_hashes(user).await.unwrap(),
1843 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1844 );
1845
1846 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1847 assert_eq!(
1848 db.get_access_token_hashes(user).await.unwrap(),
1849 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1850 );
1851
1852 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1853 assert_eq!(
1854 db.get_access_token_hashes(user).await.unwrap(),
1855 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1856 );
1857 }
1858
1859 #[test]
1860 fn test_fuzzy_like_string() {
1861 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1862 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1863 assert_eq!(fuzzy_like_string(" z "), "%z%");
1864 }
1865
1866 #[tokio::test(flavor = "multi_thread")]
1867 async fn test_fuzzy_search_users() {
1868 let test_db = TestDb::postgres().await;
1869 let db = test_db.db();
1870 for github_login in [
1871 "California",
1872 "colorado",
1873 "oregon",
1874 "washington",
1875 "florida",
1876 "delaware",
1877 "rhode-island",
1878 ] {
1879 db.create_user(github_login, None, false).await.unwrap();
1880 }
1881
1882 assert_eq!(
1883 fuzzy_search_user_names(db, "clr").await,
1884 &["colorado", "California"]
1885 );
1886 assert_eq!(
1887 fuzzy_search_user_names(db, "ro").await,
1888 &["rhode-island", "colorado", "oregon"],
1889 );
1890
1891 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1892 db.fuzzy_search_users(query, 10)
1893 .await
1894 .unwrap()
1895 .into_iter()
1896 .map(|user| user.github_login)
1897 .collect::<Vec<_>>()
1898 }
1899 }
1900
1901 #[tokio::test(flavor = "multi_thread")]
1902 async fn test_add_contacts() {
1903 for test_db in [
1904 TestDb::postgres().await,
1905 TestDb::fake(build_background_executor()),
1906 ] {
1907 let db = test_db.db();
1908
1909 let user_1 = db.create_user("user1", None, false).await.unwrap();
1910 let user_2 = db.create_user("user2", None, false).await.unwrap();
1911 let user_3 = db.create_user("user3", None, false).await.unwrap();
1912
1913 // User starts with no contacts
1914 assert_eq!(
1915 db.get_contacts(user_1).await.unwrap(),
1916 vec![Contact::Accepted {
1917 user_id: user_1,
1918 should_notify: false
1919 }],
1920 );
1921
1922 // User requests a contact. Both users see the pending request.
1923 db.send_contact_request(user_1, user_2).await.unwrap();
1924 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1925 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1926 assert_eq!(
1927 db.get_contacts(user_1).await.unwrap(),
1928 &[
1929 Contact::Accepted {
1930 user_id: user_1,
1931 should_notify: false
1932 },
1933 Contact::Outgoing { user_id: user_2 }
1934 ],
1935 );
1936 assert_eq!(
1937 db.get_contacts(user_2).await.unwrap(),
1938 &[
1939 Contact::Incoming {
1940 user_id: user_1,
1941 should_notify: true
1942 },
1943 Contact::Accepted {
1944 user_id: user_2,
1945 should_notify: false
1946 },
1947 ]
1948 );
1949
1950 // User 2 dismisses the contact request notification without accepting or rejecting.
1951 // We shouldn't notify them again.
1952 db.dismiss_contact_notification(user_1, user_2)
1953 .await
1954 .unwrap_err();
1955 db.dismiss_contact_notification(user_2, user_1)
1956 .await
1957 .unwrap();
1958 assert_eq!(
1959 db.get_contacts(user_2).await.unwrap(),
1960 &[
1961 Contact::Incoming {
1962 user_id: user_1,
1963 should_notify: false
1964 },
1965 Contact::Accepted {
1966 user_id: user_2,
1967 should_notify: false
1968 },
1969 ]
1970 );
1971
1972 // User can't accept their own contact request
1973 db.respond_to_contact_request(user_1, user_2, true)
1974 .await
1975 .unwrap_err();
1976
1977 // User accepts a contact request. Both users see the contact.
1978 db.respond_to_contact_request(user_2, user_1, true)
1979 .await
1980 .unwrap();
1981 assert_eq!(
1982 db.get_contacts(user_1).await.unwrap(),
1983 &[
1984 Contact::Accepted {
1985 user_id: user_1,
1986 should_notify: false
1987 },
1988 Contact::Accepted {
1989 user_id: user_2,
1990 should_notify: true
1991 }
1992 ],
1993 );
1994 assert!(db.has_contact(user_1, user_2).await.unwrap());
1995 assert!(db.has_contact(user_2, user_1).await.unwrap());
1996 assert_eq!(
1997 db.get_contacts(user_2).await.unwrap(),
1998 &[
1999 Contact::Accepted {
2000 user_id: user_1,
2001 should_notify: false,
2002 },
2003 Contact::Accepted {
2004 user_id: user_2,
2005 should_notify: false,
2006 },
2007 ]
2008 );
2009
2010 // Users cannot re-request existing contacts.
2011 db.send_contact_request(user_1, user_2).await.unwrap_err();
2012 db.send_contact_request(user_2, user_1).await.unwrap_err();
2013
2014 // Users can't dismiss notifications of them accepting other users' requests.
2015 db.dismiss_contact_notification(user_2, user_1)
2016 .await
2017 .unwrap_err();
2018 assert_eq!(
2019 db.get_contacts(user_1).await.unwrap(),
2020 &[
2021 Contact::Accepted {
2022 user_id: user_1,
2023 should_notify: false
2024 },
2025 Contact::Accepted {
2026 user_id: user_2,
2027 should_notify: true,
2028 },
2029 ]
2030 );
2031
2032 // Users can dismiss notifications of other users accepting their requests.
2033 db.dismiss_contact_notification(user_1, user_2)
2034 .await
2035 .unwrap();
2036 assert_eq!(
2037 db.get_contacts(user_1).await.unwrap(),
2038 &[
2039 Contact::Accepted {
2040 user_id: user_1,
2041 should_notify: false
2042 },
2043 Contact::Accepted {
2044 user_id: user_2,
2045 should_notify: false,
2046 },
2047 ]
2048 );
2049
2050 // Users send each other concurrent contact requests and
2051 // see that they are immediately accepted.
2052 db.send_contact_request(user_1, user_3).await.unwrap();
2053 db.send_contact_request(user_3, user_1).await.unwrap();
2054 assert_eq!(
2055 db.get_contacts(user_1).await.unwrap(),
2056 &[
2057 Contact::Accepted {
2058 user_id: user_1,
2059 should_notify: false
2060 },
2061 Contact::Accepted {
2062 user_id: user_2,
2063 should_notify: false,
2064 },
2065 Contact::Accepted {
2066 user_id: user_3,
2067 should_notify: false
2068 },
2069 ]
2070 );
2071 assert_eq!(
2072 db.get_contacts(user_3).await.unwrap(),
2073 &[
2074 Contact::Accepted {
2075 user_id: user_1,
2076 should_notify: false
2077 },
2078 Contact::Accepted {
2079 user_id: user_3,
2080 should_notify: false
2081 }
2082 ],
2083 );
2084
2085 // User declines a contact request. Both users see that it is gone.
2086 db.send_contact_request(user_2, user_3).await.unwrap();
2087 db.respond_to_contact_request(user_3, user_2, false)
2088 .await
2089 .unwrap();
2090 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2091 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2092 assert_eq!(
2093 db.get_contacts(user_2).await.unwrap(),
2094 &[
2095 Contact::Accepted {
2096 user_id: user_1,
2097 should_notify: false
2098 },
2099 Contact::Accepted {
2100 user_id: user_2,
2101 should_notify: false
2102 }
2103 ]
2104 );
2105 assert_eq!(
2106 db.get_contacts(user_3).await.unwrap(),
2107 &[
2108 Contact::Accepted {
2109 user_id: user_1,
2110 should_notify: false
2111 },
2112 Contact::Accepted {
2113 user_id: user_3,
2114 should_notify: false
2115 }
2116 ],
2117 );
2118 }
2119 }
2120
2121 #[tokio::test(flavor = "multi_thread")]
2122 async fn test_invite_codes() {
2123 let postgres = TestDb::postgres().await;
2124 let db = postgres.db();
2125 let user1 = db.create_user("user-1", None, false).await.unwrap();
2126
2127 // Initially, user 1 has no invite code
2128 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2129
2130 // Setting invite count to 0 when no code is assigned does not assign a new code
2131 db.set_invite_count(user1, 0).await.unwrap();
2132 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2133
2134 // User 1 creates an invite code that can be used twice.
2135 db.set_invite_count(user1, 2).await.unwrap();
2136 let (invite_code, invite_count) =
2137 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2138 assert_eq!(invite_count, 2);
2139
2140 // User 2 redeems the invite code and becomes a contact of user 1.
2141 let user2 = db
2142 .redeem_invite_code(&invite_code, "user-2", None)
2143 .await
2144 .unwrap();
2145 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2146 assert_eq!(invite_count, 1);
2147 assert_eq!(
2148 db.get_contacts(user1).await.unwrap(),
2149 [
2150 Contact::Accepted {
2151 user_id: user1,
2152 should_notify: false
2153 },
2154 Contact::Accepted {
2155 user_id: user2,
2156 should_notify: true
2157 }
2158 ]
2159 );
2160 assert_eq!(
2161 db.get_contacts(user2).await.unwrap(),
2162 [
2163 Contact::Accepted {
2164 user_id: user1,
2165 should_notify: false
2166 },
2167 Contact::Accepted {
2168 user_id: user2,
2169 should_notify: false
2170 }
2171 ]
2172 );
2173
2174 // User 3 redeems the invite code and becomes a contact of user 1.
2175 let user3 = db
2176 .redeem_invite_code(&invite_code, "user-3", None)
2177 .await
2178 .unwrap();
2179 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2180 assert_eq!(invite_count, 0);
2181 assert_eq!(
2182 db.get_contacts(user1).await.unwrap(),
2183 [
2184 Contact::Accepted {
2185 user_id: user1,
2186 should_notify: false
2187 },
2188 Contact::Accepted {
2189 user_id: user2,
2190 should_notify: true
2191 },
2192 Contact::Accepted {
2193 user_id: user3,
2194 should_notify: true
2195 }
2196 ]
2197 );
2198 assert_eq!(
2199 db.get_contacts(user3).await.unwrap(),
2200 [
2201 Contact::Accepted {
2202 user_id: user1,
2203 should_notify: false
2204 },
2205 Contact::Accepted {
2206 user_id: user3,
2207 should_notify: false
2208 },
2209 ]
2210 );
2211
2212 // Trying to reedem the code for the third time results in an error.
2213 db.redeem_invite_code(&invite_code, "user-4", None)
2214 .await
2215 .unwrap_err();
2216
2217 // Invite count can be updated after the code has been created.
2218 db.set_invite_count(user1, 2).await.unwrap();
2219 let (latest_code, invite_count) =
2220 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2221 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2222 assert_eq!(invite_count, 2);
2223
2224 // User 4 can now redeem the invite code and becomes a contact of user 1.
2225 let user4 = db
2226 .redeem_invite_code(&invite_code, "user-4", None)
2227 .await
2228 .unwrap();
2229 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2230 assert_eq!(invite_count, 1);
2231 assert_eq!(
2232 db.get_contacts(user1).await.unwrap(),
2233 [
2234 Contact::Accepted {
2235 user_id: user1,
2236 should_notify: false
2237 },
2238 Contact::Accepted {
2239 user_id: user2,
2240 should_notify: true
2241 },
2242 Contact::Accepted {
2243 user_id: user3,
2244 should_notify: true
2245 },
2246 Contact::Accepted {
2247 user_id: user4,
2248 should_notify: true
2249 }
2250 ]
2251 );
2252 assert_eq!(
2253 db.get_contacts(user4).await.unwrap(),
2254 [
2255 Contact::Accepted {
2256 user_id: user1,
2257 should_notify: false
2258 },
2259 Contact::Accepted {
2260 user_id: user4,
2261 should_notify: false
2262 },
2263 ]
2264 );
2265
2266 // An existing user cannot redeem invite codes.
2267 db.redeem_invite_code(&invite_code, "user-2", None)
2268 .await
2269 .unwrap_err();
2270 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2271 assert_eq!(invite_count, 1);
2272 }
2273
2274 pub struct TestDb {
2275 pub db: Option<Arc<dyn Db>>,
2276 pub url: String,
2277 }
2278
2279 impl TestDb {
2280 pub async fn postgres() -> Self {
2281 lazy_static! {
2282 static ref LOCK: Mutex<()> = Mutex::new(());
2283 }
2284
2285 let _guard = LOCK.lock();
2286 let mut rng = StdRng::from_entropy();
2287 let name = format!("zed-test-{}", rng.gen::<u128>());
2288 let url = format!("postgres://postgres@localhost/{}", name);
2289 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2290 Postgres::create_database(&url)
2291 .await
2292 .expect("failed to create test db");
2293 let db = PostgresDb::new(&url, 5).await.unwrap();
2294 let migrator = Migrator::new(migrations_path).await.unwrap();
2295 migrator.run(&db.pool).await.unwrap();
2296 Self {
2297 db: Some(Arc::new(db)),
2298 url,
2299 }
2300 }
2301
2302 pub fn fake(background: Arc<Background>) -> Self {
2303 Self {
2304 db: Some(Arc::new(FakeDb::new(background))),
2305 url: Default::default(),
2306 }
2307 }
2308
2309 pub fn db(&self) -> &Arc<dyn Db> {
2310 self.db.as_ref().unwrap()
2311 }
2312 }
2313
2314 impl Drop for TestDb {
2315 fn drop(&mut self) {
2316 if let Some(db) = self.db.take() {
2317 futures::executor::block_on(db.teardown(&self.url));
2318 }
2319 }
2320 }
2321
2322 pub struct FakeDb {
2323 background: Arc<Background>,
2324 pub users: Mutex<BTreeMap<UserId, User>>,
2325 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2326 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2327 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2328 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2329 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2330 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2331 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2332 pub contacts: Mutex<Vec<FakeContact>>,
2333 next_channel_message_id: Mutex<i32>,
2334 next_user_id: Mutex<i32>,
2335 next_org_id: Mutex<i32>,
2336 next_channel_id: Mutex<i32>,
2337 next_project_id: Mutex<i32>,
2338 }
2339
2340 #[derive(Debug)]
2341 pub struct FakeContact {
2342 pub requester_id: UserId,
2343 pub responder_id: UserId,
2344 pub accepted: bool,
2345 pub should_notify: bool,
2346 }
2347
2348 impl FakeDb {
2349 pub fn new(background: Arc<Background>) -> Self {
2350 Self {
2351 background,
2352 users: Default::default(),
2353 next_user_id: Mutex::new(0),
2354 projects: Default::default(),
2355 worktree_extensions: Default::default(),
2356 next_project_id: Mutex::new(1),
2357 orgs: Default::default(),
2358 next_org_id: Mutex::new(1),
2359 org_memberships: Default::default(),
2360 channels: Default::default(),
2361 next_channel_id: Mutex::new(1),
2362 channel_memberships: Default::default(),
2363 channel_messages: Default::default(),
2364 next_channel_message_id: Mutex::new(1),
2365 contacts: Default::default(),
2366 }
2367 }
2368 }
2369
2370 #[async_trait]
2371 impl Db for FakeDb {
2372 async fn create_user(
2373 &self,
2374 github_login: &str,
2375 email_address: Option<&str>,
2376 admin: bool,
2377 ) -> Result<UserId> {
2378 self.background.simulate_random_delay().await;
2379
2380 let mut users = self.users.lock();
2381 if let Some(user) = users
2382 .values()
2383 .find(|user| user.github_login == github_login)
2384 {
2385 Ok(user.id)
2386 } else {
2387 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2388 users.insert(
2389 user_id,
2390 User {
2391 id: user_id,
2392 github_login: github_login.to_string(),
2393 email_address: email_address.map(str::to_string),
2394 admin,
2395 invite_code: None,
2396 invite_count: 0,
2397 connected_once: false,
2398 },
2399 );
2400 Ok(user_id)
2401 }
2402 }
2403
2404 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2405 unimplemented!()
2406 }
2407
2408 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2409 unimplemented!()
2410 }
2411
2412 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2413 unimplemented!()
2414 }
2415
2416 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2417 self.background.simulate_random_delay().await;
2418 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2419 }
2420
2421 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2422 self.background.simulate_random_delay().await;
2423 let users = self.users.lock();
2424 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2425 }
2426
2427 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2428 unimplemented!()
2429 }
2430
2431 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2432 self.background.simulate_random_delay().await;
2433 Ok(self
2434 .users
2435 .lock()
2436 .values()
2437 .find(|user| user.github_login == github_login)
2438 .cloned())
2439 }
2440
2441 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2442 unimplemented!()
2443 }
2444
2445 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2446 self.background.simulate_random_delay().await;
2447 let mut users = self.users.lock();
2448 let mut user = users
2449 .get_mut(&id)
2450 .ok_or_else(|| anyhow!("user not found"))?;
2451 user.connected_once = connected_once;
2452 Ok(())
2453 }
2454
2455 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2456 unimplemented!()
2457 }
2458
2459 // invite codes
2460
2461 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2462 unimplemented!()
2463 }
2464
2465 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2466 self.background.simulate_random_delay().await;
2467 Ok(None)
2468 }
2469
2470 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2471 unimplemented!()
2472 }
2473
2474 async fn redeem_invite_code(
2475 &self,
2476 _code: &str,
2477 _login: &str,
2478 _email_address: Option<&str>,
2479 ) -> Result<UserId> {
2480 unimplemented!()
2481 }
2482
2483 // projects
2484
2485 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2486 self.background.simulate_random_delay().await;
2487 if !self.users.lock().contains_key(&host_user_id) {
2488 Err(anyhow!("no such user"))?;
2489 }
2490
2491 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2492 self.projects.lock().insert(
2493 project_id,
2494 Project {
2495 id: project_id,
2496 host_user_id,
2497 unregistered: false,
2498 },
2499 );
2500 Ok(project_id)
2501 }
2502
2503 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2504 self.background.simulate_random_delay().await;
2505 self.projects
2506 .lock()
2507 .get_mut(&project_id)
2508 .ok_or_else(|| anyhow!("no such project"))?
2509 .unregistered = true;
2510 Ok(())
2511 }
2512
2513 async fn update_worktree_extensions(
2514 &self,
2515 project_id: ProjectId,
2516 worktree_id: u64,
2517 extensions: HashMap<String, u32>,
2518 ) -> Result<()> {
2519 self.background.simulate_random_delay().await;
2520 if !self.projects.lock().contains_key(&project_id) {
2521 Err(anyhow!("no such project"))?;
2522 }
2523
2524 for (extension, count) in extensions {
2525 self.worktree_extensions
2526 .lock()
2527 .insert((project_id, worktree_id, extension), count);
2528 }
2529
2530 Ok(())
2531 }
2532
2533 async fn get_project_extensions(
2534 &self,
2535 _project_id: ProjectId,
2536 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2537 unimplemented!()
2538 }
2539
2540 async fn record_user_activity(
2541 &self,
2542 _time_period: Range<OffsetDateTime>,
2543 _active_projects: &[(UserId, ProjectId)],
2544 ) -> Result<()> {
2545 unimplemented!()
2546 }
2547
2548 async fn get_active_user_count(
2549 &self,
2550 _time_period: Range<OffsetDateTime>,
2551 _min_duration: Duration,
2552 ) -> Result<usize> {
2553 unimplemented!()
2554 }
2555
2556 async fn get_top_users_activity_summary(
2557 &self,
2558 _time_period: Range<OffsetDateTime>,
2559 _limit: usize,
2560 ) -> Result<Vec<UserActivitySummary>> {
2561 unimplemented!()
2562 }
2563
2564 async fn get_user_activity_timeline(
2565 &self,
2566 _time_period: Range<OffsetDateTime>,
2567 _user_id: UserId,
2568 ) -> Result<Vec<UserActivityPeriod>> {
2569 unimplemented!()
2570 }
2571
2572 // contacts
2573
2574 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2575 self.background.simulate_random_delay().await;
2576 let mut contacts = vec![Contact::Accepted {
2577 user_id: id,
2578 should_notify: false,
2579 }];
2580
2581 for contact in self.contacts.lock().iter() {
2582 if contact.requester_id == id {
2583 if contact.accepted {
2584 contacts.push(Contact::Accepted {
2585 user_id: contact.responder_id,
2586 should_notify: contact.should_notify,
2587 });
2588 } else {
2589 contacts.push(Contact::Outgoing {
2590 user_id: contact.responder_id,
2591 });
2592 }
2593 } else if contact.responder_id == id {
2594 if contact.accepted {
2595 contacts.push(Contact::Accepted {
2596 user_id: contact.requester_id,
2597 should_notify: false,
2598 });
2599 } else {
2600 contacts.push(Contact::Incoming {
2601 user_id: contact.requester_id,
2602 should_notify: contact.should_notify,
2603 });
2604 }
2605 }
2606 }
2607
2608 contacts.sort_unstable_by_key(|contact| contact.user_id());
2609 Ok(contacts)
2610 }
2611
2612 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2613 self.background.simulate_random_delay().await;
2614 Ok(self.contacts.lock().iter().any(|contact| {
2615 contact.accepted
2616 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2617 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2618 }))
2619 }
2620
2621 async fn send_contact_request(
2622 &self,
2623 requester_id: UserId,
2624 responder_id: UserId,
2625 ) -> Result<()> {
2626 self.background.simulate_random_delay().await;
2627 let mut contacts = self.contacts.lock();
2628 for contact in contacts.iter_mut() {
2629 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2630 if contact.accepted {
2631 Err(anyhow!("contact already exists"))?;
2632 } else {
2633 Err(anyhow!("contact already requested"))?;
2634 }
2635 }
2636 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2637 if contact.accepted {
2638 Err(anyhow!("contact already exists"))?;
2639 } else {
2640 contact.accepted = true;
2641 contact.should_notify = false;
2642 return Ok(());
2643 }
2644 }
2645 }
2646 contacts.push(FakeContact {
2647 requester_id,
2648 responder_id,
2649 accepted: false,
2650 should_notify: true,
2651 });
2652 Ok(())
2653 }
2654
2655 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2656 self.background.simulate_random_delay().await;
2657 self.contacts.lock().retain(|contact| {
2658 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2659 });
2660 Ok(())
2661 }
2662
2663 async fn dismiss_contact_notification(
2664 &self,
2665 user_id: UserId,
2666 contact_user_id: UserId,
2667 ) -> Result<()> {
2668 self.background.simulate_random_delay().await;
2669 let mut contacts = self.contacts.lock();
2670 for contact in contacts.iter_mut() {
2671 if contact.requester_id == contact_user_id
2672 && contact.responder_id == user_id
2673 && !contact.accepted
2674 {
2675 contact.should_notify = false;
2676 return Ok(());
2677 }
2678 if contact.requester_id == user_id
2679 && contact.responder_id == contact_user_id
2680 && contact.accepted
2681 {
2682 contact.should_notify = false;
2683 return Ok(());
2684 }
2685 }
2686 Err(anyhow!("no such notification"))?
2687 }
2688
2689 async fn respond_to_contact_request(
2690 &self,
2691 responder_id: UserId,
2692 requester_id: UserId,
2693 accept: bool,
2694 ) -> Result<()> {
2695 self.background.simulate_random_delay().await;
2696 let mut contacts = self.contacts.lock();
2697 for (ix, contact) in contacts.iter_mut().enumerate() {
2698 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2699 if contact.accepted {
2700 Err(anyhow!("contact already confirmed"))?;
2701 }
2702 if accept {
2703 contact.accepted = true;
2704 contact.should_notify = true;
2705 } else {
2706 contacts.remove(ix);
2707 }
2708 return Ok(());
2709 }
2710 }
2711 Err(anyhow!("no such contact request"))?
2712 }
2713
2714 async fn create_access_token_hash(
2715 &self,
2716 _user_id: UserId,
2717 _access_token_hash: &str,
2718 _max_access_token_count: usize,
2719 ) -> Result<()> {
2720 unimplemented!()
2721 }
2722
2723 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2724 unimplemented!()
2725 }
2726
2727 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2728 unimplemented!()
2729 }
2730
2731 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2732 self.background.simulate_random_delay().await;
2733 let mut orgs = self.orgs.lock();
2734 if orgs.values().any(|org| org.slug == slug) {
2735 Err(anyhow!("org already exists"))?
2736 } else {
2737 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2738 orgs.insert(
2739 org_id,
2740 Org {
2741 id: org_id,
2742 name: name.to_string(),
2743 slug: slug.to_string(),
2744 },
2745 );
2746 Ok(org_id)
2747 }
2748 }
2749
2750 async fn add_org_member(
2751 &self,
2752 org_id: OrgId,
2753 user_id: UserId,
2754 is_admin: bool,
2755 ) -> Result<()> {
2756 self.background.simulate_random_delay().await;
2757 if !self.orgs.lock().contains_key(&org_id) {
2758 Err(anyhow!("org does not exist"))?;
2759 }
2760 if !self.users.lock().contains_key(&user_id) {
2761 Err(anyhow!("user does not exist"))?;
2762 }
2763
2764 self.org_memberships
2765 .lock()
2766 .entry((org_id, user_id))
2767 .or_insert(is_admin);
2768 Ok(())
2769 }
2770
2771 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2772 self.background.simulate_random_delay().await;
2773 if !self.orgs.lock().contains_key(&org_id) {
2774 Err(anyhow!("org does not exist"))?;
2775 }
2776
2777 let mut channels = self.channels.lock();
2778 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2779 channels.insert(
2780 channel_id,
2781 Channel {
2782 id: channel_id,
2783 name: name.to_string(),
2784 owner_id: org_id.0,
2785 owner_is_user: false,
2786 },
2787 );
2788 Ok(channel_id)
2789 }
2790
2791 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2792 self.background.simulate_random_delay().await;
2793 Ok(self
2794 .channels
2795 .lock()
2796 .values()
2797 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2798 .cloned()
2799 .collect())
2800 }
2801
2802 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2803 self.background.simulate_random_delay().await;
2804 let channels = self.channels.lock();
2805 let memberships = self.channel_memberships.lock();
2806 Ok(channels
2807 .values()
2808 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2809 .cloned()
2810 .collect())
2811 }
2812
2813 async fn can_user_access_channel(
2814 &self,
2815 user_id: UserId,
2816 channel_id: ChannelId,
2817 ) -> Result<bool> {
2818 self.background.simulate_random_delay().await;
2819 Ok(self
2820 .channel_memberships
2821 .lock()
2822 .contains_key(&(channel_id, user_id)))
2823 }
2824
2825 async fn add_channel_member(
2826 &self,
2827 channel_id: ChannelId,
2828 user_id: UserId,
2829 is_admin: bool,
2830 ) -> Result<()> {
2831 self.background.simulate_random_delay().await;
2832 if !self.channels.lock().contains_key(&channel_id) {
2833 Err(anyhow!("channel does not exist"))?;
2834 }
2835 if !self.users.lock().contains_key(&user_id) {
2836 Err(anyhow!("user does not exist"))?;
2837 }
2838
2839 self.channel_memberships
2840 .lock()
2841 .entry((channel_id, user_id))
2842 .or_insert(is_admin);
2843 Ok(())
2844 }
2845
2846 async fn create_channel_message(
2847 &self,
2848 channel_id: ChannelId,
2849 sender_id: UserId,
2850 body: &str,
2851 timestamp: OffsetDateTime,
2852 nonce: u128,
2853 ) -> Result<MessageId> {
2854 self.background.simulate_random_delay().await;
2855 if !self.channels.lock().contains_key(&channel_id) {
2856 Err(anyhow!("channel does not exist"))?;
2857 }
2858 if !self.users.lock().contains_key(&sender_id) {
2859 Err(anyhow!("user does not exist"))?;
2860 }
2861
2862 let mut messages = self.channel_messages.lock();
2863 if let Some(message) = messages
2864 .values()
2865 .find(|message| message.nonce.as_u128() == nonce)
2866 {
2867 Ok(message.id)
2868 } else {
2869 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2870 messages.insert(
2871 message_id,
2872 ChannelMessage {
2873 id: message_id,
2874 channel_id,
2875 sender_id,
2876 body: body.to_string(),
2877 sent_at: timestamp,
2878 nonce: Uuid::from_u128(nonce),
2879 },
2880 );
2881 Ok(message_id)
2882 }
2883 }
2884
2885 async fn get_channel_messages(
2886 &self,
2887 channel_id: ChannelId,
2888 count: usize,
2889 before_id: Option<MessageId>,
2890 ) -> Result<Vec<ChannelMessage>> {
2891 self.background.simulate_random_delay().await;
2892 let mut messages = self
2893 .channel_messages
2894 .lock()
2895 .values()
2896 .rev()
2897 .filter(|message| {
2898 message.channel_id == channel_id
2899 && message.id < before_id.unwrap_or(MessageId::MAX)
2900 })
2901 .take(count)
2902 .cloned()
2903 .collect::<Vec<_>>();
2904 messages.sort_unstable_by_key(|message| message.id);
2905 Ok(messages)
2906 }
2907
2908 async fn teardown(&self, _: &str) {}
2909
2910 #[cfg(test)]
2911 fn as_fake(&self) -> Option<&FakeDb> {
2912 Some(self)
2913 }
2914 }
2915
2916 fn build_background_executor() -> Arc<Background> {
2917 Deterministic::new(0).build_background()
2918 }
2919}