1use std::{cmp, ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::{Deserialize, Serialize};
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
29 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
30 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
31 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
32 async fn destroy_user(&self, id: UserId) -> Result<()>;
33
34 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
35 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
36 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
37 async fn redeem_invite_code(
38 &self,
39 code: &str,
40 login: &str,
41 email_address: Option<&str>,
42 ) -> Result<UserId>;
43
44 /// Registers a new project for the given user.
45 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
46
47 /// Unregisters a project for the given project id.
48 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
49
50 /// Update file counts by extension for the given project and worktree.
51 async fn update_worktree_extensions(
52 &self,
53 project_id: ProjectId,
54 worktree_id: u64,
55 extensions: HashMap<String, u32>,
56 ) -> Result<()>;
57
58 /// Get the file counts on the given project keyed by their worktree and extension.
59 async fn get_project_extensions(
60 &self,
61 project_id: ProjectId,
62 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
63
64 /// Record which users have been active in which projects during
65 /// a given period of time.
66 async fn record_user_activity(
67 &self,
68 time_period: Range<OffsetDateTime>,
69 active_projects: &[(UserId, ProjectId)],
70 ) -> Result<()>;
71
72 /// Get the users that have been most active during the given time period,
73 /// along with the amount of time they have been active in each project.
74 async fn get_top_users_activity_summary(
75 &self,
76 time_period: Range<OffsetDateTime>,
77 max_user_count: usize,
78 ) -> Result<Vec<UserActivitySummary>>;
79
80 /// Get the project activity for the given user and time period.
81 async fn get_user_activity_timeline(
82 &self,
83 time_period: Range<OffsetDateTime>,
84 user_id: UserId,
85 ) -> Result<Vec<UserActivityPeriod>>;
86
87 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
88 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
89 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
90 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
91 async fn dismiss_contact_notification(
92 &self,
93 responder_id: UserId,
94 requester_id: UserId,
95 ) -> Result<()>;
96 async fn respond_to_contact_request(
97 &self,
98 responder_id: UserId,
99 requester_id: UserId,
100 accept: bool,
101 ) -> Result<()>;
102
103 async fn create_access_token_hash(
104 &self,
105 user_id: UserId,
106 access_token_hash: &str,
107 max_access_token_count: usize,
108 ) -> Result<()>;
109 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
110 #[cfg(any(test, feature = "seed-support"))]
111
112 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
113 #[cfg(any(test, feature = "seed-support"))]
114 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
115 #[cfg(any(test, feature = "seed-support"))]
116 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
117 #[cfg(any(test, feature = "seed-support"))]
118 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
119 #[cfg(any(test, feature = "seed-support"))]
120
121 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
122 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
123 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
124 -> Result<bool>;
125 #[cfg(any(test, feature = "seed-support"))]
126 async fn add_channel_member(
127 &self,
128 channel_id: ChannelId,
129 user_id: UserId,
130 is_admin: bool,
131 ) -> Result<()>;
132 async fn create_channel_message(
133 &self,
134 channel_id: ChannelId,
135 sender_id: UserId,
136 body: &str,
137 timestamp: OffsetDateTime,
138 nonce: u128,
139 ) -> Result<MessageId>;
140 async fn get_channel_messages(
141 &self,
142 channel_id: ChannelId,
143 count: usize,
144 before_id: Option<MessageId>,
145 ) -> Result<Vec<ChannelMessage>>;
146 #[cfg(test)]
147 async fn teardown(&self, url: &str);
148 #[cfg(test)]
149 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
150}
151
152pub struct PostgresDb {
153 pool: sqlx::PgPool,
154}
155
156impl PostgresDb {
157 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
158 let pool = DbOptions::new()
159 .max_connections(max_connections)
160 .connect(&url)
161 .await
162 .context("failed to connect to postgres database")?;
163 Ok(Self { pool })
164 }
165}
166
167#[async_trait]
168impl Db for PostgresDb {
169 // users
170
171 async fn create_user(
172 &self,
173 github_login: &str,
174 email_address: Option<&str>,
175 admin: bool,
176 ) -> Result<UserId> {
177 let query = "
178 INSERT INTO users (github_login, email_address, admin)
179 VALUES ($1, $2, $3)
180 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
181 RETURNING id
182 ";
183 Ok(sqlx::query_scalar(query)
184 .bind(github_login)
185 .bind(email_address)
186 .bind(admin)
187 .fetch_one(&self.pool)
188 .await
189 .map(UserId)?)
190 }
191
192 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
193 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
194 Ok(sqlx::query_as(query)
195 .bind(limit as i32)
196 .bind((page * limit) as i32)
197 .fetch_all(&self.pool)
198 .await?)
199 }
200
201 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
202 let mut query = QueryBuilder::new(
203 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
204 );
205 query.push_values(
206 users,
207 |mut query, (github_login, email_address, invite_count)| {
208 query
209 .push_bind(github_login)
210 .push_bind(email_address)
211 .push_bind(false)
212 .push_bind(nanoid!(16))
213 .push_bind(invite_count as i32);
214 },
215 );
216 query.push(
217 "
218 ON CONFLICT (github_login) DO UPDATE SET
219 github_login = excluded.github_login,
220 invite_count = excluded.invite_count,
221 invite_code = CASE WHEN users.invite_code IS NULL
222 THEN excluded.invite_code
223 ELSE users.invite_code
224 END
225 RETURNING id
226 ",
227 );
228
229 let rows = query.build().fetch_all(&self.pool).await?;
230 Ok(rows
231 .into_iter()
232 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
233 .collect())
234 }
235
236 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
237 let like_string = fuzzy_like_string(name_query);
238 let query = "
239 SELECT users.*
240 FROM users
241 WHERE github_login ILIKE $1
242 ORDER BY github_login <-> $2
243 LIMIT $3
244 ";
245 Ok(sqlx::query_as(query)
246 .bind(like_string)
247 .bind(name_query)
248 .bind(limit as i32)
249 .fetch_all(&self.pool)
250 .await?)
251 }
252
253 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
254 let users = self.get_users_by_ids(vec![id]).await?;
255 Ok(users.into_iter().next())
256 }
257
258 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
259 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
260 let query = "
261 SELECT users.*
262 FROM users
263 WHERE users.id = ANY ($1)
264 ";
265 Ok(sqlx::query_as(query)
266 .bind(&ids)
267 .fetch_all(&self.pool)
268 .await?)
269 }
270
271 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
272 let query = format!(
273 "
274 SELECT users.*
275 FROM users
276 WHERE invite_count = 0
277 AND inviter_id IS{} NULL
278 ",
279 if invited_by_another_user { " NOT" } else { "" }
280 );
281
282 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
283 }
284
285 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
286 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
287 Ok(sqlx::query_as(query)
288 .bind(github_login)
289 .fetch_optional(&self.pool)
290 .await?)
291 }
292
293 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
294 let query = "UPDATE users SET admin = $1 WHERE id = $2";
295 Ok(sqlx::query(query)
296 .bind(is_admin)
297 .bind(id.0)
298 .execute(&self.pool)
299 .await
300 .map(drop)?)
301 }
302
303 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
304 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
305 Ok(sqlx::query(query)
306 .bind(connected_once)
307 .bind(id.0)
308 .execute(&self.pool)
309 .await
310 .map(drop)?)
311 }
312
313 async fn destroy_user(&self, id: UserId) -> Result<()> {
314 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
315 sqlx::query(query)
316 .bind(id.0)
317 .execute(&self.pool)
318 .await
319 .map(drop)?;
320 let query = "DELETE FROM users WHERE id = $1;";
321 Ok(sqlx::query(query)
322 .bind(id.0)
323 .execute(&self.pool)
324 .await
325 .map(drop)?)
326 }
327
328 // invite codes
329
330 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
331 let mut tx = self.pool.begin().await?;
332 if count > 0 {
333 sqlx::query(
334 "
335 UPDATE users
336 SET invite_code = $1
337 WHERE id = $2 AND invite_code IS NULL
338 ",
339 )
340 .bind(nanoid!(16))
341 .bind(id)
342 .execute(&mut tx)
343 .await?;
344 }
345
346 sqlx::query(
347 "
348 UPDATE users
349 SET invite_count = $1
350 WHERE id = $2
351 ",
352 )
353 .bind(count as i32)
354 .bind(id)
355 .execute(&mut tx)
356 .await?;
357 tx.commit().await?;
358 Ok(())
359 }
360
361 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
362 let result: Option<(String, i32)> = sqlx::query_as(
363 "
364 SELECT invite_code, invite_count
365 FROM users
366 WHERE id = $1 AND invite_code IS NOT NULL
367 ",
368 )
369 .bind(id)
370 .fetch_optional(&self.pool)
371 .await?;
372 if let Some((code, count)) = result {
373 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
374 } else {
375 Ok(None)
376 }
377 }
378
379 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
380 sqlx::query_as(
381 "
382 SELECT *
383 FROM users
384 WHERE invite_code = $1
385 ",
386 )
387 .bind(code)
388 .fetch_optional(&self.pool)
389 .await?
390 .ok_or_else(|| {
391 Error::Http(
392 StatusCode::NOT_FOUND,
393 "that invite code does not exist".to_string(),
394 )
395 })
396 }
397
398 async fn redeem_invite_code(
399 &self,
400 code: &str,
401 login: &str,
402 email_address: Option<&str>,
403 ) -> Result<UserId> {
404 let mut tx = self.pool.begin().await?;
405
406 let inviter_id: Option<UserId> = sqlx::query_scalar(
407 "
408 UPDATE users
409 SET invite_count = invite_count - 1
410 WHERE
411 invite_code = $1 AND
412 invite_count > 0
413 RETURNING id
414 ",
415 )
416 .bind(code)
417 .fetch_optional(&mut tx)
418 .await?;
419
420 let inviter_id = match inviter_id {
421 Some(inviter_id) => inviter_id,
422 None => {
423 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
424 .bind(code)
425 .fetch_optional(&mut tx)
426 .await?
427 .is_some()
428 {
429 Err(Error::Http(
430 StatusCode::UNAUTHORIZED,
431 "no invites remaining".to_string(),
432 ))?
433 } else {
434 Err(Error::Http(
435 StatusCode::NOT_FOUND,
436 "invite code not found".to_string(),
437 ))?
438 }
439 }
440 };
441
442 let invitee_id = sqlx::query_scalar(
443 "
444 INSERT INTO users
445 (github_login, email_address, admin, inviter_id)
446 VALUES
447 ($1, $2, 'f', $3)
448 RETURNING id
449 ",
450 )
451 .bind(login)
452 .bind(email_address)
453 .bind(inviter_id)
454 .fetch_one(&mut tx)
455 .await
456 .map(UserId)?;
457
458 sqlx::query(
459 "
460 INSERT INTO contacts
461 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
462 VALUES
463 ($1, $2, 't', 't', 't')
464 ",
465 )
466 .bind(inviter_id)
467 .bind(invitee_id)
468 .execute(&mut tx)
469 .await?;
470
471 tx.commit().await?;
472 Ok(invitee_id)
473 }
474
475 // projects
476
477 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
478 Ok(sqlx::query_scalar(
479 "
480 INSERT INTO projects(host_user_id)
481 VALUES ($1)
482 RETURNING id
483 ",
484 )
485 .bind(host_user_id)
486 .fetch_one(&self.pool)
487 .await
488 .map(ProjectId)?)
489 }
490
491 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
492 sqlx::query(
493 "
494 UPDATE projects
495 SET unregistered = 't'
496 WHERE id = $1
497 ",
498 )
499 .bind(project_id)
500 .execute(&self.pool)
501 .await?;
502 Ok(())
503 }
504
505 async fn update_worktree_extensions(
506 &self,
507 project_id: ProjectId,
508 worktree_id: u64,
509 extensions: HashMap<String, u32>,
510 ) -> Result<()> {
511 if extensions.is_empty() {
512 return Ok(());
513 }
514
515 let mut query = QueryBuilder::new(
516 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
517 );
518 query.push_values(extensions, |mut query, (extension, count)| {
519 query
520 .push_bind(project_id)
521 .push_bind(worktree_id as i32)
522 .push_bind(extension)
523 .push_bind(count as i32);
524 });
525 query.push(
526 "
527 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
528 count = excluded.count
529 ",
530 );
531 query.build().execute(&self.pool).await?;
532
533 Ok(())
534 }
535
536 async fn get_project_extensions(
537 &self,
538 project_id: ProjectId,
539 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
540 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
541 struct WorktreeExtension {
542 worktree_id: i32,
543 extension: String,
544 count: i32,
545 }
546
547 let query = "
548 SELECT worktree_id, extension, count
549 FROM worktree_extensions
550 WHERE project_id = $1
551 ";
552 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
553 .bind(&project_id)
554 .fetch_all(&self.pool)
555 .await?;
556
557 let mut extension_counts = HashMap::default();
558 for count in counts {
559 extension_counts
560 .entry(count.worktree_id as u64)
561 .or_insert(HashMap::default())
562 .insert(count.extension, count.count as usize);
563 }
564 Ok(extension_counts)
565 }
566
567 async fn record_user_activity(
568 &self,
569 time_period: Range<OffsetDateTime>,
570 projects: &[(UserId, ProjectId)],
571 ) -> Result<()> {
572 let query = "
573 INSERT INTO project_activity_periods
574 (ended_at, duration_millis, user_id, project_id)
575 VALUES
576 ($1, $2, $3, $4);
577 ";
578
579 let mut tx = self.pool.begin().await?;
580 let duration_millis =
581 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
582 for (user_id, project_id) in projects {
583 sqlx::query(query)
584 .bind(time_period.end)
585 .bind(duration_millis)
586 .bind(user_id)
587 .bind(project_id)
588 .execute(&mut tx)
589 .await?;
590 }
591 tx.commit().await?;
592
593 Ok(())
594 }
595
596 async fn get_top_users_activity_summary(
597 &self,
598 time_period: Range<OffsetDateTime>,
599 max_user_count: usize,
600 ) -> Result<Vec<UserActivitySummary>> {
601 let query = "
602 WITH
603 project_durations AS (
604 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
605 FROM project_activity_periods
606 WHERE $1 < ended_at AND ended_at <= $2
607 GROUP BY user_id, project_id
608 ),
609 user_durations AS (
610 SELECT user_id, SUM(project_duration) as total_duration
611 FROM project_durations
612 GROUP BY user_id
613 ORDER BY total_duration DESC
614 LIMIT $3
615 )
616 SELECT user_durations.user_id, users.github_login, project_id, project_duration
617 FROM user_durations, project_durations, users
618 WHERE
619 user_durations.user_id = project_durations.user_id AND
620 user_durations.user_id = users.id
621 ORDER BY total_duration DESC, user_id ASC
622 ";
623
624 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
625 .bind(time_period.start)
626 .bind(time_period.end)
627 .bind(max_user_count as i32)
628 .fetch(&self.pool);
629
630 let mut result = Vec::<UserActivitySummary>::new();
631 while let Some(row) = rows.next().await {
632 let (user_id, github_login, project_id, duration_millis) = row?;
633 let project_id = project_id;
634 let duration = Duration::from_millis(duration_millis as u64);
635 if let Some(last_summary) = result.last_mut() {
636 if last_summary.id == user_id {
637 last_summary.project_activity.push((project_id, duration));
638 continue;
639 }
640 }
641 result.push(UserActivitySummary {
642 id: user_id,
643 project_activity: vec![(project_id, duration)],
644 github_login,
645 });
646 }
647
648 Ok(result)
649 }
650
651 async fn get_user_activity_timeline(
652 &self,
653 time_period: Range<OffsetDateTime>,
654 user_id: UserId,
655 ) -> Result<Vec<UserActivityPeriod>> {
656 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
657
658 let query = "
659 SELECT
660 project_activity_periods.ended_at,
661 project_activity_periods.duration_millis,
662 project_activity_periods.project_id,
663 worktree_extensions.extension,
664 worktree_extensions.count
665 FROM project_activity_periods
666 LEFT OUTER JOIN
667 worktree_extensions
668 ON
669 project_activity_periods.project_id = worktree_extensions.project_id
670 WHERE
671 project_activity_periods.user_id = $1 AND
672 $2 < project_activity_periods.ended_at AND
673 project_activity_periods.ended_at <= $3
674 ORDER BY project_activity_periods.id ASC
675 ";
676
677 let mut rows = sqlx::query_as::<
678 _,
679 (
680 PrimitiveDateTime,
681 i32,
682 ProjectId,
683 Option<String>,
684 Option<i32>,
685 ),
686 >(query)
687 .bind(user_id)
688 .bind(time_period.start)
689 .bind(time_period.end)
690 .fetch(&self.pool);
691
692 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
693 while let Some(row) = rows.next().await {
694 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
695 let ended_at = ended_at.assume_utc();
696 let duration = Duration::from_millis(duration_millis as u64);
697 let started_at = ended_at - duration;
698 let project_time_periods = time_periods.entry(project_id).or_default();
699
700 if let Some(prev_duration) = project_time_periods.last_mut() {
701 if started_at <= prev_duration.end + COALESCE_THRESHOLD
702 && ended_at >= prev_duration.start
703 {
704 prev_duration.end = cmp::max(prev_duration.end, ended_at);
705 } else {
706 project_time_periods.push(UserActivityPeriod {
707 project_id,
708 start: started_at,
709 end: ended_at,
710 extensions: Default::default(),
711 });
712 }
713 } else {
714 project_time_periods.push(UserActivityPeriod {
715 project_id,
716 start: started_at,
717 end: ended_at,
718 extensions: Default::default(),
719 });
720 }
721
722 if let Some((extension, extension_count)) = extension.zip(extension_count) {
723 project_time_periods
724 .last_mut()
725 .unwrap()
726 .extensions
727 .insert(extension, extension_count as usize);
728 }
729 }
730
731 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
732 durations.sort_unstable_by_key(|duration| duration.start);
733 Ok(durations)
734 }
735
736 // contacts
737
738 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
739 let query = "
740 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
741 FROM contacts
742 WHERE user_id_a = $1 OR user_id_b = $1;
743 ";
744
745 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
746 .bind(user_id)
747 .fetch(&self.pool);
748
749 let mut contacts = vec![Contact::Accepted {
750 user_id,
751 should_notify: false,
752 }];
753 while let Some(row) = rows.next().await {
754 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
755
756 if user_id_a == user_id {
757 if accepted {
758 contacts.push(Contact::Accepted {
759 user_id: user_id_b,
760 should_notify: should_notify && a_to_b,
761 });
762 } else if a_to_b {
763 contacts.push(Contact::Outgoing { user_id: user_id_b })
764 } else {
765 contacts.push(Contact::Incoming {
766 user_id: user_id_b,
767 should_notify,
768 });
769 }
770 } else {
771 if accepted {
772 contacts.push(Contact::Accepted {
773 user_id: user_id_a,
774 should_notify: should_notify && !a_to_b,
775 });
776 } else if a_to_b {
777 contacts.push(Contact::Incoming {
778 user_id: user_id_a,
779 should_notify,
780 });
781 } else {
782 contacts.push(Contact::Outgoing { user_id: user_id_a });
783 }
784 }
785 }
786
787 contacts.sort_unstable_by_key(|contact| contact.user_id());
788
789 Ok(contacts)
790 }
791
792 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
793 let (id_a, id_b) = if user_id_1 < user_id_2 {
794 (user_id_1, user_id_2)
795 } else {
796 (user_id_2, user_id_1)
797 };
798
799 let query = "
800 SELECT 1 FROM contacts
801 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
802 LIMIT 1
803 ";
804 Ok(sqlx::query_scalar::<_, i32>(query)
805 .bind(id_a.0)
806 .bind(id_b.0)
807 .fetch_optional(&self.pool)
808 .await?
809 .is_some())
810 }
811
812 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
813 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
814 (sender_id, receiver_id, true)
815 } else {
816 (receiver_id, sender_id, false)
817 };
818 let query = "
819 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
820 VALUES ($1, $2, $3, 'f', 't')
821 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
822 SET
823 accepted = 't',
824 should_notify = 'f'
825 WHERE
826 NOT contacts.accepted AND
827 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
828 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
829 ";
830 let result = sqlx::query(query)
831 .bind(id_a.0)
832 .bind(id_b.0)
833 .bind(a_to_b)
834 .execute(&self.pool)
835 .await?;
836
837 if result.rows_affected() == 1 {
838 Ok(())
839 } else {
840 Err(anyhow!("contact already requested"))?
841 }
842 }
843
844 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
845 let (id_a, id_b) = if responder_id < requester_id {
846 (responder_id, requester_id)
847 } else {
848 (requester_id, responder_id)
849 };
850 let query = "
851 DELETE FROM contacts
852 WHERE user_id_a = $1 AND user_id_b = $2;
853 ";
854 let result = sqlx::query(query)
855 .bind(id_a.0)
856 .bind(id_b.0)
857 .execute(&self.pool)
858 .await?;
859
860 if result.rows_affected() == 1 {
861 Ok(())
862 } else {
863 Err(anyhow!("no such contact"))?
864 }
865 }
866
867 async fn dismiss_contact_notification(
868 &self,
869 user_id: UserId,
870 contact_user_id: UserId,
871 ) -> Result<()> {
872 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
873 (user_id, contact_user_id, true)
874 } else {
875 (contact_user_id, user_id, false)
876 };
877
878 let query = "
879 UPDATE contacts
880 SET should_notify = 'f'
881 WHERE
882 user_id_a = $1 AND user_id_b = $2 AND
883 (
884 (a_to_b = $3 AND accepted) OR
885 (a_to_b != $3 AND NOT accepted)
886 );
887 ";
888
889 let result = sqlx::query(query)
890 .bind(id_a.0)
891 .bind(id_b.0)
892 .bind(a_to_b)
893 .execute(&self.pool)
894 .await?;
895
896 if result.rows_affected() == 0 {
897 Err(anyhow!("no such contact request"))?;
898 }
899
900 Ok(())
901 }
902
903 async fn respond_to_contact_request(
904 &self,
905 responder_id: UserId,
906 requester_id: UserId,
907 accept: bool,
908 ) -> Result<()> {
909 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
910 (responder_id, requester_id, false)
911 } else {
912 (requester_id, responder_id, true)
913 };
914 let result = if accept {
915 let query = "
916 UPDATE contacts
917 SET accepted = 't', should_notify = 't'
918 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
919 ";
920 sqlx::query(query)
921 .bind(id_a.0)
922 .bind(id_b.0)
923 .bind(a_to_b)
924 .execute(&self.pool)
925 .await?
926 } else {
927 let query = "
928 DELETE FROM contacts
929 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
930 ";
931 sqlx::query(query)
932 .bind(id_a.0)
933 .bind(id_b.0)
934 .bind(a_to_b)
935 .execute(&self.pool)
936 .await?
937 };
938 if result.rows_affected() == 1 {
939 Ok(())
940 } else {
941 Err(anyhow!("no such contact request"))?
942 }
943 }
944
945 // access tokens
946
947 async fn create_access_token_hash(
948 &self,
949 user_id: UserId,
950 access_token_hash: &str,
951 max_access_token_count: usize,
952 ) -> Result<()> {
953 let insert_query = "
954 INSERT INTO access_tokens (user_id, hash)
955 VALUES ($1, $2);
956 ";
957 let cleanup_query = "
958 DELETE FROM access_tokens
959 WHERE id IN (
960 SELECT id from access_tokens
961 WHERE user_id = $1
962 ORDER BY id DESC
963 OFFSET $3
964 )
965 ";
966
967 let mut tx = self.pool.begin().await?;
968 sqlx::query(insert_query)
969 .bind(user_id.0)
970 .bind(access_token_hash)
971 .execute(&mut tx)
972 .await?;
973 sqlx::query(cleanup_query)
974 .bind(user_id.0)
975 .bind(access_token_hash)
976 .bind(max_access_token_count as i32)
977 .execute(&mut tx)
978 .await?;
979 Ok(tx.commit().await?)
980 }
981
982 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
983 let query = "
984 SELECT hash
985 FROM access_tokens
986 WHERE user_id = $1
987 ORDER BY id DESC
988 ";
989 Ok(sqlx::query_scalar(query)
990 .bind(user_id.0)
991 .fetch_all(&self.pool)
992 .await?)
993 }
994
995 // orgs
996
997 #[allow(unused)] // Help rust-analyzer
998 #[cfg(any(test, feature = "seed-support"))]
999 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1000 let query = "
1001 SELECT *
1002 FROM orgs
1003 WHERE slug = $1
1004 ";
1005 Ok(sqlx::query_as(query)
1006 .bind(slug)
1007 .fetch_optional(&self.pool)
1008 .await?)
1009 }
1010
1011 #[cfg(any(test, feature = "seed-support"))]
1012 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1013 let query = "
1014 INSERT INTO orgs (name, slug)
1015 VALUES ($1, $2)
1016 RETURNING id
1017 ";
1018 Ok(sqlx::query_scalar(query)
1019 .bind(name)
1020 .bind(slug)
1021 .fetch_one(&self.pool)
1022 .await
1023 .map(OrgId)?)
1024 }
1025
1026 #[cfg(any(test, feature = "seed-support"))]
1027 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1028 let query = "
1029 INSERT INTO org_memberships (org_id, user_id, admin)
1030 VALUES ($1, $2, $3)
1031 ON CONFLICT DO NOTHING
1032 ";
1033 Ok(sqlx::query(query)
1034 .bind(org_id.0)
1035 .bind(user_id.0)
1036 .bind(is_admin)
1037 .execute(&self.pool)
1038 .await
1039 .map(drop)?)
1040 }
1041
1042 // channels
1043
1044 #[cfg(any(test, feature = "seed-support"))]
1045 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1046 let query = "
1047 INSERT INTO channels (owner_id, owner_is_user, name)
1048 VALUES ($1, false, $2)
1049 RETURNING id
1050 ";
1051 Ok(sqlx::query_scalar(query)
1052 .bind(org_id.0)
1053 .bind(name)
1054 .fetch_one(&self.pool)
1055 .await
1056 .map(ChannelId)?)
1057 }
1058
1059 #[allow(unused)] // Help rust-analyzer
1060 #[cfg(any(test, feature = "seed-support"))]
1061 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1062 let query = "
1063 SELECT *
1064 FROM channels
1065 WHERE
1066 channels.owner_is_user = false AND
1067 channels.owner_id = $1
1068 ";
1069 Ok(sqlx::query_as(query)
1070 .bind(org_id.0)
1071 .fetch_all(&self.pool)
1072 .await?)
1073 }
1074
1075 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1076 let query = "
1077 SELECT
1078 channels.*
1079 FROM
1080 channel_memberships, channels
1081 WHERE
1082 channel_memberships.user_id = $1 AND
1083 channel_memberships.channel_id = channels.id
1084 ";
1085 Ok(sqlx::query_as(query)
1086 .bind(user_id.0)
1087 .fetch_all(&self.pool)
1088 .await?)
1089 }
1090
1091 async fn can_user_access_channel(
1092 &self,
1093 user_id: UserId,
1094 channel_id: ChannelId,
1095 ) -> Result<bool> {
1096 let query = "
1097 SELECT id
1098 FROM channel_memberships
1099 WHERE user_id = $1 AND channel_id = $2
1100 LIMIT 1
1101 ";
1102 Ok(sqlx::query_scalar::<_, i32>(query)
1103 .bind(user_id.0)
1104 .bind(channel_id.0)
1105 .fetch_optional(&self.pool)
1106 .await
1107 .map(|e| e.is_some())?)
1108 }
1109
1110 #[cfg(any(test, feature = "seed-support"))]
1111 async fn add_channel_member(
1112 &self,
1113 channel_id: ChannelId,
1114 user_id: UserId,
1115 is_admin: bool,
1116 ) -> Result<()> {
1117 let query = "
1118 INSERT INTO channel_memberships (channel_id, user_id, admin)
1119 VALUES ($1, $2, $3)
1120 ON CONFLICT DO NOTHING
1121 ";
1122 Ok(sqlx::query(query)
1123 .bind(channel_id.0)
1124 .bind(user_id.0)
1125 .bind(is_admin)
1126 .execute(&self.pool)
1127 .await
1128 .map(drop)?)
1129 }
1130
1131 // messages
1132
1133 async fn create_channel_message(
1134 &self,
1135 channel_id: ChannelId,
1136 sender_id: UserId,
1137 body: &str,
1138 timestamp: OffsetDateTime,
1139 nonce: u128,
1140 ) -> Result<MessageId> {
1141 let query = "
1142 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1143 VALUES ($1, $2, $3, $4, $5)
1144 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1145 RETURNING id
1146 ";
1147 Ok(sqlx::query_scalar(query)
1148 .bind(channel_id.0)
1149 .bind(sender_id.0)
1150 .bind(body)
1151 .bind(timestamp)
1152 .bind(Uuid::from_u128(nonce))
1153 .fetch_one(&self.pool)
1154 .await
1155 .map(MessageId)?)
1156 }
1157
1158 async fn get_channel_messages(
1159 &self,
1160 channel_id: ChannelId,
1161 count: usize,
1162 before_id: Option<MessageId>,
1163 ) -> Result<Vec<ChannelMessage>> {
1164 let query = r#"
1165 SELECT * FROM (
1166 SELECT
1167 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1168 FROM
1169 channel_messages
1170 WHERE
1171 channel_id = $1 AND
1172 id < $2
1173 ORDER BY id DESC
1174 LIMIT $3
1175 ) as recent_messages
1176 ORDER BY id ASC
1177 "#;
1178 Ok(sqlx::query_as(query)
1179 .bind(channel_id.0)
1180 .bind(before_id.unwrap_or(MessageId::MAX))
1181 .bind(count as i64)
1182 .fetch_all(&self.pool)
1183 .await?)
1184 }
1185
1186 #[cfg(test)]
1187 async fn teardown(&self, url: &str) {
1188 use util::ResultExt;
1189
1190 let query = "
1191 SELECT pg_terminate_backend(pg_stat_activity.pid)
1192 FROM pg_stat_activity
1193 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1194 ";
1195 sqlx::query(query).execute(&self.pool).await.log_err();
1196 self.pool.close().await;
1197 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1198 .await
1199 .log_err();
1200 }
1201
1202 #[cfg(test)]
1203 fn as_fake(&self) -> Option<&tests::FakeDb> {
1204 None
1205 }
1206}
1207
1208macro_rules! id_type {
1209 ($name:ident) => {
1210 #[derive(
1211 Clone,
1212 Copy,
1213 Debug,
1214 Default,
1215 PartialEq,
1216 Eq,
1217 PartialOrd,
1218 Ord,
1219 Hash,
1220 sqlx::Type,
1221 Serialize,
1222 Deserialize,
1223 )]
1224 #[sqlx(transparent)]
1225 #[serde(transparent)]
1226 pub struct $name(pub i32);
1227
1228 impl $name {
1229 #[allow(unused)]
1230 pub const MAX: Self = Self(i32::MAX);
1231
1232 #[allow(unused)]
1233 pub fn from_proto(value: u64) -> Self {
1234 Self(value as i32)
1235 }
1236
1237 #[allow(unused)]
1238 pub fn to_proto(&self) -> u64 {
1239 self.0 as u64
1240 }
1241 }
1242
1243 impl std::fmt::Display for $name {
1244 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1245 self.0.fmt(f)
1246 }
1247 }
1248 };
1249}
1250
1251id_type!(UserId);
1252#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1253pub struct User {
1254 pub id: UserId,
1255 pub github_login: String,
1256 pub email_address: Option<String>,
1257 pub admin: bool,
1258 pub invite_code: Option<String>,
1259 pub invite_count: i32,
1260 pub connected_once: bool,
1261}
1262
1263id_type!(ProjectId);
1264#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1265pub struct Project {
1266 pub id: ProjectId,
1267 pub host_user_id: UserId,
1268 pub unregistered: bool,
1269}
1270
1271#[derive(Clone, Debug, PartialEq, Serialize)]
1272pub struct UserActivitySummary {
1273 pub id: UserId,
1274 pub github_login: String,
1275 pub project_activity: Vec<(ProjectId, Duration)>,
1276}
1277
1278#[derive(Clone, Debug, PartialEq, Serialize)]
1279pub struct UserActivityPeriod {
1280 project_id: ProjectId,
1281 #[serde(with = "time::serde::iso8601")]
1282 start: OffsetDateTime,
1283 #[serde(with = "time::serde::iso8601")]
1284 end: OffsetDateTime,
1285 extensions: HashMap<String, usize>,
1286}
1287
1288id_type!(OrgId);
1289#[derive(FromRow)]
1290pub struct Org {
1291 pub id: OrgId,
1292 pub name: String,
1293 pub slug: String,
1294}
1295
1296id_type!(ChannelId);
1297#[derive(Clone, Debug, FromRow, Serialize)]
1298pub struct Channel {
1299 pub id: ChannelId,
1300 pub name: String,
1301 pub owner_id: i32,
1302 pub owner_is_user: bool,
1303}
1304
1305id_type!(MessageId);
1306#[derive(Clone, Debug, FromRow)]
1307pub struct ChannelMessage {
1308 pub id: MessageId,
1309 pub channel_id: ChannelId,
1310 pub sender_id: UserId,
1311 pub body: String,
1312 pub sent_at: OffsetDateTime,
1313 pub nonce: Uuid,
1314}
1315
1316#[derive(Clone, Debug, PartialEq, Eq)]
1317pub enum Contact {
1318 Accepted {
1319 user_id: UserId,
1320 should_notify: bool,
1321 },
1322 Outgoing {
1323 user_id: UserId,
1324 },
1325 Incoming {
1326 user_id: UserId,
1327 should_notify: bool,
1328 },
1329}
1330
1331impl Contact {
1332 pub fn user_id(&self) -> UserId {
1333 match self {
1334 Contact::Accepted { user_id, .. } => *user_id,
1335 Contact::Outgoing { user_id } => *user_id,
1336 Contact::Incoming { user_id, .. } => *user_id,
1337 }
1338 }
1339}
1340
1341#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1342pub struct IncomingContactRequest {
1343 pub requester_id: UserId,
1344 pub should_notify: bool,
1345}
1346
1347fn fuzzy_like_string(string: &str) -> String {
1348 let mut result = String::with_capacity(string.len() * 2 + 1);
1349 for c in string.chars() {
1350 if c.is_alphanumeric() {
1351 result.push('%');
1352 result.push(c);
1353 }
1354 }
1355 result.push('%');
1356 result
1357}
1358
1359#[cfg(test)]
1360pub mod tests {
1361 use super::*;
1362 use anyhow::anyhow;
1363 use collections::BTreeMap;
1364 use gpui::executor::Background;
1365 use lazy_static::lazy_static;
1366 use parking_lot::Mutex;
1367 use rand::prelude::*;
1368 use sqlx::{
1369 migrate::{MigrateDatabase, Migrator},
1370 Postgres,
1371 };
1372 use std::{path::Path, sync::Arc};
1373 use util::post_inc;
1374
1375 #[tokio::test(flavor = "multi_thread")]
1376 async fn test_get_users_by_ids() {
1377 for test_db in [
1378 TestDb::postgres().await,
1379 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1380 ] {
1381 let db = test_db.db();
1382
1383 let user = db.create_user("user", None, false).await.unwrap();
1384 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1385 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1386 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1387
1388 assert_eq!(
1389 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1390 .await
1391 .unwrap(),
1392 vec![
1393 User {
1394 id: user,
1395 github_login: "user".to_string(),
1396 admin: false,
1397 ..Default::default()
1398 },
1399 User {
1400 id: friend1,
1401 github_login: "friend-1".to_string(),
1402 admin: false,
1403 ..Default::default()
1404 },
1405 User {
1406 id: friend2,
1407 github_login: "friend-2".to_string(),
1408 admin: false,
1409 ..Default::default()
1410 },
1411 User {
1412 id: friend3,
1413 github_login: "friend-3".to_string(),
1414 admin: false,
1415 ..Default::default()
1416 }
1417 ]
1418 );
1419 }
1420 }
1421
1422 #[tokio::test(flavor = "multi_thread")]
1423 async fn test_create_users() {
1424 let db = TestDb::postgres().await;
1425 let db = db.db();
1426
1427 // Create the first batch of users, ensuring invite counts are assigned
1428 // correctly and the respective invite codes are unique.
1429 let user_ids_batch_1 = db
1430 .create_users(vec![
1431 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1432 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1433 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1434 ])
1435 .await
1436 .unwrap();
1437 assert_eq!(user_ids_batch_1.len(), 3);
1438
1439 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1440 assert_eq!(users.len(), 3);
1441 assert_eq!(users[0].github_login, "user1");
1442 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1443 assert_eq!(users[0].invite_count, 5);
1444 assert_eq!(users[1].github_login, "user2");
1445 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1446 assert_eq!(users[1].invite_count, 4);
1447 assert_eq!(users[2].github_login, "user3");
1448 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1449 assert_eq!(users[2].invite_count, 3);
1450
1451 let invite_code_1 = users[0].invite_code.clone().unwrap();
1452 let invite_code_2 = users[1].invite_code.clone().unwrap();
1453 let invite_code_3 = users[2].invite_code.clone().unwrap();
1454 assert_ne!(invite_code_1, invite_code_2);
1455 assert_ne!(invite_code_1, invite_code_3);
1456 assert_ne!(invite_code_2, invite_code_3);
1457
1458 // Create the second batch of users and include a user that is already in the database, ensuring
1459 // the invite count for the existing user is updated without changing their invite code.
1460 let user_ids_batch_2 = db
1461 .create_users(vec![
1462 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1463 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1464 ])
1465 .await
1466 .unwrap();
1467 assert_eq!(user_ids_batch_2.len(), 2);
1468 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1469
1470 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1471 assert_eq!(users.len(), 2);
1472 assert_eq!(users[0].github_login, "user2");
1473 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1474 assert_eq!(users[0].invite_count, 10);
1475 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1476 assert_eq!(users[1].github_login, "user4");
1477 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1478 assert_eq!(users[1].invite_count, 2);
1479
1480 let invite_code_4 = users[1].invite_code.clone().unwrap();
1481 assert_ne!(invite_code_4, invite_code_1);
1482 assert_ne!(invite_code_4, invite_code_2);
1483 assert_ne!(invite_code_4, invite_code_3);
1484 }
1485
1486 #[tokio::test(flavor = "multi_thread")]
1487 async fn test_worktree_extensions() {
1488 let test_db = TestDb::postgres().await;
1489 let db = test_db.db();
1490
1491 let user = db.create_user("user_1", None, false).await.unwrap();
1492 let project = db.register_project(user).await.unwrap();
1493
1494 db.update_worktree_extensions(project, 100, Default::default())
1495 .await
1496 .unwrap();
1497 db.update_worktree_extensions(
1498 project,
1499 100,
1500 [("rs".to_string(), 5), ("md".to_string(), 3)]
1501 .into_iter()
1502 .collect(),
1503 )
1504 .await
1505 .unwrap();
1506 db.update_worktree_extensions(
1507 project,
1508 100,
1509 [("rs".to_string(), 6), ("md".to_string(), 5)]
1510 .into_iter()
1511 .collect(),
1512 )
1513 .await
1514 .unwrap();
1515 db.update_worktree_extensions(
1516 project,
1517 101,
1518 [("ts".to_string(), 2), ("md".to_string(), 1)]
1519 .into_iter()
1520 .collect(),
1521 )
1522 .await
1523 .unwrap();
1524
1525 assert_eq!(
1526 db.get_project_extensions(project).await.unwrap(),
1527 [
1528 (
1529 100,
1530 [("rs".into(), 6), ("md".into(), 5),]
1531 .into_iter()
1532 .collect::<HashMap<_, _>>()
1533 ),
1534 (
1535 101,
1536 [("ts".into(), 2), ("md".into(), 1),]
1537 .into_iter()
1538 .collect::<HashMap<_, _>>()
1539 )
1540 ]
1541 .into_iter()
1542 .collect()
1543 );
1544 }
1545
1546 #[tokio::test(flavor = "multi_thread")]
1547 async fn test_project_activity() {
1548 let test_db = TestDb::postgres().await;
1549 let db = test_db.db();
1550
1551 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1552 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1553 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1554 let project_1 = db.register_project(user_1).await.unwrap();
1555 db.update_worktree_extensions(
1556 project_1,
1557 1,
1558 HashMap::from_iter([("rs".into(), 5), ("md".into(), 7)]),
1559 )
1560 .await
1561 .unwrap();
1562 let project_2 = db.register_project(user_2).await.unwrap();
1563 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1564
1565 // User 2 opens a project
1566 let t1 = t0 + Duration::from_secs(10);
1567 db.record_user_activity(t0..t1, &[(user_2, project_2)])
1568 .await
1569 .unwrap();
1570
1571 let t2 = t1 + Duration::from_secs(10);
1572 db.record_user_activity(t1..t2, &[(user_2, project_2)])
1573 .await
1574 .unwrap();
1575
1576 // User 1 joins the project
1577 let t3 = t2 + Duration::from_secs(10);
1578 db.record_user_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1579 .await
1580 .unwrap();
1581
1582 // User 1 opens another project
1583 let t4 = t3 + Duration::from_secs(10);
1584 db.record_user_activity(
1585 t3..t4,
1586 &[
1587 (user_2, project_2),
1588 (user_1, project_2),
1589 (user_1, project_1),
1590 ],
1591 )
1592 .await
1593 .unwrap();
1594
1595 // User 3 joins that project
1596 let t5 = t4 + Duration::from_secs(10);
1597 db.record_user_activity(
1598 t4..t5,
1599 &[
1600 (user_2, project_2),
1601 (user_1, project_2),
1602 (user_1, project_1),
1603 (user_3, project_1),
1604 ],
1605 )
1606 .await
1607 .unwrap();
1608
1609 // User 2 leaves
1610 let t6 = t5 + Duration::from_secs(5);
1611 db.record_user_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1612 .await
1613 .unwrap();
1614
1615 let t7 = t6 + Duration::from_secs(60);
1616 let t8 = t7 + Duration::from_secs(10);
1617 db.record_user_activity(t7..t8, &[(user_1, project_1)])
1618 .await
1619 .unwrap();
1620
1621 assert_eq!(
1622 db.get_top_users_activity_summary(t0..t6, 10).await.unwrap(),
1623 &[
1624 UserActivitySummary {
1625 id: user_1,
1626 github_login: "user_1".to_string(),
1627 project_activity: vec![
1628 (project_1, Duration::from_secs(25)),
1629 (project_2, Duration::from_secs(30)),
1630 ]
1631 },
1632 UserActivitySummary {
1633 id: user_2,
1634 github_login: "user_2".to_string(),
1635 project_activity: vec![(project_2, Duration::from_secs(50))]
1636 },
1637 UserActivitySummary {
1638 id: user_3,
1639 github_login: "user_3".to_string(),
1640 project_activity: vec![(project_1, Duration::from_secs(15))]
1641 },
1642 ]
1643 );
1644 assert_eq!(
1645 db.get_user_activity_timeline(t3..t6, user_1).await.unwrap(),
1646 &[
1647 UserActivityPeriod {
1648 project_id: project_1,
1649 start: t3,
1650 end: t6,
1651 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1652 },
1653 UserActivityPeriod {
1654 project_id: project_2,
1655 start: t3,
1656 end: t5,
1657 extensions: Default::default(),
1658 },
1659 ]
1660 );
1661 assert_eq!(
1662 db.get_user_activity_timeline(t0..t8, user_1).await.unwrap(),
1663 &[
1664 UserActivityPeriod {
1665 project_id: project_2,
1666 start: t2,
1667 end: t5,
1668 extensions: Default::default(),
1669 },
1670 UserActivityPeriod {
1671 project_id: project_1,
1672 start: t3,
1673 end: t6,
1674 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1675 },
1676 UserActivityPeriod {
1677 project_id: project_1,
1678 start: t7,
1679 end: t8,
1680 extensions: HashMap::from_iter([("rs".to_string(), 5), ("md".to_string(), 7)]),
1681 },
1682 ]
1683 );
1684 }
1685
1686 #[tokio::test(flavor = "multi_thread")]
1687 async fn test_recent_channel_messages() {
1688 for test_db in [
1689 TestDb::postgres().await,
1690 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1691 ] {
1692 let db = test_db.db();
1693 let user = db.create_user("user", None, false).await.unwrap();
1694 let org = db.create_org("org", "org").await.unwrap();
1695 let channel = db.create_org_channel(org, "channel").await.unwrap();
1696 for i in 0..10 {
1697 db.create_channel_message(
1698 channel,
1699 user,
1700 &i.to_string(),
1701 OffsetDateTime::now_utc(),
1702 i,
1703 )
1704 .await
1705 .unwrap();
1706 }
1707
1708 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1709 assert_eq!(
1710 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1711 ["5", "6", "7", "8", "9"]
1712 );
1713
1714 let prev_messages = db
1715 .get_channel_messages(channel, 4, Some(messages[0].id))
1716 .await
1717 .unwrap();
1718 assert_eq!(
1719 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1720 ["1", "2", "3", "4"]
1721 );
1722 }
1723 }
1724
1725 #[tokio::test(flavor = "multi_thread")]
1726 async fn test_channel_message_nonces() {
1727 for test_db in [
1728 TestDb::postgres().await,
1729 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1730 ] {
1731 let db = test_db.db();
1732 let user = db.create_user("user", None, false).await.unwrap();
1733 let org = db.create_org("org", "org").await.unwrap();
1734 let channel = db.create_org_channel(org, "channel").await.unwrap();
1735
1736 let msg1_id = db
1737 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1738 .await
1739 .unwrap();
1740 let msg2_id = db
1741 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1742 .await
1743 .unwrap();
1744 let msg3_id = db
1745 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1746 .await
1747 .unwrap();
1748 let msg4_id = db
1749 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1750 .await
1751 .unwrap();
1752
1753 assert_ne!(msg1_id, msg2_id);
1754 assert_eq!(msg1_id, msg3_id);
1755 assert_eq!(msg2_id, msg4_id);
1756 }
1757 }
1758
1759 #[tokio::test(flavor = "multi_thread")]
1760 async fn test_create_access_tokens() {
1761 let test_db = TestDb::postgres().await;
1762 let db = test_db.db();
1763 let user = db.create_user("the-user", None, false).await.unwrap();
1764
1765 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1766 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1767 assert_eq!(
1768 db.get_access_token_hashes(user).await.unwrap(),
1769 &["h2".to_string(), "h1".to_string()]
1770 );
1771
1772 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1773 assert_eq!(
1774 db.get_access_token_hashes(user).await.unwrap(),
1775 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1776 );
1777
1778 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1779 assert_eq!(
1780 db.get_access_token_hashes(user).await.unwrap(),
1781 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1782 );
1783
1784 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1785 assert_eq!(
1786 db.get_access_token_hashes(user).await.unwrap(),
1787 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1788 );
1789 }
1790
1791 #[test]
1792 fn test_fuzzy_like_string() {
1793 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1794 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1795 assert_eq!(fuzzy_like_string(" z "), "%z%");
1796 }
1797
1798 #[tokio::test(flavor = "multi_thread")]
1799 async fn test_fuzzy_search_users() {
1800 let test_db = TestDb::postgres().await;
1801 let db = test_db.db();
1802 for github_login in [
1803 "California",
1804 "colorado",
1805 "oregon",
1806 "washington",
1807 "florida",
1808 "delaware",
1809 "rhode-island",
1810 ] {
1811 db.create_user(github_login, None, false).await.unwrap();
1812 }
1813
1814 assert_eq!(
1815 fuzzy_search_user_names(db, "clr").await,
1816 &["colorado", "California"]
1817 );
1818 assert_eq!(
1819 fuzzy_search_user_names(db, "ro").await,
1820 &["rhode-island", "colorado", "oregon"],
1821 );
1822
1823 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1824 db.fuzzy_search_users(query, 10)
1825 .await
1826 .unwrap()
1827 .into_iter()
1828 .map(|user| user.github_login)
1829 .collect::<Vec<_>>()
1830 }
1831 }
1832
1833 #[tokio::test(flavor = "multi_thread")]
1834 async fn test_add_contacts() {
1835 for test_db in [
1836 TestDb::postgres().await,
1837 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1838 ] {
1839 let db = test_db.db();
1840
1841 let user_1 = db.create_user("user1", None, false).await.unwrap();
1842 let user_2 = db.create_user("user2", None, false).await.unwrap();
1843 let user_3 = db.create_user("user3", None, false).await.unwrap();
1844
1845 // User starts with no contacts
1846 assert_eq!(
1847 db.get_contacts(user_1).await.unwrap(),
1848 vec![Contact::Accepted {
1849 user_id: user_1,
1850 should_notify: false
1851 }],
1852 );
1853
1854 // User requests a contact. Both users see the pending request.
1855 db.send_contact_request(user_1, user_2).await.unwrap();
1856 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1857 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1858 assert_eq!(
1859 db.get_contacts(user_1).await.unwrap(),
1860 &[
1861 Contact::Accepted {
1862 user_id: user_1,
1863 should_notify: false
1864 },
1865 Contact::Outgoing { user_id: user_2 }
1866 ],
1867 );
1868 assert_eq!(
1869 db.get_contacts(user_2).await.unwrap(),
1870 &[
1871 Contact::Incoming {
1872 user_id: user_1,
1873 should_notify: true
1874 },
1875 Contact::Accepted {
1876 user_id: user_2,
1877 should_notify: false
1878 },
1879 ]
1880 );
1881
1882 // User 2 dismisses the contact request notification without accepting or rejecting.
1883 // We shouldn't notify them again.
1884 db.dismiss_contact_notification(user_1, user_2)
1885 .await
1886 .unwrap_err();
1887 db.dismiss_contact_notification(user_2, user_1)
1888 .await
1889 .unwrap();
1890 assert_eq!(
1891 db.get_contacts(user_2).await.unwrap(),
1892 &[
1893 Contact::Incoming {
1894 user_id: user_1,
1895 should_notify: false
1896 },
1897 Contact::Accepted {
1898 user_id: user_2,
1899 should_notify: false
1900 },
1901 ]
1902 );
1903
1904 // User can't accept their own contact request
1905 db.respond_to_contact_request(user_1, user_2, true)
1906 .await
1907 .unwrap_err();
1908
1909 // User accepts a contact request. Both users see the contact.
1910 db.respond_to_contact_request(user_2, user_1, true)
1911 .await
1912 .unwrap();
1913 assert_eq!(
1914 db.get_contacts(user_1).await.unwrap(),
1915 &[
1916 Contact::Accepted {
1917 user_id: user_1,
1918 should_notify: false
1919 },
1920 Contact::Accepted {
1921 user_id: user_2,
1922 should_notify: true
1923 }
1924 ],
1925 );
1926 assert!(db.has_contact(user_1, user_2).await.unwrap());
1927 assert!(db.has_contact(user_2, user_1).await.unwrap());
1928 assert_eq!(
1929 db.get_contacts(user_2).await.unwrap(),
1930 &[
1931 Contact::Accepted {
1932 user_id: user_1,
1933 should_notify: false,
1934 },
1935 Contact::Accepted {
1936 user_id: user_2,
1937 should_notify: false,
1938 },
1939 ]
1940 );
1941
1942 // Users cannot re-request existing contacts.
1943 db.send_contact_request(user_1, user_2).await.unwrap_err();
1944 db.send_contact_request(user_2, user_1).await.unwrap_err();
1945
1946 // Users can't dismiss notifications of them accepting other users' requests.
1947 db.dismiss_contact_notification(user_2, user_1)
1948 .await
1949 .unwrap_err();
1950 assert_eq!(
1951 db.get_contacts(user_1).await.unwrap(),
1952 &[
1953 Contact::Accepted {
1954 user_id: user_1,
1955 should_notify: false
1956 },
1957 Contact::Accepted {
1958 user_id: user_2,
1959 should_notify: true,
1960 },
1961 ]
1962 );
1963
1964 // Users can dismiss notifications of other users accepting their requests.
1965 db.dismiss_contact_notification(user_1, user_2)
1966 .await
1967 .unwrap();
1968 assert_eq!(
1969 db.get_contacts(user_1).await.unwrap(),
1970 &[
1971 Contact::Accepted {
1972 user_id: user_1,
1973 should_notify: false
1974 },
1975 Contact::Accepted {
1976 user_id: user_2,
1977 should_notify: false,
1978 },
1979 ]
1980 );
1981
1982 // Users send each other concurrent contact requests and
1983 // see that they are immediately accepted.
1984 db.send_contact_request(user_1, user_3).await.unwrap();
1985 db.send_contact_request(user_3, user_1).await.unwrap();
1986 assert_eq!(
1987 db.get_contacts(user_1).await.unwrap(),
1988 &[
1989 Contact::Accepted {
1990 user_id: user_1,
1991 should_notify: false
1992 },
1993 Contact::Accepted {
1994 user_id: user_2,
1995 should_notify: false,
1996 },
1997 Contact::Accepted {
1998 user_id: user_3,
1999 should_notify: false
2000 },
2001 ]
2002 );
2003 assert_eq!(
2004 db.get_contacts(user_3).await.unwrap(),
2005 &[
2006 Contact::Accepted {
2007 user_id: user_1,
2008 should_notify: false
2009 },
2010 Contact::Accepted {
2011 user_id: user_3,
2012 should_notify: false
2013 }
2014 ],
2015 );
2016
2017 // User declines a contact request. Both users see that it is gone.
2018 db.send_contact_request(user_2, user_3).await.unwrap();
2019 db.respond_to_contact_request(user_3, user_2, false)
2020 .await
2021 .unwrap();
2022 assert!(!db.has_contact(user_2, user_3).await.unwrap());
2023 assert!(!db.has_contact(user_3, user_2).await.unwrap());
2024 assert_eq!(
2025 db.get_contacts(user_2).await.unwrap(),
2026 &[
2027 Contact::Accepted {
2028 user_id: user_1,
2029 should_notify: false
2030 },
2031 Contact::Accepted {
2032 user_id: user_2,
2033 should_notify: false
2034 }
2035 ]
2036 );
2037 assert_eq!(
2038 db.get_contacts(user_3).await.unwrap(),
2039 &[
2040 Contact::Accepted {
2041 user_id: user_1,
2042 should_notify: false
2043 },
2044 Contact::Accepted {
2045 user_id: user_3,
2046 should_notify: false
2047 }
2048 ],
2049 );
2050 }
2051 }
2052
2053 #[tokio::test(flavor = "multi_thread")]
2054 async fn test_invite_codes() {
2055 let postgres = TestDb::postgres().await;
2056 let db = postgres.db();
2057 let user1 = db.create_user("user-1", None, false).await.unwrap();
2058
2059 // Initially, user 1 has no invite code
2060 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
2061
2062 // Setting invite count to 0 when no code is assigned does not assign a new code
2063 db.set_invite_count(user1, 0).await.unwrap();
2064 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
2065
2066 // User 1 creates an invite code that can be used twice.
2067 db.set_invite_count(user1, 2).await.unwrap();
2068 let (invite_code, invite_count) =
2069 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2070 assert_eq!(invite_count, 2);
2071
2072 // User 2 redeems the invite code and becomes a contact of user 1.
2073 let user2 = db
2074 .redeem_invite_code(&invite_code, "user-2", None)
2075 .await
2076 .unwrap();
2077 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2078 assert_eq!(invite_count, 1);
2079 assert_eq!(
2080 db.get_contacts(user1).await.unwrap(),
2081 [
2082 Contact::Accepted {
2083 user_id: user1,
2084 should_notify: false
2085 },
2086 Contact::Accepted {
2087 user_id: user2,
2088 should_notify: true
2089 }
2090 ]
2091 );
2092 assert_eq!(
2093 db.get_contacts(user2).await.unwrap(),
2094 [
2095 Contact::Accepted {
2096 user_id: user1,
2097 should_notify: false
2098 },
2099 Contact::Accepted {
2100 user_id: user2,
2101 should_notify: false
2102 }
2103 ]
2104 );
2105
2106 // User 3 redeems the invite code and becomes a contact of user 1.
2107 let user3 = db
2108 .redeem_invite_code(&invite_code, "user-3", None)
2109 .await
2110 .unwrap();
2111 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2112 assert_eq!(invite_count, 0);
2113 assert_eq!(
2114 db.get_contacts(user1).await.unwrap(),
2115 [
2116 Contact::Accepted {
2117 user_id: user1,
2118 should_notify: false
2119 },
2120 Contact::Accepted {
2121 user_id: user2,
2122 should_notify: true
2123 },
2124 Contact::Accepted {
2125 user_id: user3,
2126 should_notify: true
2127 }
2128 ]
2129 );
2130 assert_eq!(
2131 db.get_contacts(user3).await.unwrap(),
2132 [
2133 Contact::Accepted {
2134 user_id: user1,
2135 should_notify: false
2136 },
2137 Contact::Accepted {
2138 user_id: user3,
2139 should_notify: false
2140 },
2141 ]
2142 );
2143
2144 // Trying to reedem the code for the third time results in an error.
2145 db.redeem_invite_code(&invite_code, "user-4", None)
2146 .await
2147 .unwrap_err();
2148
2149 // Invite count can be updated after the code has been created.
2150 db.set_invite_count(user1, 2).await.unwrap();
2151 let (latest_code, invite_count) =
2152 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2153 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
2154 assert_eq!(invite_count, 2);
2155
2156 // User 4 can now redeem the invite code and becomes a contact of user 1.
2157 let user4 = db
2158 .redeem_invite_code(&invite_code, "user-4", None)
2159 .await
2160 .unwrap();
2161 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2162 assert_eq!(invite_count, 1);
2163 assert_eq!(
2164 db.get_contacts(user1).await.unwrap(),
2165 [
2166 Contact::Accepted {
2167 user_id: user1,
2168 should_notify: false
2169 },
2170 Contact::Accepted {
2171 user_id: user2,
2172 should_notify: true
2173 },
2174 Contact::Accepted {
2175 user_id: user3,
2176 should_notify: true
2177 },
2178 Contact::Accepted {
2179 user_id: user4,
2180 should_notify: true
2181 }
2182 ]
2183 );
2184 assert_eq!(
2185 db.get_contacts(user4).await.unwrap(),
2186 [
2187 Contact::Accepted {
2188 user_id: user1,
2189 should_notify: false
2190 },
2191 Contact::Accepted {
2192 user_id: user4,
2193 should_notify: false
2194 },
2195 ]
2196 );
2197
2198 // An existing user cannot redeem invite codes.
2199 db.redeem_invite_code(&invite_code, "user-2", None)
2200 .await
2201 .unwrap_err();
2202 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2203 assert_eq!(invite_count, 1);
2204 }
2205
2206 pub struct TestDb {
2207 pub db: Option<Arc<dyn Db>>,
2208 pub url: String,
2209 }
2210
2211 impl TestDb {
2212 pub async fn postgres() -> Self {
2213 lazy_static! {
2214 static ref LOCK: Mutex<()> = Mutex::new(());
2215 }
2216
2217 let _guard = LOCK.lock();
2218 let mut rng = StdRng::from_entropy();
2219 let name = format!("zed-test-{}", rng.gen::<u128>());
2220 let url = format!("postgres://postgres@localhost/{}", name);
2221 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2222 Postgres::create_database(&url)
2223 .await
2224 .expect("failed to create test db");
2225 let db = PostgresDb::new(&url, 5).await.unwrap();
2226 let migrator = Migrator::new(migrations_path).await.unwrap();
2227 migrator.run(&db.pool).await.unwrap();
2228 Self {
2229 db: Some(Arc::new(db)),
2230 url,
2231 }
2232 }
2233
2234 pub fn fake(background: Arc<Background>) -> Self {
2235 Self {
2236 db: Some(Arc::new(FakeDb::new(background))),
2237 url: Default::default(),
2238 }
2239 }
2240
2241 pub fn db(&self) -> &Arc<dyn Db> {
2242 self.db.as_ref().unwrap()
2243 }
2244 }
2245
2246 impl Drop for TestDb {
2247 fn drop(&mut self) {
2248 if let Some(db) = self.db.take() {
2249 futures::executor::block_on(db.teardown(&self.url));
2250 }
2251 }
2252 }
2253
2254 pub struct FakeDb {
2255 background: Arc<Background>,
2256 pub users: Mutex<BTreeMap<UserId, User>>,
2257 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2258 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
2259 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2260 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2261 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2262 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2263 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2264 pub contacts: Mutex<Vec<FakeContact>>,
2265 next_channel_message_id: Mutex<i32>,
2266 next_user_id: Mutex<i32>,
2267 next_org_id: Mutex<i32>,
2268 next_channel_id: Mutex<i32>,
2269 next_project_id: Mutex<i32>,
2270 }
2271
2272 #[derive(Debug)]
2273 pub struct FakeContact {
2274 pub requester_id: UserId,
2275 pub responder_id: UserId,
2276 pub accepted: bool,
2277 pub should_notify: bool,
2278 }
2279
2280 impl FakeDb {
2281 pub fn new(background: Arc<Background>) -> Self {
2282 Self {
2283 background,
2284 users: Default::default(),
2285 next_user_id: Mutex::new(0),
2286 projects: Default::default(),
2287 worktree_extensions: Default::default(),
2288 next_project_id: Mutex::new(1),
2289 orgs: Default::default(),
2290 next_org_id: Mutex::new(1),
2291 org_memberships: Default::default(),
2292 channels: Default::default(),
2293 next_channel_id: Mutex::new(1),
2294 channel_memberships: Default::default(),
2295 channel_messages: Default::default(),
2296 next_channel_message_id: Mutex::new(1),
2297 contacts: Default::default(),
2298 }
2299 }
2300 }
2301
2302 #[async_trait]
2303 impl Db for FakeDb {
2304 async fn create_user(
2305 &self,
2306 github_login: &str,
2307 email_address: Option<&str>,
2308 admin: bool,
2309 ) -> Result<UserId> {
2310 self.background.simulate_random_delay().await;
2311
2312 let mut users = self.users.lock();
2313 if let Some(user) = users
2314 .values()
2315 .find(|user| user.github_login == github_login)
2316 {
2317 Ok(user.id)
2318 } else {
2319 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2320 users.insert(
2321 user_id,
2322 User {
2323 id: user_id,
2324 github_login: github_login.to_string(),
2325 email_address: email_address.map(str::to_string),
2326 admin,
2327 invite_code: None,
2328 invite_count: 0,
2329 connected_once: false,
2330 },
2331 );
2332 Ok(user_id)
2333 }
2334 }
2335
2336 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2337 unimplemented!()
2338 }
2339
2340 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2341 unimplemented!()
2342 }
2343
2344 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2345 unimplemented!()
2346 }
2347
2348 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2349 self.background.simulate_random_delay().await;
2350 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2351 }
2352
2353 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2354 self.background.simulate_random_delay().await;
2355 let users = self.users.lock();
2356 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2357 }
2358
2359 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2360 unimplemented!()
2361 }
2362
2363 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2364 self.background.simulate_random_delay().await;
2365 Ok(self
2366 .users
2367 .lock()
2368 .values()
2369 .find(|user| user.github_login == github_login)
2370 .cloned())
2371 }
2372
2373 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2374 unimplemented!()
2375 }
2376
2377 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2378 self.background.simulate_random_delay().await;
2379 let mut users = self.users.lock();
2380 let mut user = users
2381 .get_mut(&id)
2382 .ok_or_else(|| anyhow!("user not found"))?;
2383 user.connected_once = connected_once;
2384 Ok(())
2385 }
2386
2387 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2388 unimplemented!()
2389 }
2390
2391 // invite codes
2392
2393 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2394 unimplemented!()
2395 }
2396
2397 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2398 self.background.simulate_random_delay().await;
2399 Ok(None)
2400 }
2401
2402 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2403 unimplemented!()
2404 }
2405
2406 async fn redeem_invite_code(
2407 &self,
2408 _code: &str,
2409 _login: &str,
2410 _email_address: Option<&str>,
2411 ) -> Result<UserId> {
2412 unimplemented!()
2413 }
2414
2415 // projects
2416
2417 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2418 self.background.simulate_random_delay().await;
2419 if !self.users.lock().contains_key(&host_user_id) {
2420 Err(anyhow!("no such user"))?;
2421 }
2422
2423 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2424 self.projects.lock().insert(
2425 project_id,
2426 Project {
2427 id: project_id,
2428 host_user_id,
2429 unregistered: false,
2430 },
2431 );
2432 Ok(project_id)
2433 }
2434
2435 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2436 self.background.simulate_random_delay().await;
2437 self.projects
2438 .lock()
2439 .get_mut(&project_id)
2440 .ok_or_else(|| anyhow!("no such project"))?
2441 .unregistered = true;
2442 Ok(())
2443 }
2444
2445 async fn update_worktree_extensions(
2446 &self,
2447 project_id: ProjectId,
2448 worktree_id: u64,
2449 extensions: HashMap<String, u32>,
2450 ) -> Result<()> {
2451 self.background.simulate_random_delay().await;
2452 if !self.projects.lock().contains_key(&project_id) {
2453 Err(anyhow!("no such project"))?;
2454 }
2455
2456 for (extension, count) in extensions {
2457 self.worktree_extensions
2458 .lock()
2459 .insert((project_id, worktree_id, extension), count);
2460 }
2461
2462 Ok(())
2463 }
2464
2465 async fn get_project_extensions(
2466 &self,
2467 _project_id: ProjectId,
2468 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2469 unimplemented!()
2470 }
2471
2472 async fn record_user_activity(
2473 &self,
2474 _time_period: Range<OffsetDateTime>,
2475 _active_projects: &[(UserId, ProjectId)],
2476 ) -> Result<()> {
2477 unimplemented!()
2478 }
2479
2480 async fn get_top_users_activity_summary(
2481 &self,
2482 _time_period: Range<OffsetDateTime>,
2483 _limit: usize,
2484 ) -> Result<Vec<UserActivitySummary>> {
2485 unimplemented!()
2486 }
2487
2488 async fn get_user_activity_timeline(
2489 &self,
2490 _time_period: Range<OffsetDateTime>,
2491 _user_id: UserId,
2492 ) -> Result<Vec<UserActivityPeriod>> {
2493 unimplemented!()
2494 }
2495
2496 // contacts
2497
2498 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2499 self.background.simulate_random_delay().await;
2500 let mut contacts = vec![Contact::Accepted {
2501 user_id: id,
2502 should_notify: false,
2503 }];
2504
2505 for contact in self.contacts.lock().iter() {
2506 if contact.requester_id == id {
2507 if contact.accepted {
2508 contacts.push(Contact::Accepted {
2509 user_id: contact.responder_id,
2510 should_notify: contact.should_notify,
2511 });
2512 } else {
2513 contacts.push(Contact::Outgoing {
2514 user_id: contact.responder_id,
2515 });
2516 }
2517 } else if contact.responder_id == id {
2518 if contact.accepted {
2519 contacts.push(Contact::Accepted {
2520 user_id: contact.requester_id,
2521 should_notify: false,
2522 });
2523 } else {
2524 contacts.push(Contact::Incoming {
2525 user_id: contact.requester_id,
2526 should_notify: contact.should_notify,
2527 });
2528 }
2529 }
2530 }
2531
2532 contacts.sort_unstable_by_key(|contact| contact.user_id());
2533 Ok(contacts)
2534 }
2535
2536 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2537 self.background.simulate_random_delay().await;
2538 Ok(self.contacts.lock().iter().any(|contact| {
2539 contact.accepted
2540 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2541 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2542 }))
2543 }
2544
2545 async fn send_contact_request(
2546 &self,
2547 requester_id: UserId,
2548 responder_id: UserId,
2549 ) -> Result<()> {
2550 self.background.simulate_random_delay().await;
2551 let mut contacts = self.contacts.lock();
2552 for contact in contacts.iter_mut() {
2553 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2554 if contact.accepted {
2555 Err(anyhow!("contact already exists"))?;
2556 } else {
2557 Err(anyhow!("contact already requested"))?;
2558 }
2559 }
2560 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2561 if contact.accepted {
2562 Err(anyhow!("contact already exists"))?;
2563 } else {
2564 contact.accepted = true;
2565 contact.should_notify = false;
2566 return Ok(());
2567 }
2568 }
2569 }
2570 contacts.push(FakeContact {
2571 requester_id,
2572 responder_id,
2573 accepted: false,
2574 should_notify: true,
2575 });
2576 Ok(())
2577 }
2578
2579 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2580 self.background.simulate_random_delay().await;
2581 self.contacts.lock().retain(|contact| {
2582 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2583 });
2584 Ok(())
2585 }
2586
2587 async fn dismiss_contact_notification(
2588 &self,
2589 user_id: UserId,
2590 contact_user_id: UserId,
2591 ) -> Result<()> {
2592 self.background.simulate_random_delay().await;
2593 let mut contacts = self.contacts.lock();
2594 for contact in contacts.iter_mut() {
2595 if contact.requester_id == contact_user_id
2596 && contact.responder_id == user_id
2597 && !contact.accepted
2598 {
2599 contact.should_notify = false;
2600 return Ok(());
2601 }
2602 if contact.requester_id == user_id
2603 && contact.responder_id == contact_user_id
2604 && contact.accepted
2605 {
2606 contact.should_notify = false;
2607 return Ok(());
2608 }
2609 }
2610 Err(anyhow!("no such notification"))?
2611 }
2612
2613 async fn respond_to_contact_request(
2614 &self,
2615 responder_id: UserId,
2616 requester_id: UserId,
2617 accept: bool,
2618 ) -> Result<()> {
2619 self.background.simulate_random_delay().await;
2620 let mut contacts = self.contacts.lock();
2621 for (ix, contact) in contacts.iter_mut().enumerate() {
2622 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2623 if contact.accepted {
2624 Err(anyhow!("contact already confirmed"))?;
2625 }
2626 if accept {
2627 contact.accepted = true;
2628 contact.should_notify = true;
2629 } else {
2630 contacts.remove(ix);
2631 }
2632 return Ok(());
2633 }
2634 }
2635 Err(anyhow!("no such contact request"))?
2636 }
2637
2638 async fn create_access_token_hash(
2639 &self,
2640 _user_id: UserId,
2641 _access_token_hash: &str,
2642 _max_access_token_count: usize,
2643 ) -> Result<()> {
2644 unimplemented!()
2645 }
2646
2647 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2648 unimplemented!()
2649 }
2650
2651 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2652 unimplemented!()
2653 }
2654
2655 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2656 self.background.simulate_random_delay().await;
2657 let mut orgs = self.orgs.lock();
2658 if orgs.values().any(|org| org.slug == slug) {
2659 Err(anyhow!("org already exists"))?
2660 } else {
2661 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2662 orgs.insert(
2663 org_id,
2664 Org {
2665 id: org_id,
2666 name: name.to_string(),
2667 slug: slug.to_string(),
2668 },
2669 );
2670 Ok(org_id)
2671 }
2672 }
2673
2674 async fn add_org_member(
2675 &self,
2676 org_id: OrgId,
2677 user_id: UserId,
2678 is_admin: bool,
2679 ) -> Result<()> {
2680 self.background.simulate_random_delay().await;
2681 if !self.orgs.lock().contains_key(&org_id) {
2682 Err(anyhow!("org does not exist"))?;
2683 }
2684 if !self.users.lock().contains_key(&user_id) {
2685 Err(anyhow!("user does not exist"))?;
2686 }
2687
2688 self.org_memberships
2689 .lock()
2690 .entry((org_id, user_id))
2691 .or_insert(is_admin);
2692 Ok(())
2693 }
2694
2695 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2696 self.background.simulate_random_delay().await;
2697 if !self.orgs.lock().contains_key(&org_id) {
2698 Err(anyhow!("org does not exist"))?;
2699 }
2700
2701 let mut channels = self.channels.lock();
2702 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2703 channels.insert(
2704 channel_id,
2705 Channel {
2706 id: channel_id,
2707 name: name.to_string(),
2708 owner_id: org_id.0,
2709 owner_is_user: false,
2710 },
2711 );
2712 Ok(channel_id)
2713 }
2714
2715 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2716 self.background.simulate_random_delay().await;
2717 Ok(self
2718 .channels
2719 .lock()
2720 .values()
2721 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2722 .cloned()
2723 .collect())
2724 }
2725
2726 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2727 self.background.simulate_random_delay().await;
2728 let channels = self.channels.lock();
2729 let memberships = self.channel_memberships.lock();
2730 Ok(channels
2731 .values()
2732 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2733 .cloned()
2734 .collect())
2735 }
2736
2737 async fn can_user_access_channel(
2738 &self,
2739 user_id: UserId,
2740 channel_id: ChannelId,
2741 ) -> Result<bool> {
2742 self.background.simulate_random_delay().await;
2743 Ok(self
2744 .channel_memberships
2745 .lock()
2746 .contains_key(&(channel_id, user_id)))
2747 }
2748
2749 async fn add_channel_member(
2750 &self,
2751 channel_id: ChannelId,
2752 user_id: UserId,
2753 is_admin: bool,
2754 ) -> Result<()> {
2755 self.background.simulate_random_delay().await;
2756 if !self.channels.lock().contains_key(&channel_id) {
2757 Err(anyhow!("channel does not exist"))?;
2758 }
2759 if !self.users.lock().contains_key(&user_id) {
2760 Err(anyhow!("user does not exist"))?;
2761 }
2762
2763 self.channel_memberships
2764 .lock()
2765 .entry((channel_id, user_id))
2766 .or_insert(is_admin);
2767 Ok(())
2768 }
2769
2770 async fn create_channel_message(
2771 &self,
2772 channel_id: ChannelId,
2773 sender_id: UserId,
2774 body: &str,
2775 timestamp: OffsetDateTime,
2776 nonce: u128,
2777 ) -> Result<MessageId> {
2778 self.background.simulate_random_delay().await;
2779 if !self.channels.lock().contains_key(&channel_id) {
2780 Err(anyhow!("channel does not exist"))?;
2781 }
2782 if !self.users.lock().contains_key(&sender_id) {
2783 Err(anyhow!("user does not exist"))?;
2784 }
2785
2786 let mut messages = self.channel_messages.lock();
2787 if let Some(message) = messages
2788 .values()
2789 .find(|message| message.nonce.as_u128() == nonce)
2790 {
2791 Ok(message.id)
2792 } else {
2793 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2794 messages.insert(
2795 message_id,
2796 ChannelMessage {
2797 id: message_id,
2798 channel_id,
2799 sender_id,
2800 body: body.to_string(),
2801 sent_at: timestamp,
2802 nonce: Uuid::from_u128(nonce),
2803 },
2804 );
2805 Ok(message_id)
2806 }
2807 }
2808
2809 async fn get_channel_messages(
2810 &self,
2811 channel_id: ChannelId,
2812 count: usize,
2813 before_id: Option<MessageId>,
2814 ) -> Result<Vec<ChannelMessage>> {
2815 self.background.simulate_random_delay().await;
2816 let mut messages = self
2817 .channel_messages
2818 .lock()
2819 .values()
2820 .rev()
2821 .filter(|message| {
2822 message.channel_id == channel_id
2823 && message.id < before_id.unwrap_or(MessageId::MAX)
2824 })
2825 .take(count)
2826 .cloned()
2827 .collect::<Vec<_>>();
2828 messages.sort_unstable_by_key(|message| message.id);
2829 Ok(messages)
2830 }
2831
2832 async fn teardown(&self, _: &str) {}
2833
2834 #[cfg(test)]
2835 fn as_fake(&self) -> Option<&FakeDb> {
2836 Some(self)
2837 }
2838 }
2839}