1use std::{ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::{Deserialize, Serialize};
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
29 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
30 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
31 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
32 async fn destroy_user(&self, id: UserId) -> Result<()>;
33
34 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
35 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
36 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
37 async fn redeem_invite_code(
38 &self,
39 code: &str,
40 login: &str,
41 email_address: Option<&str>,
42 ) -> Result<UserId>;
43
44 /// Registers a new project for the given user.
45 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
46
47 /// Unregisters a project for the given project id.
48 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
49
50 /// Update file counts by extension for the given project and worktree.
51 async fn update_worktree_extensions(
52 &self,
53 project_id: ProjectId,
54 worktree_id: u64,
55 extensions: HashMap<String, usize>,
56 ) -> Result<()>;
57
58 /// Get the file counts on the given project keyed by their worktree and extension.
59 async fn get_project_extensions(
60 &self,
61 project_id: ProjectId,
62 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
63
64 /// Record which users have been active in which projects during
65 /// a given period of time.
66 async fn record_user_activity(
67 &self,
68 time_period: Range<OffsetDateTime>,
69 active_projects: &[(UserId, ProjectId)],
70 ) -> Result<()>;
71
72 /// Get the users that have been most active during the given time period,
73 /// along with the amount of time they have been active in each project.
74 async fn get_top_users_activity_summary(
75 &self,
76 time_period: Range<OffsetDateTime>,
77 max_user_count: usize,
78 ) -> Result<Vec<UserActivitySummary>>;
79
80 /// Get the project activity for the given user and time period.
81 async fn get_user_activity_timeline(
82 &self,
83 time_period: Range<OffsetDateTime>,
84 user_id: UserId,
85 ) -> Result<Vec<UserActivityPeriod>>;
86
87 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
88 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
89 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
90 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
91 async fn dismiss_contact_notification(
92 &self,
93 responder_id: UserId,
94 requester_id: UserId,
95 ) -> Result<()>;
96 async fn respond_to_contact_request(
97 &self,
98 responder_id: UserId,
99 requester_id: UserId,
100 accept: bool,
101 ) -> Result<()>;
102
103 async fn create_access_token_hash(
104 &self,
105 user_id: UserId,
106 access_token_hash: &str,
107 max_access_token_count: usize,
108 ) -> Result<()>;
109 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
110 #[cfg(any(test, feature = "seed-support"))]
111
112 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
113 #[cfg(any(test, feature = "seed-support"))]
114 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
115 #[cfg(any(test, feature = "seed-support"))]
116 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
117 #[cfg(any(test, feature = "seed-support"))]
118 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
119 #[cfg(any(test, feature = "seed-support"))]
120
121 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
122 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
123 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
124 -> Result<bool>;
125 #[cfg(any(test, feature = "seed-support"))]
126 async fn add_channel_member(
127 &self,
128 channel_id: ChannelId,
129 user_id: UserId,
130 is_admin: bool,
131 ) -> Result<()>;
132 async fn create_channel_message(
133 &self,
134 channel_id: ChannelId,
135 sender_id: UserId,
136 body: &str,
137 timestamp: OffsetDateTime,
138 nonce: u128,
139 ) -> Result<MessageId>;
140 async fn get_channel_messages(
141 &self,
142 channel_id: ChannelId,
143 count: usize,
144 before_id: Option<MessageId>,
145 ) -> Result<Vec<ChannelMessage>>;
146 #[cfg(test)]
147 async fn teardown(&self, url: &str);
148 #[cfg(test)]
149 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
150}
151
152pub struct PostgresDb {
153 pool: sqlx::PgPool,
154}
155
156impl PostgresDb {
157 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
158 let pool = DbOptions::new()
159 .max_connections(max_connections)
160 .connect(&url)
161 .await
162 .context("failed to connect to postgres database")?;
163 Ok(Self { pool })
164 }
165}
166
167#[async_trait]
168impl Db for PostgresDb {
169 // users
170
171 async fn create_user(
172 &self,
173 github_login: &str,
174 email_address: Option<&str>,
175 admin: bool,
176 ) -> Result<UserId> {
177 let query = "
178 INSERT INTO users (github_login, email_address, admin)
179 VALUES ($1, $2, $3)
180 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
181 RETURNING id
182 ";
183 Ok(sqlx::query_scalar(query)
184 .bind(github_login)
185 .bind(email_address)
186 .bind(admin)
187 .fetch_one(&self.pool)
188 .await
189 .map(UserId)?)
190 }
191
192 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
193 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
194 Ok(sqlx::query_as(query)
195 .bind(limit as i32)
196 .bind((page * limit) as i32)
197 .fetch_all(&self.pool)
198 .await?)
199 }
200
201 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
202 let mut query = QueryBuilder::new(
203 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
204 );
205 query.push_values(
206 users,
207 |mut query, (github_login, email_address, invite_count)| {
208 query
209 .push_bind(github_login)
210 .push_bind(email_address)
211 .push_bind(false)
212 .push_bind(nanoid!(16))
213 .push_bind(invite_count as i32);
214 },
215 );
216 query.push(
217 "
218 ON CONFLICT (github_login) DO UPDATE SET
219 github_login = excluded.github_login,
220 invite_count = excluded.invite_count,
221 invite_code = CASE WHEN users.invite_code IS NULL
222 THEN excluded.invite_code
223 ELSE users.invite_code
224 END
225 RETURNING id
226 ",
227 );
228
229 let rows = query.build().fetch_all(&self.pool).await?;
230 Ok(rows
231 .into_iter()
232 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
233 .collect())
234 }
235
236 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
237 let like_string = fuzzy_like_string(name_query);
238 let query = "
239 SELECT users.*
240 FROM users
241 WHERE github_login ILIKE $1
242 ORDER BY github_login <-> $2
243 LIMIT $3
244 ";
245 Ok(sqlx::query_as(query)
246 .bind(like_string)
247 .bind(name_query)
248 .bind(limit as i32)
249 .fetch_all(&self.pool)
250 .await?)
251 }
252
253 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
254 let users = self.get_users_by_ids(vec![id]).await?;
255 Ok(users.into_iter().next())
256 }
257
258 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
259 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
260 let query = "
261 SELECT users.*
262 FROM users
263 WHERE users.id = ANY ($1)
264 ";
265 Ok(sqlx::query_as(query)
266 .bind(&ids)
267 .fetch_all(&self.pool)
268 .await?)
269 }
270
271 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
272 let query = format!(
273 "
274 SELECT users.*
275 FROM users
276 WHERE invite_count = 0
277 AND inviter_id IS{} NULL
278 ",
279 if invited_by_another_user { " NOT" } else { "" }
280 );
281
282 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
283 }
284
285 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
286 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
287 Ok(sqlx::query_as(query)
288 .bind(github_login)
289 .fetch_optional(&self.pool)
290 .await?)
291 }
292
293 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
294 let query = "UPDATE users SET admin = $1 WHERE id = $2";
295 Ok(sqlx::query(query)
296 .bind(is_admin)
297 .bind(id.0)
298 .execute(&self.pool)
299 .await
300 .map(drop)?)
301 }
302
303 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
304 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
305 Ok(sqlx::query(query)
306 .bind(connected_once)
307 .bind(id.0)
308 .execute(&self.pool)
309 .await
310 .map(drop)?)
311 }
312
313 async fn destroy_user(&self, id: UserId) -> Result<()> {
314 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
315 sqlx::query(query)
316 .bind(id.0)
317 .execute(&self.pool)
318 .await
319 .map(drop)?;
320 let query = "DELETE FROM users WHERE id = $1;";
321 Ok(sqlx::query(query)
322 .bind(id.0)
323 .execute(&self.pool)
324 .await
325 .map(drop)?)
326 }
327
328 // invite codes
329
330 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
331 let mut tx = self.pool.begin().await?;
332 if count > 0 {
333 sqlx::query(
334 "
335 UPDATE users
336 SET invite_code = $1
337 WHERE id = $2 AND invite_code IS NULL
338 ",
339 )
340 .bind(nanoid!(16))
341 .bind(id)
342 .execute(&mut tx)
343 .await?;
344 }
345
346 sqlx::query(
347 "
348 UPDATE users
349 SET invite_count = $1
350 WHERE id = $2
351 ",
352 )
353 .bind(count as i32)
354 .bind(id)
355 .execute(&mut tx)
356 .await?;
357 tx.commit().await?;
358 Ok(())
359 }
360
361 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
362 let result: Option<(String, i32)> = sqlx::query_as(
363 "
364 SELECT invite_code, invite_count
365 FROM users
366 WHERE id = $1 AND invite_code IS NOT NULL
367 ",
368 )
369 .bind(id)
370 .fetch_optional(&self.pool)
371 .await?;
372 if let Some((code, count)) = result {
373 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
374 } else {
375 Ok(None)
376 }
377 }
378
379 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
380 sqlx::query_as(
381 "
382 SELECT *
383 FROM users
384 WHERE invite_code = $1
385 ",
386 )
387 .bind(code)
388 .fetch_optional(&self.pool)
389 .await?
390 .ok_or_else(|| {
391 Error::Http(
392 StatusCode::NOT_FOUND,
393 "that invite code does not exist".to_string(),
394 )
395 })
396 }
397
398 async fn redeem_invite_code(
399 &self,
400 code: &str,
401 login: &str,
402 email_address: Option<&str>,
403 ) -> Result<UserId> {
404 let mut tx = self.pool.begin().await?;
405
406 let inviter_id: Option<UserId> = sqlx::query_scalar(
407 "
408 UPDATE users
409 SET invite_count = invite_count - 1
410 WHERE
411 invite_code = $1 AND
412 invite_count > 0
413 RETURNING id
414 ",
415 )
416 .bind(code)
417 .fetch_optional(&mut tx)
418 .await?;
419
420 let inviter_id = match inviter_id {
421 Some(inviter_id) => inviter_id,
422 None => {
423 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
424 .bind(code)
425 .fetch_optional(&mut tx)
426 .await?
427 .is_some()
428 {
429 Err(Error::Http(
430 StatusCode::UNAUTHORIZED,
431 "no invites remaining".to_string(),
432 ))?
433 } else {
434 Err(Error::Http(
435 StatusCode::NOT_FOUND,
436 "invite code not found".to_string(),
437 ))?
438 }
439 }
440 };
441
442 let invitee_id = sqlx::query_scalar(
443 "
444 INSERT INTO users
445 (github_login, email_address, admin, inviter_id)
446 VALUES
447 ($1, $2, 'f', $3)
448 RETURNING id
449 ",
450 )
451 .bind(login)
452 .bind(email_address)
453 .bind(inviter_id)
454 .fetch_one(&mut tx)
455 .await
456 .map(UserId)?;
457
458 sqlx::query(
459 "
460 INSERT INTO contacts
461 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
462 VALUES
463 ($1, $2, 't', 't', 't')
464 ",
465 )
466 .bind(inviter_id)
467 .bind(invitee_id)
468 .execute(&mut tx)
469 .await?;
470
471 tx.commit().await?;
472 Ok(invitee_id)
473 }
474
475 // projects
476
477 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
478 Ok(sqlx::query_scalar(
479 "
480 INSERT INTO projects(host_user_id)
481 VALUES ($1)
482 RETURNING id
483 ",
484 )
485 .bind(host_user_id)
486 .fetch_one(&self.pool)
487 .await
488 .map(ProjectId)?)
489 }
490
491 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
492 sqlx::query(
493 "
494 UPDATE projects
495 SET unregistered = 't'
496 WHERE id = $1
497 ",
498 )
499 .bind(project_id)
500 .execute(&self.pool)
501 .await?;
502 Ok(())
503 }
504
505 async fn update_worktree_extensions(
506 &self,
507 project_id: ProjectId,
508 worktree_id: u64,
509 extensions: HashMap<String, usize>,
510 ) -> Result<()> {
511 if extensions.is_empty() {
512 return Ok(());
513 }
514
515 let mut query = QueryBuilder::new(
516 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
517 );
518 query.push_values(extensions, |mut query, (extension, count)| {
519 query
520 .push_bind(project_id)
521 .push_bind(worktree_id as i32)
522 .push_bind(extension)
523 .push_bind(count as i32);
524 });
525 query.push(
526 "
527 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
528 count = excluded.count
529 ",
530 );
531 query.build().execute(&self.pool).await?;
532
533 Ok(())
534 }
535
536 async fn get_project_extensions(
537 &self,
538 project_id: ProjectId,
539 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
540 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
541 struct WorktreeExtension {
542 worktree_id: i32,
543 extension: String,
544 count: i32,
545 }
546
547 let query = "
548 SELECT worktree_id, extension, count
549 FROM worktree_extensions
550 WHERE project_id = $1
551 ";
552 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
553 .bind(&project_id)
554 .fetch_all(&self.pool)
555 .await?;
556
557 let mut extension_counts = HashMap::default();
558 for count in counts {
559 extension_counts
560 .entry(count.worktree_id as u64)
561 .or_insert(HashMap::default())
562 .insert(count.extension, count.count as usize);
563 }
564 Ok(extension_counts)
565 }
566
567 async fn record_user_activity(
568 &self,
569 time_period: Range<OffsetDateTime>,
570 projects: &[(UserId, ProjectId)],
571 ) -> Result<()> {
572 let query = "
573 INSERT INTO project_activity_periods
574 (ended_at, duration_millis, user_id, project_id)
575 VALUES
576 ($1, $2, $3, $4);
577 ";
578
579 let mut tx = self.pool.begin().await?;
580 let duration_millis =
581 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
582 for (user_id, project_id) in projects {
583 sqlx::query(query)
584 .bind(time_period.end)
585 .bind(duration_millis)
586 .bind(user_id)
587 .bind(project_id)
588 .execute(&mut tx)
589 .await?;
590 }
591 tx.commit().await?;
592
593 Ok(())
594 }
595
596 async fn get_top_users_activity_summary(
597 &self,
598 time_period: Range<OffsetDateTime>,
599 max_user_count: usize,
600 ) -> Result<Vec<UserActivitySummary>> {
601 let query = "
602 WITH
603 project_durations AS (
604 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
605 FROM project_activity_periods
606 WHERE $1 < ended_at AND ended_at <= $2
607 GROUP BY user_id, project_id
608 ),
609 user_durations AS (
610 SELECT user_id, SUM(project_duration) as total_duration
611 FROM project_durations
612 GROUP BY user_id
613 ORDER BY total_duration DESC
614 LIMIT $3
615 )
616 SELECT user_durations.user_id, users.github_login, project_id, project_duration
617 FROM user_durations, project_durations, users
618 WHERE
619 user_durations.user_id = project_durations.user_id AND
620 user_durations.user_id = users.id
621 ORDER BY user_id ASC, project_duration DESC
622 ";
623
624 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
625 .bind(time_period.start)
626 .bind(time_period.end)
627 .bind(max_user_count as i32)
628 .fetch(&self.pool);
629
630 let mut result = Vec::<UserActivitySummary>::new();
631 while let Some(row) = rows.next().await {
632 let (user_id, github_login, project_id, duration_millis) = row?;
633 let project_id = project_id;
634 let duration = Duration::from_millis(duration_millis as u64);
635 if let Some(last_summary) = result.last_mut() {
636 if last_summary.id == user_id {
637 last_summary.project_activity.push((project_id, duration));
638 continue;
639 }
640 }
641 result.push(UserActivitySummary {
642 id: user_id,
643 project_activity: vec![(project_id, duration)],
644 github_login,
645 });
646 }
647
648 Ok(result)
649 }
650
651 async fn get_user_activity_timeline(
652 &self,
653 time_period: Range<OffsetDateTime>,
654 user_id: UserId,
655 ) -> Result<Vec<UserActivityPeriod>> {
656 const COALESCE_THRESHOLD: Duration = Duration::from_secs(5);
657
658 let query = "
659 SELECT
660 project_activity_periods.ended_at,
661 project_activity_periods.duration_millis,
662 project_activity_periods.project_id,
663 worktree_extensions.extension,
664 worktree_extensions.count
665 FROM project_activity_periods
666 LEFT OUTER JOIN
667 worktree_extensions
668 ON
669 project_activity_periods.project_id = worktree_extensions.project_id
670 WHERE
671 project_activity_periods.user_id = $1 AND
672 $2 < project_activity_periods.ended_at AND
673 project_activity_periods.ended_at <= $3
674 ORDER BY project_activity_periods.id ASC
675 ";
676
677 let mut rows = sqlx::query_as::<
678 _,
679 (
680 PrimitiveDateTime,
681 i32,
682 ProjectId,
683 Option<String>,
684 Option<i32>,
685 ),
686 >(query)
687 .bind(user_id)
688 .bind(time_period.start)
689 .bind(time_period.end)
690 .fetch(&self.pool);
691
692 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
693 while let Some(row) = rows.next().await {
694 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
695 let ended_at = ended_at.assume_utc();
696 let duration = Duration::from_millis(duration_millis as u64);
697 let started_at = ended_at - duration;
698 let project_time_periods = time_periods.entry(project_id).or_default();
699
700 if let Some(prev_duration) = project_time_periods.last_mut() {
701 if started_at - prev_duration.end <= COALESCE_THRESHOLD {
702 prev_duration.end = ended_at;
703 } else {
704 project_time_periods.push(UserActivityPeriod {
705 project_id,
706 start: started_at,
707 end: ended_at,
708 extensions: Default::default(),
709 });
710 }
711 } else {
712 project_time_periods.push(UserActivityPeriod {
713 project_id,
714 start: started_at,
715 end: ended_at,
716 extensions: Default::default(),
717 });
718 }
719
720 if let Some((extension, extension_count)) = extension.zip(extension_count) {
721 project_time_periods
722 .last_mut()
723 .unwrap()
724 .extensions
725 .insert(extension, extension_count as usize);
726 }
727 }
728
729 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
730 durations.sort_unstable_by_key(|duration| duration.start);
731 Ok(durations)
732 }
733
734 // contacts
735
736 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
737 let query = "
738 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
739 FROM contacts
740 WHERE user_id_a = $1 OR user_id_b = $1;
741 ";
742
743 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
744 .bind(user_id)
745 .fetch(&self.pool);
746
747 let mut contacts = vec![Contact::Accepted {
748 user_id,
749 should_notify: false,
750 }];
751 while let Some(row) = rows.next().await {
752 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
753
754 if user_id_a == user_id {
755 if accepted {
756 contacts.push(Contact::Accepted {
757 user_id: user_id_b,
758 should_notify: should_notify && a_to_b,
759 });
760 } else if a_to_b {
761 contacts.push(Contact::Outgoing { user_id: user_id_b })
762 } else {
763 contacts.push(Contact::Incoming {
764 user_id: user_id_b,
765 should_notify,
766 });
767 }
768 } else {
769 if accepted {
770 contacts.push(Contact::Accepted {
771 user_id: user_id_a,
772 should_notify: should_notify && !a_to_b,
773 });
774 } else if a_to_b {
775 contacts.push(Contact::Incoming {
776 user_id: user_id_a,
777 should_notify,
778 });
779 } else {
780 contacts.push(Contact::Outgoing { user_id: user_id_a });
781 }
782 }
783 }
784
785 contacts.sort_unstable_by_key(|contact| contact.user_id());
786
787 Ok(contacts)
788 }
789
790 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
791 let (id_a, id_b) = if user_id_1 < user_id_2 {
792 (user_id_1, user_id_2)
793 } else {
794 (user_id_2, user_id_1)
795 };
796
797 let query = "
798 SELECT 1 FROM contacts
799 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
800 LIMIT 1
801 ";
802 Ok(sqlx::query_scalar::<_, i32>(query)
803 .bind(id_a.0)
804 .bind(id_b.0)
805 .fetch_optional(&self.pool)
806 .await?
807 .is_some())
808 }
809
810 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
811 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
812 (sender_id, receiver_id, true)
813 } else {
814 (receiver_id, sender_id, false)
815 };
816 let query = "
817 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
818 VALUES ($1, $2, $3, 'f', 't')
819 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
820 SET
821 accepted = 't',
822 should_notify = 'f'
823 WHERE
824 NOT contacts.accepted AND
825 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
826 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
827 ";
828 let result = sqlx::query(query)
829 .bind(id_a.0)
830 .bind(id_b.0)
831 .bind(a_to_b)
832 .execute(&self.pool)
833 .await?;
834
835 if result.rows_affected() == 1 {
836 Ok(())
837 } else {
838 Err(anyhow!("contact already requested"))?
839 }
840 }
841
842 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
843 let (id_a, id_b) = if responder_id < requester_id {
844 (responder_id, requester_id)
845 } else {
846 (requester_id, responder_id)
847 };
848 let query = "
849 DELETE FROM contacts
850 WHERE user_id_a = $1 AND user_id_b = $2;
851 ";
852 let result = sqlx::query(query)
853 .bind(id_a.0)
854 .bind(id_b.0)
855 .execute(&self.pool)
856 .await?;
857
858 if result.rows_affected() == 1 {
859 Ok(())
860 } else {
861 Err(anyhow!("no such contact"))?
862 }
863 }
864
865 async fn dismiss_contact_notification(
866 &self,
867 user_id: UserId,
868 contact_user_id: UserId,
869 ) -> Result<()> {
870 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
871 (user_id, contact_user_id, true)
872 } else {
873 (contact_user_id, user_id, false)
874 };
875
876 let query = "
877 UPDATE contacts
878 SET should_notify = 'f'
879 WHERE
880 user_id_a = $1 AND user_id_b = $2 AND
881 (
882 (a_to_b = $3 AND accepted) OR
883 (a_to_b != $3 AND NOT accepted)
884 );
885 ";
886
887 let result = sqlx::query(query)
888 .bind(id_a.0)
889 .bind(id_b.0)
890 .bind(a_to_b)
891 .execute(&self.pool)
892 .await?;
893
894 if result.rows_affected() == 0 {
895 Err(anyhow!("no such contact request"))?;
896 }
897
898 Ok(())
899 }
900
901 async fn respond_to_contact_request(
902 &self,
903 responder_id: UserId,
904 requester_id: UserId,
905 accept: bool,
906 ) -> Result<()> {
907 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
908 (responder_id, requester_id, false)
909 } else {
910 (requester_id, responder_id, true)
911 };
912 let result = if accept {
913 let query = "
914 UPDATE contacts
915 SET accepted = 't', should_notify = 't'
916 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
917 ";
918 sqlx::query(query)
919 .bind(id_a.0)
920 .bind(id_b.0)
921 .bind(a_to_b)
922 .execute(&self.pool)
923 .await?
924 } else {
925 let query = "
926 DELETE FROM contacts
927 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
928 ";
929 sqlx::query(query)
930 .bind(id_a.0)
931 .bind(id_b.0)
932 .bind(a_to_b)
933 .execute(&self.pool)
934 .await?
935 };
936 if result.rows_affected() == 1 {
937 Ok(())
938 } else {
939 Err(anyhow!("no such contact request"))?
940 }
941 }
942
943 // access tokens
944
945 async fn create_access_token_hash(
946 &self,
947 user_id: UserId,
948 access_token_hash: &str,
949 max_access_token_count: usize,
950 ) -> Result<()> {
951 let insert_query = "
952 INSERT INTO access_tokens (user_id, hash)
953 VALUES ($1, $2);
954 ";
955 let cleanup_query = "
956 DELETE FROM access_tokens
957 WHERE id IN (
958 SELECT id from access_tokens
959 WHERE user_id = $1
960 ORDER BY id DESC
961 OFFSET $3
962 )
963 ";
964
965 let mut tx = self.pool.begin().await?;
966 sqlx::query(insert_query)
967 .bind(user_id.0)
968 .bind(access_token_hash)
969 .execute(&mut tx)
970 .await?;
971 sqlx::query(cleanup_query)
972 .bind(user_id.0)
973 .bind(access_token_hash)
974 .bind(max_access_token_count as i32)
975 .execute(&mut tx)
976 .await?;
977 Ok(tx.commit().await?)
978 }
979
980 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
981 let query = "
982 SELECT hash
983 FROM access_tokens
984 WHERE user_id = $1
985 ORDER BY id DESC
986 ";
987 Ok(sqlx::query_scalar(query)
988 .bind(user_id.0)
989 .fetch_all(&self.pool)
990 .await?)
991 }
992
993 // orgs
994
995 #[allow(unused)] // Help rust-analyzer
996 #[cfg(any(test, feature = "seed-support"))]
997 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
998 let query = "
999 SELECT *
1000 FROM orgs
1001 WHERE slug = $1
1002 ";
1003 Ok(sqlx::query_as(query)
1004 .bind(slug)
1005 .fetch_optional(&self.pool)
1006 .await?)
1007 }
1008
1009 #[cfg(any(test, feature = "seed-support"))]
1010 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1011 let query = "
1012 INSERT INTO orgs (name, slug)
1013 VALUES ($1, $2)
1014 RETURNING id
1015 ";
1016 Ok(sqlx::query_scalar(query)
1017 .bind(name)
1018 .bind(slug)
1019 .fetch_one(&self.pool)
1020 .await
1021 .map(OrgId)?)
1022 }
1023
1024 #[cfg(any(test, feature = "seed-support"))]
1025 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1026 let query = "
1027 INSERT INTO org_memberships (org_id, user_id, admin)
1028 VALUES ($1, $2, $3)
1029 ON CONFLICT DO NOTHING
1030 ";
1031 Ok(sqlx::query(query)
1032 .bind(org_id.0)
1033 .bind(user_id.0)
1034 .bind(is_admin)
1035 .execute(&self.pool)
1036 .await
1037 .map(drop)?)
1038 }
1039
1040 // channels
1041
1042 #[cfg(any(test, feature = "seed-support"))]
1043 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1044 let query = "
1045 INSERT INTO channels (owner_id, owner_is_user, name)
1046 VALUES ($1, false, $2)
1047 RETURNING id
1048 ";
1049 Ok(sqlx::query_scalar(query)
1050 .bind(org_id.0)
1051 .bind(name)
1052 .fetch_one(&self.pool)
1053 .await
1054 .map(ChannelId)?)
1055 }
1056
1057 #[allow(unused)] // Help rust-analyzer
1058 #[cfg(any(test, feature = "seed-support"))]
1059 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1060 let query = "
1061 SELECT *
1062 FROM channels
1063 WHERE
1064 channels.owner_is_user = false AND
1065 channels.owner_id = $1
1066 ";
1067 Ok(sqlx::query_as(query)
1068 .bind(org_id.0)
1069 .fetch_all(&self.pool)
1070 .await?)
1071 }
1072
1073 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1074 let query = "
1075 SELECT
1076 channels.*
1077 FROM
1078 channel_memberships, channels
1079 WHERE
1080 channel_memberships.user_id = $1 AND
1081 channel_memberships.channel_id = channels.id
1082 ";
1083 Ok(sqlx::query_as(query)
1084 .bind(user_id.0)
1085 .fetch_all(&self.pool)
1086 .await?)
1087 }
1088
1089 async fn can_user_access_channel(
1090 &self,
1091 user_id: UserId,
1092 channel_id: ChannelId,
1093 ) -> Result<bool> {
1094 let query = "
1095 SELECT id
1096 FROM channel_memberships
1097 WHERE user_id = $1 AND channel_id = $2
1098 LIMIT 1
1099 ";
1100 Ok(sqlx::query_scalar::<_, i32>(query)
1101 .bind(user_id.0)
1102 .bind(channel_id.0)
1103 .fetch_optional(&self.pool)
1104 .await
1105 .map(|e| e.is_some())?)
1106 }
1107
1108 #[cfg(any(test, feature = "seed-support"))]
1109 async fn add_channel_member(
1110 &self,
1111 channel_id: ChannelId,
1112 user_id: UserId,
1113 is_admin: bool,
1114 ) -> Result<()> {
1115 let query = "
1116 INSERT INTO channel_memberships (channel_id, user_id, admin)
1117 VALUES ($1, $2, $3)
1118 ON CONFLICT DO NOTHING
1119 ";
1120 Ok(sqlx::query(query)
1121 .bind(channel_id.0)
1122 .bind(user_id.0)
1123 .bind(is_admin)
1124 .execute(&self.pool)
1125 .await
1126 .map(drop)?)
1127 }
1128
1129 // messages
1130
1131 async fn create_channel_message(
1132 &self,
1133 channel_id: ChannelId,
1134 sender_id: UserId,
1135 body: &str,
1136 timestamp: OffsetDateTime,
1137 nonce: u128,
1138 ) -> Result<MessageId> {
1139 let query = "
1140 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1141 VALUES ($1, $2, $3, $4, $5)
1142 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1143 RETURNING id
1144 ";
1145 Ok(sqlx::query_scalar(query)
1146 .bind(channel_id.0)
1147 .bind(sender_id.0)
1148 .bind(body)
1149 .bind(timestamp)
1150 .bind(Uuid::from_u128(nonce))
1151 .fetch_one(&self.pool)
1152 .await
1153 .map(MessageId)?)
1154 }
1155
1156 async fn get_channel_messages(
1157 &self,
1158 channel_id: ChannelId,
1159 count: usize,
1160 before_id: Option<MessageId>,
1161 ) -> Result<Vec<ChannelMessage>> {
1162 let query = r#"
1163 SELECT * FROM (
1164 SELECT
1165 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1166 FROM
1167 channel_messages
1168 WHERE
1169 channel_id = $1 AND
1170 id < $2
1171 ORDER BY id DESC
1172 LIMIT $3
1173 ) as recent_messages
1174 ORDER BY id ASC
1175 "#;
1176 Ok(sqlx::query_as(query)
1177 .bind(channel_id.0)
1178 .bind(before_id.unwrap_or(MessageId::MAX))
1179 .bind(count as i64)
1180 .fetch_all(&self.pool)
1181 .await?)
1182 }
1183
1184 #[cfg(test)]
1185 async fn teardown(&self, url: &str) {
1186 use util::ResultExt;
1187
1188 let query = "
1189 SELECT pg_terminate_backend(pg_stat_activity.pid)
1190 FROM pg_stat_activity
1191 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1192 ";
1193 sqlx::query(query).execute(&self.pool).await.log_err();
1194 self.pool.close().await;
1195 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1196 .await
1197 .log_err();
1198 }
1199
1200 #[cfg(test)]
1201 fn as_fake(&self) -> Option<&tests::FakeDb> {
1202 None
1203 }
1204}
1205
1206macro_rules! id_type {
1207 ($name:ident) => {
1208 #[derive(
1209 Clone,
1210 Copy,
1211 Debug,
1212 Default,
1213 PartialEq,
1214 Eq,
1215 PartialOrd,
1216 Ord,
1217 Hash,
1218 sqlx::Type,
1219 Serialize,
1220 Deserialize,
1221 )]
1222 #[sqlx(transparent)]
1223 #[serde(transparent)]
1224 pub struct $name(pub i32);
1225
1226 impl $name {
1227 #[allow(unused)]
1228 pub const MAX: Self = Self(i32::MAX);
1229
1230 #[allow(unused)]
1231 pub fn from_proto(value: u64) -> Self {
1232 Self(value as i32)
1233 }
1234
1235 #[allow(unused)]
1236 pub fn to_proto(&self) -> u64 {
1237 self.0 as u64
1238 }
1239 }
1240
1241 impl std::fmt::Display for $name {
1242 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1243 self.0.fmt(f)
1244 }
1245 }
1246 };
1247}
1248
1249id_type!(UserId);
1250#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1251pub struct User {
1252 pub id: UserId,
1253 pub github_login: String,
1254 pub email_address: Option<String>,
1255 pub admin: bool,
1256 pub invite_code: Option<String>,
1257 pub invite_count: i32,
1258 pub connected_once: bool,
1259}
1260
1261id_type!(ProjectId);
1262#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1263pub struct Project {
1264 pub id: ProjectId,
1265 pub host_user_id: UserId,
1266 pub unregistered: bool,
1267}
1268
1269#[derive(Clone, Debug, PartialEq, Serialize)]
1270pub struct UserActivitySummary {
1271 pub id: UserId,
1272 pub github_login: String,
1273 pub project_activity: Vec<(ProjectId, Duration)>,
1274}
1275
1276#[derive(Clone, Debug, PartialEq, Serialize)]
1277pub struct UserActivityPeriod {
1278 project_id: ProjectId,
1279 start: OffsetDateTime,
1280 end: OffsetDateTime,
1281 extensions: HashMap<String, usize>,
1282}
1283
1284id_type!(OrgId);
1285#[derive(FromRow)]
1286pub struct Org {
1287 pub id: OrgId,
1288 pub name: String,
1289 pub slug: String,
1290}
1291
1292id_type!(ChannelId);
1293#[derive(Clone, Debug, FromRow, Serialize)]
1294pub struct Channel {
1295 pub id: ChannelId,
1296 pub name: String,
1297 pub owner_id: i32,
1298 pub owner_is_user: bool,
1299}
1300
1301id_type!(MessageId);
1302#[derive(Clone, Debug, FromRow)]
1303pub struct ChannelMessage {
1304 pub id: MessageId,
1305 pub channel_id: ChannelId,
1306 pub sender_id: UserId,
1307 pub body: String,
1308 pub sent_at: OffsetDateTime,
1309 pub nonce: Uuid,
1310}
1311
1312#[derive(Clone, Debug, PartialEq, Eq)]
1313pub enum Contact {
1314 Accepted {
1315 user_id: UserId,
1316 should_notify: bool,
1317 },
1318 Outgoing {
1319 user_id: UserId,
1320 },
1321 Incoming {
1322 user_id: UserId,
1323 should_notify: bool,
1324 },
1325}
1326
1327impl Contact {
1328 pub fn user_id(&self) -> UserId {
1329 match self {
1330 Contact::Accepted { user_id, .. } => *user_id,
1331 Contact::Outgoing { user_id } => *user_id,
1332 Contact::Incoming { user_id, .. } => *user_id,
1333 }
1334 }
1335}
1336
1337#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1338pub struct IncomingContactRequest {
1339 pub requester_id: UserId,
1340 pub should_notify: bool,
1341}
1342
1343fn fuzzy_like_string(string: &str) -> String {
1344 let mut result = String::with_capacity(string.len() * 2 + 1);
1345 for c in string.chars() {
1346 if c.is_alphanumeric() {
1347 result.push('%');
1348 result.push(c);
1349 }
1350 }
1351 result.push('%');
1352 result
1353}
1354
1355#[cfg(test)]
1356pub mod tests {
1357 use super::*;
1358 use anyhow::anyhow;
1359 use collections::BTreeMap;
1360 use gpui::executor::Background;
1361 use lazy_static::lazy_static;
1362 use parking_lot::Mutex;
1363 use rand::prelude::*;
1364 use sqlx::{
1365 migrate::{MigrateDatabase, Migrator},
1366 Postgres,
1367 };
1368 use std::{path::Path, sync::Arc};
1369 use util::post_inc;
1370
1371 #[tokio::test(flavor = "multi_thread")]
1372 async fn test_get_users_by_ids() {
1373 for test_db in [
1374 TestDb::postgres().await,
1375 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1376 ] {
1377 let db = test_db.db();
1378
1379 let user = db.create_user("user", None, false).await.unwrap();
1380 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1381 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1382 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1383
1384 assert_eq!(
1385 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1386 .await
1387 .unwrap(),
1388 vec![
1389 User {
1390 id: user,
1391 github_login: "user".to_string(),
1392 admin: false,
1393 ..Default::default()
1394 },
1395 User {
1396 id: friend1,
1397 github_login: "friend-1".to_string(),
1398 admin: false,
1399 ..Default::default()
1400 },
1401 User {
1402 id: friend2,
1403 github_login: "friend-2".to_string(),
1404 admin: false,
1405 ..Default::default()
1406 },
1407 User {
1408 id: friend3,
1409 github_login: "friend-3".to_string(),
1410 admin: false,
1411 ..Default::default()
1412 }
1413 ]
1414 );
1415 }
1416 }
1417
1418 #[tokio::test(flavor = "multi_thread")]
1419 async fn test_create_users() {
1420 let db = TestDb::postgres().await;
1421 let db = db.db();
1422
1423 // Create the first batch of users, ensuring invite counts are assigned
1424 // correctly and the respective invite codes are unique.
1425 let user_ids_batch_1 = db
1426 .create_users(vec![
1427 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1428 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1429 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1430 ])
1431 .await
1432 .unwrap();
1433 assert_eq!(user_ids_batch_1.len(), 3);
1434
1435 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1436 assert_eq!(users.len(), 3);
1437 assert_eq!(users[0].github_login, "user1");
1438 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1439 assert_eq!(users[0].invite_count, 5);
1440 assert_eq!(users[1].github_login, "user2");
1441 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1442 assert_eq!(users[1].invite_count, 4);
1443 assert_eq!(users[2].github_login, "user3");
1444 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1445 assert_eq!(users[2].invite_count, 3);
1446
1447 let invite_code_1 = users[0].invite_code.clone().unwrap();
1448 let invite_code_2 = users[1].invite_code.clone().unwrap();
1449 let invite_code_3 = users[2].invite_code.clone().unwrap();
1450 assert_ne!(invite_code_1, invite_code_2);
1451 assert_ne!(invite_code_1, invite_code_3);
1452 assert_ne!(invite_code_2, invite_code_3);
1453
1454 // Create the second batch of users and include a user that is already in the database, ensuring
1455 // the invite count for the existing user is updated without changing their invite code.
1456 let user_ids_batch_2 = db
1457 .create_users(vec![
1458 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1459 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1460 ])
1461 .await
1462 .unwrap();
1463 assert_eq!(user_ids_batch_2.len(), 2);
1464 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1465
1466 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1467 assert_eq!(users.len(), 2);
1468 assert_eq!(users[0].github_login, "user2");
1469 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1470 assert_eq!(users[0].invite_count, 10);
1471 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1472 assert_eq!(users[1].github_login, "user4");
1473 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1474 assert_eq!(users[1].invite_count, 2);
1475
1476 let invite_code_4 = users[1].invite_code.clone().unwrap();
1477 assert_ne!(invite_code_4, invite_code_1);
1478 assert_ne!(invite_code_4, invite_code_2);
1479 assert_ne!(invite_code_4, invite_code_3);
1480 }
1481
1482 #[tokio::test(flavor = "multi_thread")]
1483 async fn test_worktree_extensions() {
1484 let test_db = TestDb::postgres().await;
1485 let db = test_db.db();
1486
1487 let user = db.create_user("user_1", None, false).await.unwrap();
1488 let project = db.register_project(user).await.unwrap();
1489
1490 db.update_worktree_extensions(project, 100, Default::default())
1491 .await
1492 .unwrap();
1493 db.update_worktree_extensions(
1494 project,
1495 100,
1496 [("rs".to_string(), 5), ("md".to_string(), 3)]
1497 .into_iter()
1498 .collect(),
1499 )
1500 .await
1501 .unwrap();
1502 db.update_worktree_extensions(
1503 project,
1504 100,
1505 [("rs".to_string(), 6), ("md".to_string(), 5)]
1506 .into_iter()
1507 .collect(),
1508 )
1509 .await
1510 .unwrap();
1511 db.update_worktree_extensions(
1512 project,
1513 101,
1514 [("ts".to_string(), 2), ("md".to_string(), 1)]
1515 .into_iter()
1516 .collect(),
1517 )
1518 .await
1519 .unwrap();
1520
1521 assert_eq!(
1522 db.get_project_extensions(project).await.unwrap(),
1523 [
1524 (
1525 100,
1526 [("rs".into(), 6), ("md".into(), 5),]
1527 .into_iter()
1528 .collect::<HashMap<_, _>>()
1529 ),
1530 (
1531 101,
1532 [("ts".into(), 2), ("md".into(), 1),]
1533 .into_iter()
1534 .collect::<HashMap<_, _>>()
1535 )
1536 ]
1537 .into_iter()
1538 .collect()
1539 );
1540 }
1541
1542 #[tokio::test(flavor = "multi_thread")]
1543 async fn test_project_activity() {
1544 let test_db = TestDb::postgres().await;
1545 let db = test_db.db();
1546
1547 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1548 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1549 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1550 let project_1 = db.register_project(user_1).await.unwrap();
1551 db.update_worktree_extensions(
1552 project_1,
1553 1,
1554 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1555 )
1556 .await
1557 .unwrap();
1558 let project_2 = db.register_project(user_2).await.unwrap();
1559 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1560
1561 // User 2 opens a project
1562 let t1 = t0 + Duration::from_secs(10);
1563 db.record_user_activity(t0..t1, &[(user_2, project_2)])
1564 .await
1565 .unwrap();
1566
1567 let t2 = t1 + Duration::from_secs(10);
1568 db.record_user_activity(t1..t2, &[(user_2, project_2)])
1569 .await
1570 .unwrap();
1571
1572 // User 1 joins the project
1573 let t3 = t2 + Duration::from_secs(10);
1574 db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1575 .await
1576 .unwrap();
1577
1578 // User 1 opens another project
1579 let t4 = t3 + Duration::from_secs(10);
1580 db.record_user_activity(
1581 t3..t4,
1582 &[
1583 (user_2, project_2),
1584 (user_1, project_2),
1585 (user_1, project_1),
1586 ],
1587 )
1588 .await
1589 .unwrap();
1590
1591 // User 3 joins that project
1592 let t5 = t4 + Duration::from_secs(10);
1593 db.record_user_activity(
1594 t4..t5,
1595 &[
1596 (user_2, project_2),
1597 (user_1, project_2),
1598 (user_1, project_1),
1599 (user_3, project_1),
1600 ],
1601 )
1602 .await
1603 .unwrap();
1604
1605 // User 2 leaves
1606 let t6 = t5 + Duration::from_secs(5);
1607 db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1608 .await
1609 .unwrap();
1610
1611 let t7 = t6 + Duration::from_secs(60);
1612 let t8 = t7 + Duration::from_secs(10);
1613 db.record_user_activity(t7..t8, &[(user_1, project_1)])
1614 .await
1615 .unwrap();
1616
1617 assert_eq!(
1618 db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1619 &[
1620 UserActivitySummary {
1621 id: user_1,
1622 github_login: "user_1".to_string(),
1623 project_activity: vec![
1624 (project_2, Duration::from_secs(30)),
1625 (project_1, Duration::from_secs(25))
1626 ]
1627 },
1628 UserActivitySummary {
1629 id: user_2,
1630 github_login: "user_2".to_string(),
1631 project_activity: vec![(project_2, Duration::from_secs(50))]
1632 },
1633 UserActivitySummary {
1634 id: user_3,
1635 github_login: "user_3".to_string(),
1636 project_activity: vec![(project_1, Duration::from_secs(15))]
1637 },
1638 ]
1639 );
1640 assert_eq!(
1641 db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1642 &[
1643 UserActivityPeriod {
1644 project_id: project_1,
1645 start: t3,
1646 end: t6,
1647 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1648 },
1649 UserActivityPeriod {
1650 project_id: project_2,
1651 start: t3,
1652 end: t5,
1653 extensions: Default::default(),
1654 },
1655 ]
1656 );
1657 assert_eq!(
1658 db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1659 &[
1660 UserActivityPeriod {
1661 project_id: project_2,
1662 start: t2,
1663 end: t5,
1664 extensions: Default::default(),
1665 },
1666 UserActivityPeriod {
1667 project_id: project_1,
1668 start: t3,
1669 end: t6,
1670 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1671 },
1672 UserActivityPeriod {
1673 project_id: project_1,
1674 start: t7,
1675 end: t8,
1676 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1677 },
1678 ]
1679 );
1680 }
1681
1682 #[tokio::test(flavor = "multi_thread")]
1683 async fn test_recent_channel_messages() {
1684 for test_db in [
1685 TestDb::postgres().await,
1686 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1687 ] {
1688 let db = test_db.db();
1689 let user = db.create_user("user", None, false).await.unwrap();
1690 let org = db.create_org("org", "org").await.unwrap();
1691 let channel = db.create_org_channel(org, "channel").await.unwrap();
1692 for i in 0..10 {
1693 db.create_channel_message(
1694 channel,
1695 user,
1696 &i.to_string(),
1697 OffsetDateTime::now_utc(),
1698 i,
1699 )
1700 .await
1701 .unwrap();
1702 }
1703
1704 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1705 assert_eq!(
1706 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1707 ["5", "6", "7", "8", "9"]
1708 );
1709
1710 let prev_messages = db
1711 .get_channel_messages(channel, 4, Some(messages[0].id))
1712 .await
1713 .unwrap();
1714 assert_eq!(
1715 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1716 ["1", "2", "3", "4"]
1717 );
1718 }
1719 }
1720
1721 #[tokio::test(flavor = "multi_thread")]
1722 async fn test_channel_message_nonces() {
1723 for test_db in [
1724 TestDb::postgres().await,
1725 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1726 ] {
1727 let db = test_db.db();
1728 let user = db.create_user("user", None, false).await.unwrap();
1729 let org = db.create_org("org", "org").await.unwrap();
1730 let channel = db.create_org_channel(org, "channel").await.unwrap();
1731
1732 let msg1_id = db
1733 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1734 .await
1735 .unwrap();
1736 let msg2_id = db
1737 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1738 .await
1739 .unwrap();
1740 let msg3_id = db
1741 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1742 .await
1743 .unwrap();
1744 let msg4_id = db
1745 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1746 .await
1747 .unwrap();
1748
1749 assert_ne!(msg1_id, msg2_id);
1750 assert_eq!(msg1_id, msg3_id);
1751 assert_eq!(msg2_id, msg4_id);
1752 }
1753 }
1754
1755 #[tokio::test(flavor = "multi_thread")]
1756 async fn test_create_access_tokens() {
1757 let test_db = TestDb::postgres().await;
1758 let db = test_db.db();
1759 let user = db.create_user("the-user", None, false).await.unwrap();
1760
1761 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1762 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1763 assert_eq!(
1764 db.get_access_token_hashes(user).await.unwrap(),
1765 &["h2".to_string(), "h1".to_string()]
1766 );
1767
1768 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1769 assert_eq!(
1770 db.get_access_token_hashes(user).await.unwrap(),
1771 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1772 );
1773
1774 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1775 assert_eq!(
1776 db.get_access_token_hashes(user).await.unwrap(),
1777 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1778 );
1779
1780 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1781 assert_eq!(
1782 db.get_access_token_hashes(user).await.unwrap(),
1783 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1784 );
1785 }
1786
1787 #[test]
1788 fn test_fuzzy_like_string() {
1789 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1790 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1791 assert_eq!(fuzzy_like_string(" z "), "%z%");
1792 }
1793
1794 #[tokio::test(flavor = "multi_thread")]
1795 async fn test_fuzzy_search_users() {
1796 let test_db = TestDb::postgres().await;
1797 let db = test_db.db();
1798 for github_login in [
1799 "California",
1800 "colorado",
1801 "oregon",
1802 "washington",
1803 "florida",
1804 "delaware",
1805 "rhode-island",
1806 ] {
1807 db.create_user(github_login, None, false).await.unwrap();
1808 }
1809
1810 assert_eq!(
1811 fuzzy_search_user_names(db, "clr").await,
1812 &["colorado", "California"]
1813 );
1814 assert_eq!(
1815 fuzzy_search_user_names(db, "ro").await,
1816 &["rhode-island", "colorado", "oregon"],
1817 );
1818
1819 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1820 db.fuzzy_search_users(query, 10)
1821 .await
1822 .unwrap()
1823 .into_iter()
1824 .map(|user| user.github_login)
1825 .collect::<Vec<_>>()
1826 }
1827 }
1828
1829 #[tokio::test(flavor = "multi_thread")]
1830 async fn test_add_contacts() {
1831 for test_db in [
1832 TestDb::postgres().await,
1833 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1834 ] {
1835 let db = test_db.db();
1836
1837 let user_1 = db.create_user("user1", None, false).await.unwrap();
1838 let user_2 = db.create_user("user2", None, false).await.unwrap();
1839 let user_3 = db.create_user("user3", None, false).await.unwrap();
1840
1841 // User starts with no contacts
1842 assert_eq!(
1843 db.get_contacts(user_1).await.unwrap(),
1844 vec![Contact::Accepted {
1845 user_id: user_1,
1846 should_notify: false
1847 }],
1848 );
1849
1850 // User requests a contact. Both users see the pending request.
1851 db.send_contact_request(user_1, user_2).await.unwrap();
1852 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1853 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1854 assert_eq!(
1855 db.get_contacts(user_1).await.unwrap(),
1856 &[
1857 Contact::Accepted {
1858 user_id: user_1,
1859 should_notify: false
1860 },
1861 Contact::Outgoing { user_id: user_2 }
1862 ],
1863 );
1864 assert_eq!(
1865 db.get_contacts(user_2).await.unwrap(),
1866 &[
1867 Contact::Incoming {
1868 user_id: user_1,
1869 should_notify: true
1870 },
1871 Contact::Accepted {
1872 user_id: user_2,
1873 should_notify: false
1874 },
1875 ]
1876 );
1877
1878 // User 2 dismisses the contact request notification without accepting or rejecting.
1879 // We shouldn't notify them again.
1880 db.dismiss_contact_notification(user_1, user_2)
1881 .await
1882 .unwrap_err();
1883 db.dismiss_contact_notification(user_2, user_1)
1884 .await
1885 .unwrap();
1886 assert_eq!(
1887 db.get_contacts(user_2).await.unwrap(),
1888 &[
1889 Contact::Incoming {
1890 user_id: user_1,
1891 should_notify: false
1892 },
1893 Contact::Accepted {
1894 user_id: user_2,
1895 should_notify: false
1896 },
1897 ]
1898 );
1899
1900 // User can't accept their own contact request
1901 db.respond_to_contact_request(user_1, user_2, true)
1902 .await
1903 .unwrap_err();
1904
1905 // User accepts a contact request. Both users see the contact.
1906 db.respond_to_contact_request(user_2, user_1, true)
1907 .await
1908 .unwrap();
1909 assert_eq!(
1910 db.get_contacts(user_1).await.unwrap(),
1911 &[
1912 Contact::Accepted {
1913 user_id: user_1,
1914 should_notify: false
1915 },
1916 Contact::Accepted {
1917 user_id: user_2,
1918 should_notify: true
1919 }
1920 ],
1921 );
1922 assert!(db.has_contact(user_1, user_2).await.unwrap());
1923 assert!(db.has_contact(user_2, user_1).await.unwrap());
1924 assert_eq!(
1925 db.get_contacts(user_2).await.unwrap(),
1926 &[
1927 Contact::Accepted {
1928 user_id: user_1,
1929 should_notify: false,
1930 },
1931 Contact::Accepted {
1932 user_id: user_2,
1933 should_notify: false,
1934 },
1935 ]
1936 );
1937
1938 // Users cannot re-request existing contacts.
1939 db.send_contact_request(user_1, user_2).await.unwrap_err();
1940 db.send_contact_request(user_2, user_1).await.unwrap_err();
1941
1942 // Users can't dismiss notifications of them accepting other users' requests.
1943 db.dismiss_contact_notification(user_2, user_1)
1944 .await
1945 .unwrap_err();
1946 assert_eq!(
1947 db.get_contacts(user_1).await.unwrap(),
1948 &[
1949 Contact::Accepted {
1950 user_id: user_1,
1951 should_notify: false
1952 },
1953 Contact::Accepted {
1954 user_id: user_2,
1955 should_notify: true,
1956 },
1957 ]
1958 );
1959
1960 // Users can dismiss notifications of other users accepting their requests.
1961 db.dismiss_contact_notification(user_1, user_2)
1962 .await
1963 .unwrap();
1964 assert_eq!(
1965 db.get_contacts(user_1).await.unwrap(),
1966 &[
1967 Contact::Accepted {
1968 user_id: user_1,
1969 should_notify: false
1970 },
1971 Contact::Accepted {
1972 user_id: user_2,
1973 should_notify: false,
1974 },
1975 ]
1976 );
1977
1978 // Users send each other concurrent contact requests and
1979 // see that they are immediately accepted.
1980 db.send_contact_request(user_1, user_3).await.unwrap();
1981 db.send_contact_request(user_3, user_1).await.unwrap();
1982 assert_eq!(
1983 db.get_contacts(user_1).await.unwrap(),
1984 &[
1985 Contact::Accepted {
1986 user_id: user_1,
1987 should_notify: false
1988 },
1989 Contact::Accepted {
1990 user_id: user_2,
1991 should_notify: false,
1992 },
1993 Contact::Accepted {
1994 user_id: user_3,
1995 should_notify: false
1996 },
1997 ]
1998 );
1999 assert_eq!(
2000 db.get_contacts(user_3).await.unwrap(),
2001 &[
2002 Contact::Accepted {
2003 user_id: user_1,
2004 should_notify: false
2005 },
2006 Contact::Accepted {
2007 user_id: user_3,
2008 should_notify: false
2009 }
2010 ],
2011 );
2012
2013 // User declines a contact request. Both users see that it is gone.
2014 db.send_contact_request(user_2, user_3).await.unwrap();
2015 db.respond_to_contact_request(user_3, user_2, false)
2016 .await
2017 .unwrap();
2018 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2019 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2020 assert_eq!(
2021 db.get_contacts(user_2).await.unwrap(),
2022 &[
2023 Contact::Accepted {
2024 user_id: user_1,
2025 should_notify: false
2026 },
2027 Contact::Accepted {
2028 user_id: user_2,
2029 should_notify: false
2030 }
2031 ]
2032 );
2033 assert_eq!(
2034 db.get_contacts(user_3).await.unwrap(),
2035 &[
2036 Contact::Accepted {
2037 user_id: user_1,
2038 should_notify: false
2039 },
2040 Contact::Accepted {
2041 user_id: user_3,
2042 should_notify: false
2043 }
2044 ],
2045 );
2046 }
2047 }
2048
2049 #[tokio::test(flavor = "multi_thread")]
2050 async fn test_invite_codes() {
2051 let postgres = TestDb::postgres().await;
2052 let db = postgres.db();
2053 let user1 = db.create_user("user-1", None, false).await.unwrap();
2054
2055 // Initially, user 1 has no invite code
2056 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2057
2058 // Setting invite count to 0 when no code is assigned does not assign a new code
2059 db.set_invite_count(user1, 0).await.unwrap();
2060 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2061
2062 // User 1 creates an invite code that can be used twice.
2063 db.set_invite_count(user1, 2).await.unwrap();
2064 let (invite_code, invite_count) =
2065 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2066 assert_eq!(invite_count, 2);
2067
2068 // User 2 redeems the invite code and becomes a contact of user 1.
2069 let user2 = db
2070 .redeem_invite_code(&invite_code, "user-2", None)
2071 .await
2072 .unwrap();
2073 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2074 assert_eq!(invite_count, 1);
2075 assert_eq!(
2076 db.get_contacts(user1).await.unwrap(),
2077 [
2078 Contact::Accepted {
2079 user_id: user1,
2080 should_notify: false
2081 },
2082 Contact::Accepted {
2083 user_id: user2,
2084 should_notify: true
2085 }
2086 ]
2087 );
2088 assert_eq!(
2089 db.get_contacts(user2).await.unwrap(),
2090 [
2091 Contact::Accepted {
2092 user_id: user1,
2093 should_notify: false
2094 },
2095 Contact::Accepted {
2096 user_id: user2,
2097 should_notify: false
2098 }
2099 ]
2100 );
2101
2102 // User 3 redeems the invite code and becomes a contact of user 1.
2103 let user3 = db
2104 .redeem_invite_code(&invite_code, "user-3", None)
2105 .await
2106 .unwrap();
2107 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2108 assert_eq!(invite_count, 0);
2109 assert_eq!(
2110 db.get_contacts(user1).await.unwrap(),
2111 [
2112 Contact::Accepted {
2113 user_id: user1,
2114 should_notify: false
2115 },
2116 Contact::Accepted {
2117 user_id: user2,
2118 should_notify: true
2119 },
2120 Contact::Accepted {
2121 user_id: user3,
2122 should_notify: true
2123 }
2124 ]
2125 );
2126 assert_eq!(
2127 db.get_contacts(user3).await.unwrap(),
2128 [
2129 Contact::Accepted {
2130 user_id: user1,
2131 should_notify: false
2132 },
2133 Contact::Accepted {
2134 user_id: user3,
2135 should_notify: false
2136 },
2137 ]
2138 );
2139
2140 // Trying to reedem the code for the third time results in an error.
2141 db.redeem_invite_code(&invite_code, "user-4", None)
2142 .await
2143 .unwrap_err();
2144
2145 // Invite count can be updated after the code has been created.
2146 db.set_invite_count(user1, 2).await.unwrap();
2147 let (latest_code, invite_count) =
2148 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2149 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2150 assert_eq!(invite_count, 2);
2151
2152 // User 4 can now redeem the invite code and becomes a contact of user 1.
2153 let user4 = db
2154 .redeem_invite_code(&invite_code, "user-4", None)
2155 .await
2156 .unwrap();
2157 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2158 assert_eq!(invite_count, 1);
2159 assert_eq!(
2160 db.get_contacts(user1).await.unwrap(),
2161 [
2162 Contact::Accepted {
2163 user_id: user1,
2164 should_notify: false
2165 },
2166 Contact::Accepted {
2167 user_id: user2,
2168 should_notify: true
2169 },
2170 Contact::Accepted {
2171 user_id: user3,
2172 should_notify: true
2173 },
2174 Contact::Accepted {
2175 user_id: user4,
2176 should_notify: true
2177 }
2178 ]
2179 );
2180 assert_eq!(
2181 db.get_contacts(user4).await.unwrap(),
2182 [
2183 Contact::Accepted {
2184 user_id: user1,
2185 should_notify: false
2186 },
2187 Contact::Accepted {
2188 user_id: user4,
2189 should_notify: false
2190 },
2191 ]
2192 );
2193
2194 // An existing user cannot redeem invite codes.
2195 db.redeem_invite_code(&invite_code, "user-2", None)
2196 .await
2197 .unwrap_err();
2198 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2199 assert_eq!(invite_count, 1);
2200 }
2201
2202 pub struct TestDb {
2203 pub db: Option<Arc<dyn Db>>,
2204 pub url: String,
2205 }
2206
2207 impl TestDb {
2208 pub async fn postgres() -> Self {
2209 lazy_static! {
2210 static ref LOCK: Mutex<()> = Mutex::new(());
2211 }
2212
2213 let _guard = LOCK.lock();
2214 let mut rng = StdRng::from_entropy();
2215 let name = format!("zed-test-{}", rng.gen::<u128>());
2216 let url = format!("postgres://postgres@localhost/{}", name);
2217 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2218 Postgres::create_database(&url)
2219 .await
2220 .expect("failed to create test db");
2221 let db = PostgresDb::new(&url, 5).await.unwrap();
2222 let migrator = Migrator::new(migrations_path).await.unwrap();
2223 migrator.run(&db.pool).await.unwrap();
2224 Self {
2225 db: Some(Arc::new(db)),
2226 url,
2227 }
2228 }
2229
2230 pub fn fake(background: Arc<Background>) -> Self {
2231 Self {
2232 db: Some(Arc::new(FakeDb::new(background))),
2233 url: Default::default(),
2234 }
2235 }
2236
2237 pub fn db(&self) -> &Arc<dyn Db> {
2238 self.db.as_ref().unwrap()
2239 }
2240 }
2241
2242 impl Drop for TestDb {
2243 fn drop(&mut self) {
2244 if let Some(db) = self.db.take() {
2245 futures::executor::block_on(db.teardown(&self.url));
2246 }
2247 }
2248 }
2249
2250 pub struct FakeDb {
2251 background: Arc<Background>,
2252 pub users: Mutex<BTreeMap<UserId, User>>,
2253 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2254 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), usize>>,
2255 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2256 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2257 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2258 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2259 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2260 pub contacts: Mutex<Vec<FakeContact>>,
2261 next_channel_message_id: Mutex<i32>,
2262 next_user_id: Mutex<i32>,
2263 next_org_id: Mutex<i32>,
2264 next_channel_id: Mutex<i32>,
2265 next_project_id: Mutex<i32>,
2266 }
2267
2268 #[derive(Debug)]
2269 pub struct FakeContact {
2270 pub requester_id: UserId,
2271 pub responder_id: UserId,
2272 pub accepted: bool,
2273 pub should_notify: bool,
2274 }
2275
2276 impl FakeDb {
2277 pub fn new(background: Arc<Background>) -> Self {
2278 Self {
2279 background,
2280 users: Default::default(),
2281 next_user_id: Mutex::new(1),
2282 projects: Default::default(),
2283 worktree_extensions: Default::default(),
2284 next_project_id: Mutex::new(1),
2285 orgs: Default::default(),
2286 next_org_id: Mutex::new(1),
2287 org_memberships: Default::default(),
2288 channels: Default::default(),
2289 next_channel_id: Mutex::new(1),
2290 channel_memberships: Default::default(),
2291 channel_messages: Default::default(),
2292 next_channel_message_id: Mutex::new(1),
2293 contacts: Default::default(),
2294 }
2295 }
2296 }
2297
2298 #[async_trait]
2299 impl Db for FakeDb {
2300 async fn create_user(
2301 &self,
2302 github_login: &str,
2303 email_address: Option<&str>,
2304 admin: bool,
2305 ) -> Result<UserId> {
2306 self.background.simulate_random_delay().await;
2307
2308 let mut users = self.users.lock();
2309 if let Some(user) = users
2310 .values()
2311 .find(|user| user.github_login == github_login)
2312 {
2313 Ok(user.id)
2314 } else {
2315 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2316 users.insert(
2317 user_id,
2318 User {
2319 id: user_id,
2320 github_login: github_login.to_string(),
2321 email_address: email_address.map(str::to_string),
2322 admin,
2323 invite_code: None,
2324 invite_count: 0,
2325 connected_once: false,
2326 },
2327 );
2328 Ok(user_id)
2329 }
2330 }
2331
2332 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2333 unimplemented!()
2334 }
2335
2336 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2337 unimplemented!()
2338 }
2339
2340 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2341 unimplemented!()
2342 }
2343
2344 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2345 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2346 }
2347
2348 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2349 self.background.simulate_random_delay().await;
2350 let users = self.users.lock();
2351 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2352 }
2353
2354 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2355 unimplemented!()
2356 }
2357
2358 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2359 Ok(self
2360 .users
2361 .lock()
2362 .values()
2363 .find(|user| user.github_login == github_login)
2364 .cloned())
2365 }
2366
2367 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2368 unimplemented!()
2369 }
2370
2371 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2372 self.background.simulate_random_delay().await;
2373 let mut users = self.users.lock();
2374 let mut user = users
2375 .get_mut(&id)
2376 .ok_or_else(|| anyhow!("user not found"))?;
2377 user.connected_once = connected_once;
2378 Ok(())
2379 }
2380
2381 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2382 unimplemented!()
2383 }
2384
2385 // invite codes
2386
2387 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2388 unimplemented!()
2389 }
2390
2391 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2392 Ok(None)
2393 }
2394
2395 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2396 unimplemented!()
2397 }
2398
2399 async fn redeem_invite_code(
2400 &self,
2401 _code: &str,
2402 _login: &str,
2403 _email_address: Option<&str>,
2404 ) -> Result<UserId> {
2405 unimplemented!()
2406 }
2407
2408 // projects
2409
2410 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2411 self.background.simulate_random_delay().await;
2412 if !self.users.lock().contains_key(&host_user_id) {
2413 Err(anyhow!("no such user"))?;
2414 }
2415
2416 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2417 self.projects.lock().insert(
2418 project_id,
2419 Project {
2420 id: project_id,
2421 host_user_id,
2422 unregistered: false,
2423 },
2424 );
2425 Ok(project_id)
2426 }
2427
2428 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2429 self.projects
2430 .lock()
2431 .get_mut(&project_id)
2432 .ok_or_else(|| anyhow!("no such project"))?
2433 .unregistered = true;
2434 Ok(())
2435 }
2436
2437 async fn update_worktree_extensions(
2438 &self,
2439 project_id: ProjectId,
2440 worktree_id: u64,
2441 extensions: HashMap<String, usize>,
2442 ) -> Result<()> {
2443 self.background.simulate_random_delay().await;
2444 if !self.projects.lock().contains_key(&project_id) {
2445 Err(anyhow!("no such project"))?;
2446 }
2447
2448 for (extension, count) in extensions {
2449 self.worktree_extensions
2450 .lock()
2451 .insert((project_id, worktree_id, extension), count);
2452 }
2453
2454 Ok(())
2455 }
2456
2457 async fn get_project_extensions(
2458 &self,
2459 _project_id: ProjectId,
2460 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2461 unimplemented!()
2462 }
2463
2464 async fn record_user_activity(
2465 &self,
2466 _time_period: Range<OffsetDateTime>,
2467 _active_projects: &[(UserId, ProjectId)],
2468 ) -> Result<()> {
2469 unimplemented!()
2470 }
2471
2472 async fn get_top_users_activity_summary(
2473 &self,
2474 _time_period: Range<OffsetDateTime>,
2475 _limit: usize,
2476 ) -> Result<Vec<UserActivitySummary>> {
2477 unimplemented!()
2478 }
2479
2480 async fn get_user_activity_timeline(
2481 &self,
2482 _time_period: Range<OffsetDateTime>,
2483 _user_id: UserId,
2484 ) -> Result<Vec<UserActivityPeriod>> {
2485 unimplemented!()
2486 }
2487
2488 // contacts
2489
2490 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2491 self.background.simulate_random_delay().await;
2492 let mut contacts = vec![Contact::Accepted {
2493 user_id: id,
2494 should_notify: false,
2495 }];
2496
2497 for contact in self.contacts.lock().iter() {
2498 if contact.requester_id == id {
2499 if contact.accepted {
2500 contacts.push(Contact::Accepted {
2501 user_id: contact.responder_id,
2502 should_notify: contact.should_notify,
2503 });
2504 } else {
2505 contacts.push(Contact::Outgoing {
2506 user_id: contact.responder_id,
2507 });
2508 }
2509 } else if contact.responder_id == id {
2510 if contact.accepted {
2511 contacts.push(Contact::Accepted {
2512 user_id: contact.requester_id,
2513 should_notify: false,
2514 });
2515 } else {
2516 contacts.push(Contact::Incoming {
2517 user_id: contact.requester_id,
2518 should_notify: contact.should_notify,
2519 });
2520 }
2521 }
2522 }
2523
2524 contacts.sort_unstable_by_key(|contact| contact.user_id());
2525 Ok(contacts)
2526 }
2527
2528 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2529 self.background.simulate_random_delay().await;
2530 Ok(self.contacts.lock().iter().any(|contact| {
2531 contact.accepted
2532 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2533 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2534 }))
2535 }
2536
2537 async fn send_contact_request(
2538 &self,
2539 requester_id: UserId,
2540 responder_id: UserId,
2541 ) -> Result<()> {
2542 let mut contacts = self.contacts.lock();
2543 for contact in contacts.iter_mut() {
2544 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2545 if contact.accepted {
2546 Err(anyhow!("contact already exists"))?;
2547 } else {
2548 Err(anyhow!("contact already requested"))?;
2549 }
2550 }
2551 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2552 if contact.accepted {
2553 Err(anyhow!("contact already exists"))?;
2554 } else {
2555 contact.accepted = true;
2556 contact.should_notify = false;
2557 return Ok(());
2558 }
2559 }
2560 }
2561 contacts.push(FakeContact {
2562 requester_id,
2563 responder_id,
2564 accepted: false,
2565 should_notify: true,
2566 });
2567 Ok(())
2568 }
2569
2570 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2571 self.contacts.lock().retain(|contact| {
2572 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2573 });
2574 Ok(())
2575 }
2576
2577 async fn dismiss_contact_notification(
2578 &self,
2579 user_id: UserId,
2580 contact_user_id: UserId,
2581 ) -> Result<()> {
2582 let mut contacts = self.contacts.lock();
2583 for contact in contacts.iter_mut() {
2584 if contact.requester_id == contact_user_id
2585 && contact.responder_id == user_id
2586 && !contact.accepted
2587 {
2588 contact.should_notify = false;
2589 return Ok(());
2590 }
2591 if contact.requester_id == user_id
2592 && contact.responder_id == contact_user_id
2593 && contact.accepted
2594 {
2595 contact.should_notify = false;
2596 return Ok(());
2597 }
2598 }
2599 Err(anyhow!("no such notification"))?
2600 }
2601
2602 async fn respond_to_contact_request(
2603 &self,
2604 responder_id: UserId,
2605 requester_id: UserId,
2606 accept: bool,
2607 ) -> Result<()> {
2608 let mut contacts = self.contacts.lock();
2609 for (ix, contact) in contacts.iter_mut().enumerate() {
2610 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2611 if contact.accepted {
2612 Err(anyhow!("contact already confirmed"))?;
2613 }
2614 if accept {
2615 contact.accepted = true;
2616 contact.should_notify = true;
2617 } else {
2618 contacts.remove(ix);
2619 }
2620 return Ok(());
2621 }
2622 }
2623 Err(anyhow!("no such contact request"))?
2624 }
2625
2626 async fn create_access_token_hash(
2627 &self,
2628 _user_id: UserId,
2629 _access_token_hash: &str,
2630 _max_access_token_count: usize,
2631 ) -> Result<()> {
2632 unimplemented!()
2633 }
2634
2635 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2636 unimplemented!()
2637 }
2638
2639 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2640 unimplemented!()
2641 }
2642
2643 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2644 self.background.simulate_random_delay().await;
2645 let mut orgs = self.orgs.lock();
2646 if orgs.values().any(|org| org.slug == slug) {
2647 Err(anyhow!("org already exists"))?
2648 } else {
2649 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2650 orgs.insert(
2651 org_id,
2652 Org {
2653 id: org_id,
2654 name: name.to_string(),
2655 slug: slug.to_string(),
2656 },
2657 );
2658 Ok(org_id)
2659 }
2660 }
2661
2662 async fn add_org_member(
2663 &self,
2664 org_id: OrgId,
2665 user_id: UserId,
2666 is_admin: bool,
2667 ) -> Result<()> {
2668 self.background.simulate_random_delay().await;
2669 if !self.orgs.lock().contains_key(&org_id) {
2670 Err(anyhow!("org does not exist"))?;
2671 }
2672 if !self.users.lock().contains_key(&user_id) {
2673 Err(anyhow!("user does not exist"))?;
2674 }
2675
2676 self.org_memberships
2677 .lock()
2678 .entry((org_id, user_id))
2679 .or_insert(is_admin);
2680 Ok(())
2681 }
2682
2683 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2684 self.background.simulate_random_delay().await;
2685 if !self.orgs.lock().contains_key(&org_id) {
2686 Err(anyhow!("org does not exist"))?;
2687 }
2688
2689 let mut channels = self.channels.lock();
2690 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2691 channels.insert(
2692 channel_id,
2693 Channel {
2694 id: channel_id,
2695 name: name.to_string(),
2696 owner_id: org_id.0,
2697 owner_is_user: false,
2698 },
2699 );
2700 Ok(channel_id)
2701 }
2702
2703 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2704 self.background.simulate_random_delay().await;
2705 Ok(self
2706 .channels
2707 .lock()
2708 .values()
2709 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2710 .cloned()
2711 .collect())
2712 }
2713
2714 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2715 self.background.simulate_random_delay().await;
2716 let channels = self.channels.lock();
2717 let memberships = self.channel_memberships.lock();
2718 Ok(channels
2719 .values()
2720 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2721 .cloned()
2722 .collect())
2723 }
2724
2725 async fn can_user_access_channel(
2726 &self,
2727 user_id: UserId,
2728 channel_id: ChannelId,
2729 ) -> Result<bool> {
2730 self.background.simulate_random_delay().await;
2731 Ok(self
2732 .channel_memberships
2733 .lock()
2734 .contains_key(&(channel_id, user_id)))
2735 }
2736
2737 async fn add_channel_member(
2738 &self,
2739 channel_id: ChannelId,
2740 user_id: UserId,
2741 is_admin: bool,
2742 ) -> Result<()> {
2743 self.background.simulate_random_delay().await;
2744 if !self.channels.lock().contains_key(&channel_id) {
2745 Err(anyhow!("channel does not exist"))?;
2746 }
2747 if !self.users.lock().contains_key(&user_id) {
2748 Err(anyhow!("user does not exist"))?;
2749 }
2750
2751 self.channel_memberships
2752 .lock()
2753 .entry((channel_id, user_id))
2754 .or_insert(is_admin);
2755 Ok(())
2756 }
2757
2758 async fn create_channel_message(
2759 &self,
2760 channel_id: ChannelId,
2761 sender_id: UserId,
2762 body: &str,
2763 timestamp: OffsetDateTime,
2764 nonce: u128,
2765 ) -> Result<MessageId> {
2766 self.background.simulate_random_delay().await;
2767 if !self.channels.lock().contains_key(&channel_id) {
2768 Err(anyhow!("channel does not exist"))?;
2769 }
2770 if !self.users.lock().contains_key(&sender_id) {
2771 Err(anyhow!("user does not exist"))?;
2772 }
2773
2774 let mut messages = self.channel_messages.lock();
2775 if let Some(message) = messages
2776 .values()
2777 .find(|message| message.nonce.as_u128() == nonce)
2778 {
2779 Ok(message.id)
2780 } else {
2781 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2782 messages.insert(
2783 message_id,
2784 ChannelMessage {
2785 id: message_id,
2786 channel_id,
2787 sender_id,
2788 body: body.to_string(),
2789 sent_at: timestamp,
2790 nonce: Uuid::from_u128(nonce),
2791 },
2792 );
2793 Ok(message_id)
2794 }
2795 }
2796
2797 async fn get_channel_messages(
2798 &self,
2799 channel_id: ChannelId,
2800 count: usize,
2801 before_id: Option<MessageId>,
2802 ) -> Result<Vec<ChannelMessage>> {
2803 let mut messages = self
2804 .channel_messages
2805 .lock()
2806 .values()
2807 .rev()
2808 .filter(|message| {
2809 message.channel_id == channel_id
2810 && message.id < before_id.unwrap_or(MessageId::MAX)
2811 })
2812 .take(count)
2813 .cloned()
2814 .collect::<Vec<_>>();
2815 messages.sort_unstable_by_key(|message| message.id);
2816 Ok(messages)
2817 }
2818
2819 async fn teardown(&self, _: &str) {}
2820
2821 #[cfg(test)]
2822 fn as_fake(&self) -> Option<&FakeDb> {
2823 Some(self)
2824 }
2825 }
2826}