1use std::{ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::Serialize;
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::OffsetDateTime;
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
29 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
30 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
31 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
32 async fn destroy_user(&self, id: UserId) -> Result<()>;
33
34 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
35 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
36 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
37 async fn redeem_invite_code(
38 &self,
39 code: &str,
40 login: &str,
41 email_address: Option<&str>,
42 ) -> Result<UserId>;
43
44 /// Registers a new project for the given user.
45 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
46
47 /// Unregisters a project for the given project id.
48 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
49
50 /// Update file counts by extension for the given project and worktree.
51 async fn update_worktree_extensions(
52 &self,
53 project_id: ProjectId,
54 worktree_id: u64,
55 extensions: HashMap<String, usize>,
56 ) -> Result<()>;
57
58 /// Get the file counts on the given project keyed by their worktree and extension.
59 async fn get_project_extensions(
60 &self,
61 project_id: ProjectId,
62 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
63
64 /// Record which users have been active in which projects during
65 /// a given period of time.
66 async fn record_project_activity(
67 &self,
68 time_period: Range<OffsetDateTime>,
69 active_projects: &[(UserId, ProjectId)],
70 ) -> Result<()>;
71
72 /// Get the users that have been most active during the given time period,
73 /// along with the amount of time they have been active in each project.
74 async fn summarize_project_activity(
75 &self,
76 time_period: Range<OffsetDateTime>,
77 max_user_count: usize,
78 ) -> Result<Vec<UserActivitySummary>>;
79
80 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
81 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
82 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
83 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
84 async fn dismiss_contact_notification(
85 &self,
86 responder_id: UserId,
87 requester_id: UserId,
88 ) -> Result<()>;
89 async fn respond_to_contact_request(
90 &self,
91 responder_id: UserId,
92 requester_id: UserId,
93 accept: bool,
94 ) -> Result<()>;
95
96 async fn create_access_token_hash(
97 &self,
98 user_id: UserId,
99 access_token_hash: &str,
100 max_access_token_count: usize,
101 ) -> Result<()>;
102 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
103 #[cfg(any(test, feature = "seed-support"))]
104
105 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
106 #[cfg(any(test, feature = "seed-support"))]
107 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
108 #[cfg(any(test, feature = "seed-support"))]
109 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
110 #[cfg(any(test, feature = "seed-support"))]
111 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
112 #[cfg(any(test, feature = "seed-support"))]
113
114 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
115 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
116 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
117 -> Result<bool>;
118 #[cfg(any(test, feature = "seed-support"))]
119 async fn add_channel_member(
120 &self,
121 channel_id: ChannelId,
122 user_id: UserId,
123 is_admin: bool,
124 ) -> Result<()>;
125 async fn create_channel_message(
126 &self,
127 channel_id: ChannelId,
128 sender_id: UserId,
129 body: &str,
130 timestamp: OffsetDateTime,
131 nonce: u128,
132 ) -> Result<MessageId>;
133 async fn get_channel_messages(
134 &self,
135 channel_id: ChannelId,
136 count: usize,
137 before_id: Option<MessageId>,
138 ) -> Result<Vec<ChannelMessage>>;
139 #[cfg(test)]
140 async fn teardown(&self, url: &str);
141 #[cfg(test)]
142 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
143}
144
145pub struct PostgresDb {
146 pool: sqlx::PgPool,
147}
148
149impl PostgresDb {
150 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
151 let pool = DbOptions::new()
152 .max_connections(max_connections)
153 .connect(&url)
154 .await
155 .context("failed to connect to postgres database")?;
156 Ok(Self { pool })
157 }
158}
159
160#[async_trait]
161impl Db for PostgresDb {
162 // users
163
164 async fn create_user(
165 &self,
166 github_login: &str,
167 email_address: Option<&str>,
168 admin: bool,
169 ) -> Result<UserId> {
170 let query = "
171 INSERT INTO users (github_login, email_address, admin)
172 VALUES ($1, $2, $3)
173 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
174 RETURNING id
175 ";
176 Ok(sqlx::query_scalar(query)
177 .bind(github_login)
178 .bind(email_address)
179 .bind(admin)
180 .fetch_one(&self.pool)
181 .await
182 .map(UserId)?)
183 }
184
185 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
186 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
187 Ok(sqlx::query_as(query)
188 .bind(limit as i32)
189 .bind((page * limit) as i32)
190 .fetch_all(&self.pool)
191 .await?)
192 }
193
194 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
195 let mut query = QueryBuilder::new(
196 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
197 );
198 query.push_values(
199 users,
200 |mut query, (github_login, email_address, invite_count)| {
201 query
202 .push_bind(github_login)
203 .push_bind(email_address)
204 .push_bind(false)
205 .push_bind(nanoid!(16))
206 .push_bind(invite_count as i32);
207 },
208 );
209 query.push(
210 "
211 ON CONFLICT (github_login) DO UPDATE SET
212 github_login = excluded.github_login,
213 invite_count = excluded.invite_count,
214 invite_code = CASE WHEN users.invite_code IS NULL
215 THEN excluded.invite_code
216 ELSE users.invite_code
217 END
218 RETURNING id
219 ",
220 );
221
222 let rows = query.build().fetch_all(&self.pool).await?;
223 Ok(rows
224 .into_iter()
225 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
226 .collect())
227 }
228
229 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
230 let like_string = fuzzy_like_string(name_query);
231 let query = "
232 SELECT users.*
233 FROM users
234 WHERE github_login ILIKE $1
235 ORDER BY github_login <-> $2
236 LIMIT $3
237 ";
238 Ok(sqlx::query_as(query)
239 .bind(like_string)
240 .bind(name_query)
241 .bind(limit as i32)
242 .fetch_all(&self.pool)
243 .await?)
244 }
245
246 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
247 let users = self.get_users_by_ids(vec![id]).await?;
248 Ok(users.into_iter().next())
249 }
250
251 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
252 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
253 let query = "
254 SELECT users.*
255 FROM users
256 WHERE users.id = ANY ($1)
257 ";
258 Ok(sqlx::query_as(query)
259 .bind(&ids)
260 .fetch_all(&self.pool)
261 .await?)
262 }
263
264 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
265 let query = format!(
266 "
267 SELECT users.*
268 FROM users
269 WHERE invite_count = 0
270 AND inviter_id IS{} NULL
271 ",
272 if invited_by_another_user { " NOT" } else { "" }
273 );
274
275 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
276 }
277
278 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
279 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
280 Ok(sqlx::query_as(query)
281 .bind(github_login)
282 .fetch_optional(&self.pool)
283 .await?)
284 }
285
286 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
287 let query = "UPDATE users SET admin = $1 WHERE id = $2";
288 Ok(sqlx::query(query)
289 .bind(is_admin)
290 .bind(id.0)
291 .execute(&self.pool)
292 .await
293 .map(drop)?)
294 }
295
296 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
297 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
298 Ok(sqlx::query(query)
299 .bind(connected_once)
300 .bind(id.0)
301 .execute(&self.pool)
302 .await
303 .map(drop)?)
304 }
305
306 async fn destroy_user(&self, id: UserId) -> Result<()> {
307 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
308 sqlx::query(query)
309 .bind(id.0)
310 .execute(&self.pool)
311 .await
312 .map(drop)?;
313 let query = "DELETE FROM users WHERE id = $1;";
314 Ok(sqlx::query(query)
315 .bind(id.0)
316 .execute(&self.pool)
317 .await
318 .map(drop)?)
319 }
320
321 // invite codes
322
323 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
324 let mut tx = self.pool.begin().await?;
325 if count > 0 {
326 sqlx::query(
327 "
328 UPDATE users
329 SET invite_code = $1
330 WHERE id = $2 AND invite_code IS NULL
331 ",
332 )
333 .bind(nanoid!(16))
334 .bind(id)
335 .execute(&mut tx)
336 .await?;
337 }
338
339 sqlx::query(
340 "
341 UPDATE users
342 SET invite_count = $1
343 WHERE id = $2
344 ",
345 )
346 .bind(count as i32)
347 .bind(id)
348 .execute(&mut tx)
349 .await?;
350 tx.commit().await?;
351 Ok(())
352 }
353
354 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
355 let result: Option<(String, i32)> = sqlx::query_as(
356 "
357 SELECT invite_code, invite_count
358 FROM users
359 WHERE id = $1 AND invite_code IS NOT NULL
360 ",
361 )
362 .bind(id)
363 .fetch_optional(&self.pool)
364 .await?;
365 if let Some((code, count)) = result {
366 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
367 } else {
368 Ok(None)
369 }
370 }
371
372 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
373 sqlx::query_as(
374 "
375 SELECT *
376 FROM users
377 WHERE invite_code = $1
378 ",
379 )
380 .bind(code)
381 .fetch_optional(&self.pool)
382 .await?
383 .ok_or_else(|| {
384 Error::Http(
385 StatusCode::NOT_FOUND,
386 "that invite code does not exist".to_string(),
387 )
388 })
389 }
390
391 async fn redeem_invite_code(
392 &self,
393 code: &str,
394 login: &str,
395 email_address: Option<&str>,
396 ) -> Result<UserId> {
397 let mut tx = self.pool.begin().await?;
398
399 let inviter_id: Option<UserId> = sqlx::query_scalar(
400 "
401 UPDATE users
402 SET invite_count = invite_count - 1
403 WHERE
404 invite_code = $1 AND
405 invite_count > 0
406 RETURNING id
407 ",
408 )
409 .bind(code)
410 .fetch_optional(&mut tx)
411 .await?;
412
413 let inviter_id = match inviter_id {
414 Some(inviter_id) => inviter_id,
415 None => {
416 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
417 .bind(code)
418 .fetch_optional(&mut tx)
419 .await?
420 .is_some()
421 {
422 Err(Error::Http(
423 StatusCode::UNAUTHORIZED,
424 "no invites remaining".to_string(),
425 ))?
426 } else {
427 Err(Error::Http(
428 StatusCode::NOT_FOUND,
429 "invite code not found".to_string(),
430 ))?
431 }
432 }
433 };
434
435 let invitee_id = sqlx::query_scalar(
436 "
437 INSERT INTO users
438 (github_login, email_address, admin, inviter_id)
439 VALUES
440 ($1, $2, 'f', $3)
441 RETURNING id
442 ",
443 )
444 .bind(login)
445 .bind(email_address)
446 .bind(inviter_id)
447 .fetch_one(&mut tx)
448 .await
449 .map(UserId)?;
450
451 sqlx::query(
452 "
453 INSERT INTO contacts
454 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
455 VALUES
456 ($1, $2, 't', 't', 't')
457 ",
458 )
459 .bind(inviter_id)
460 .bind(invitee_id)
461 .execute(&mut tx)
462 .await?;
463
464 tx.commit().await?;
465 Ok(invitee_id)
466 }
467
468 // projects
469
470 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
471 Ok(sqlx::query_scalar(
472 "
473 INSERT INTO projects(host_user_id)
474 VALUES ($1)
475 RETURNING id
476 ",
477 )
478 .bind(host_user_id)
479 .fetch_one(&self.pool)
480 .await
481 .map(ProjectId)?)
482 }
483
484 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
485 sqlx::query(
486 "
487 UPDATE projects
488 SET unregistered = 't'
489 WHERE id = $1
490 ",
491 )
492 .bind(project_id)
493 .execute(&self.pool)
494 .await?;
495 Ok(())
496 }
497
498 async fn update_worktree_extensions(
499 &self,
500 project_id: ProjectId,
501 worktree_id: u64,
502 extensions: HashMap<String, usize>,
503 ) -> Result<()> {
504 if extensions.is_empty() {
505 return Ok(());
506 }
507
508 let mut query = QueryBuilder::new(
509 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
510 );
511 query.push_values(extensions, |mut query, (extension, count)| {
512 query
513 .push_bind(project_id)
514 .push_bind(worktree_id as i32)
515 .push_bind(extension)
516 .push_bind(count as i32);
517 });
518 query.push(
519 "
520 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
521 count = excluded.count
522 ",
523 );
524 query.build().execute(&self.pool).await?;
525
526 Ok(())
527 }
528
529 async fn get_project_extensions(
530 &self,
531 project_id: ProjectId,
532 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
533 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
534 struct WorktreeExtension {
535 worktree_id: i32,
536 extension: String,
537 count: i32,
538 }
539
540 let query = "
541 SELECT worktree_id, extension, count
542 FROM worktree_extensions
543 WHERE project_id = $1
544 ";
545 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
546 .bind(&project_id)
547 .fetch_all(&self.pool)
548 .await?;
549
550 let mut extension_counts = HashMap::default();
551 for count in counts {
552 extension_counts
553 .entry(count.worktree_id as u64)
554 .or_insert(HashMap::default())
555 .insert(count.extension, count.count as usize);
556 }
557 Ok(extension_counts)
558 }
559
560 async fn record_project_activity(
561 &self,
562 time_period: Range<OffsetDateTime>,
563 projects: &[(UserId, ProjectId)],
564 ) -> Result<()> {
565 let query = "
566 INSERT INTO project_activity_periods
567 (ended_at, duration_millis, user_id, project_id)
568 VALUES
569 ($1, $2, $3, $4);
570 ";
571
572 let mut tx = self.pool.begin().await?;
573 let duration_millis =
574 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
575 for (user_id, project_id) in projects {
576 sqlx::query(query)
577 .bind(time_period.end)
578 .bind(duration_millis)
579 .bind(user_id)
580 .bind(project_id)
581 .execute(&mut tx)
582 .await?;
583 }
584 tx.commit().await?;
585
586 Ok(())
587 }
588
589 async fn summarize_project_activity(
590 &self,
591 time_period: Range<OffsetDateTime>,
592 max_user_count: usize,
593 ) -> Result<Vec<UserActivitySummary>> {
594 let query = "
595 WITH
596 project_durations AS (
597 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
598 FROM project_activity_periods
599 WHERE $1 <= ended_at AND ended_at <= $2
600 GROUP BY user_id, project_id
601 ),
602 user_durations AS (
603 SELECT user_id, SUM(project_duration) as total_duration
604 FROM project_durations
605 GROUP BY user_id
606 ORDER BY total_duration DESC
607 LIMIT $3
608 )
609 SELECT user_durations.user_id, users.github_login, project_id, project_duration
610 FROM user_durations, project_durations, users
611 WHERE
612 user_durations.user_id = project_durations.user_id AND
613 user_durations.user_id = users.id
614 ORDER BY user_id ASC, project_duration DESC
615 ";
616
617 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
618 .bind(time_period.start)
619 .bind(time_period.end)
620 .bind(max_user_count as i32)
621 .fetch(&self.pool);
622
623 let mut result = Vec::<UserActivitySummary>::new();
624 while let Some(row) = rows.next().await {
625 let (user_id, github_login, project_id, duration_millis) = row?;
626 let project_id = project_id;
627 let duration = Duration::from_millis(duration_millis as u64);
628 if let Some(last_summary) = result.last_mut() {
629 if last_summary.id == user_id {
630 last_summary.project_activity.push((project_id, duration));
631 continue;
632 }
633 }
634 result.push(UserActivitySummary {
635 id: user_id,
636 project_activity: vec![(project_id, duration)],
637 github_login,
638 });
639 }
640
641 Ok(result)
642 }
643
644 // contacts
645
646 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
647 let query = "
648 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
649 FROM contacts
650 WHERE user_id_a = $1 OR user_id_b = $1;
651 ";
652
653 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
654 .bind(user_id)
655 .fetch(&self.pool);
656
657 let mut contacts = vec![Contact::Accepted {
658 user_id,
659 should_notify: false,
660 }];
661 while let Some(row) = rows.next().await {
662 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
663
664 if user_id_a == user_id {
665 if accepted {
666 contacts.push(Contact::Accepted {
667 user_id: user_id_b,
668 should_notify: should_notify && a_to_b,
669 });
670 } else if a_to_b {
671 contacts.push(Contact::Outgoing { user_id: user_id_b })
672 } else {
673 contacts.push(Contact::Incoming {
674 user_id: user_id_b,
675 should_notify,
676 });
677 }
678 } else {
679 if accepted {
680 contacts.push(Contact::Accepted {
681 user_id: user_id_a,
682 should_notify: should_notify && !a_to_b,
683 });
684 } else if a_to_b {
685 contacts.push(Contact::Incoming {
686 user_id: user_id_a,
687 should_notify,
688 });
689 } else {
690 contacts.push(Contact::Outgoing { user_id: user_id_a });
691 }
692 }
693 }
694
695 contacts.sort_unstable_by_key(|contact| contact.user_id());
696
697 Ok(contacts)
698 }
699
700 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
701 let (id_a, id_b) = if user_id_1 < user_id_2 {
702 (user_id_1, user_id_2)
703 } else {
704 (user_id_2, user_id_1)
705 };
706
707 let query = "
708 SELECT 1 FROM contacts
709 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
710 LIMIT 1
711 ";
712 Ok(sqlx::query_scalar::<_, i32>(query)
713 .bind(id_a.0)
714 .bind(id_b.0)
715 .fetch_optional(&self.pool)
716 .await?
717 .is_some())
718 }
719
720 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
721 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
722 (sender_id, receiver_id, true)
723 } else {
724 (receiver_id, sender_id, false)
725 };
726 let query = "
727 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
728 VALUES ($1, $2, $3, 'f', 't')
729 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
730 SET
731 accepted = 't',
732 should_notify = 'f'
733 WHERE
734 NOT contacts.accepted AND
735 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
736 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
737 ";
738 let result = sqlx::query(query)
739 .bind(id_a.0)
740 .bind(id_b.0)
741 .bind(a_to_b)
742 .execute(&self.pool)
743 .await?;
744
745 if result.rows_affected() == 1 {
746 Ok(())
747 } else {
748 Err(anyhow!("contact already requested"))?
749 }
750 }
751
752 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
753 let (id_a, id_b) = if responder_id < requester_id {
754 (responder_id, requester_id)
755 } else {
756 (requester_id, responder_id)
757 };
758 let query = "
759 DELETE FROM contacts
760 WHERE user_id_a = $1 AND user_id_b = $2;
761 ";
762 let result = sqlx::query(query)
763 .bind(id_a.0)
764 .bind(id_b.0)
765 .execute(&self.pool)
766 .await?;
767
768 if result.rows_affected() == 1 {
769 Ok(())
770 } else {
771 Err(anyhow!("no such contact"))?
772 }
773 }
774
775 async fn dismiss_contact_notification(
776 &self,
777 user_id: UserId,
778 contact_user_id: UserId,
779 ) -> Result<()> {
780 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
781 (user_id, contact_user_id, true)
782 } else {
783 (contact_user_id, user_id, false)
784 };
785
786 let query = "
787 UPDATE contacts
788 SET should_notify = 'f'
789 WHERE
790 user_id_a = $1 AND user_id_b = $2 AND
791 (
792 (a_to_b = $3 AND accepted) OR
793 (a_to_b != $3 AND NOT accepted)
794 );
795 ";
796
797 let result = sqlx::query(query)
798 .bind(id_a.0)
799 .bind(id_b.0)
800 .bind(a_to_b)
801 .execute(&self.pool)
802 .await?;
803
804 if result.rows_affected() == 0 {
805 Err(anyhow!("no such contact request"))?;
806 }
807
808 Ok(())
809 }
810
811 async fn respond_to_contact_request(
812 &self,
813 responder_id: UserId,
814 requester_id: UserId,
815 accept: bool,
816 ) -> Result<()> {
817 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
818 (responder_id, requester_id, false)
819 } else {
820 (requester_id, responder_id, true)
821 };
822 let result = if accept {
823 let query = "
824 UPDATE contacts
825 SET accepted = 't', should_notify = 't'
826 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
827 ";
828 sqlx::query(query)
829 .bind(id_a.0)
830 .bind(id_b.0)
831 .bind(a_to_b)
832 .execute(&self.pool)
833 .await?
834 } else {
835 let query = "
836 DELETE FROM contacts
837 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
838 ";
839 sqlx::query(query)
840 .bind(id_a.0)
841 .bind(id_b.0)
842 .bind(a_to_b)
843 .execute(&self.pool)
844 .await?
845 };
846 if result.rows_affected() == 1 {
847 Ok(())
848 } else {
849 Err(anyhow!("no such contact request"))?
850 }
851 }
852
853 // access tokens
854
855 async fn create_access_token_hash(
856 &self,
857 user_id: UserId,
858 access_token_hash: &str,
859 max_access_token_count: usize,
860 ) -> Result<()> {
861 let insert_query = "
862 INSERT INTO access_tokens (user_id, hash)
863 VALUES ($1, $2);
864 ";
865 let cleanup_query = "
866 DELETE FROM access_tokens
867 WHERE id IN (
868 SELECT id from access_tokens
869 WHERE user_id = $1
870 ORDER BY id DESC
871 OFFSET $3
872 )
873 ";
874
875 let mut tx = self.pool.begin().await?;
876 sqlx::query(insert_query)
877 .bind(user_id.0)
878 .bind(access_token_hash)
879 .execute(&mut tx)
880 .await?;
881 sqlx::query(cleanup_query)
882 .bind(user_id.0)
883 .bind(access_token_hash)
884 .bind(max_access_token_count as i32)
885 .execute(&mut tx)
886 .await?;
887 Ok(tx.commit().await?)
888 }
889
890 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
891 let query = "
892 SELECT hash
893 FROM access_tokens
894 WHERE user_id = $1
895 ORDER BY id DESC
896 ";
897 Ok(sqlx::query_scalar(query)
898 .bind(user_id.0)
899 .fetch_all(&self.pool)
900 .await?)
901 }
902
903 // orgs
904
905 #[allow(unused)] // Help rust-analyzer
906 #[cfg(any(test, feature = "seed-support"))]
907 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
908 let query = "
909 SELECT *
910 FROM orgs
911 WHERE slug = $1
912 ";
913 Ok(sqlx::query_as(query)
914 .bind(slug)
915 .fetch_optional(&self.pool)
916 .await?)
917 }
918
919 #[cfg(any(test, feature = "seed-support"))]
920 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
921 let query = "
922 INSERT INTO orgs (name, slug)
923 VALUES ($1, $2)
924 RETURNING id
925 ";
926 Ok(sqlx::query_scalar(query)
927 .bind(name)
928 .bind(slug)
929 .fetch_one(&self.pool)
930 .await
931 .map(OrgId)?)
932 }
933
934 #[cfg(any(test, feature = "seed-support"))]
935 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
936 let query = "
937 INSERT INTO org_memberships (org_id, user_id, admin)
938 VALUES ($1, $2, $3)
939 ON CONFLICT DO NOTHING
940 ";
941 Ok(sqlx::query(query)
942 .bind(org_id.0)
943 .bind(user_id.0)
944 .bind(is_admin)
945 .execute(&self.pool)
946 .await
947 .map(drop)?)
948 }
949
950 // channels
951
952 #[cfg(any(test, feature = "seed-support"))]
953 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
954 let query = "
955 INSERT INTO channels (owner_id, owner_is_user, name)
956 VALUES ($1, false, $2)
957 RETURNING id
958 ";
959 Ok(sqlx::query_scalar(query)
960 .bind(org_id.0)
961 .bind(name)
962 .fetch_one(&self.pool)
963 .await
964 .map(ChannelId)?)
965 }
966
967 #[allow(unused)] // Help rust-analyzer
968 #[cfg(any(test, feature = "seed-support"))]
969 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
970 let query = "
971 SELECT *
972 FROM channels
973 WHERE
974 channels.owner_is_user = false AND
975 channels.owner_id = $1
976 ";
977 Ok(sqlx::query_as(query)
978 .bind(org_id.0)
979 .fetch_all(&self.pool)
980 .await?)
981 }
982
983 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
984 let query = "
985 SELECT
986 channels.*
987 FROM
988 channel_memberships, channels
989 WHERE
990 channel_memberships.user_id = $1 AND
991 channel_memberships.channel_id = channels.id
992 ";
993 Ok(sqlx::query_as(query)
994 .bind(user_id.0)
995 .fetch_all(&self.pool)
996 .await?)
997 }
998
999 async fn can_user_access_channel(
1000 &self,
1001 user_id: UserId,
1002 channel_id: ChannelId,
1003 ) -> Result<bool> {
1004 let query = "
1005 SELECT id
1006 FROM channel_memberships
1007 WHERE user_id = $1 AND channel_id = $2
1008 LIMIT 1
1009 ";
1010 Ok(sqlx::query_scalar::<_, i32>(query)
1011 .bind(user_id.0)
1012 .bind(channel_id.0)
1013 .fetch_optional(&self.pool)
1014 .await
1015 .map(|e| e.is_some())?)
1016 }
1017
1018 #[cfg(any(test, feature = "seed-support"))]
1019 async fn add_channel_member(
1020 &self,
1021 channel_id: ChannelId,
1022 user_id: UserId,
1023 is_admin: bool,
1024 ) -> Result<()> {
1025 let query = "
1026 INSERT INTO channel_memberships (channel_id, user_id, admin)
1027 VALUES ($1, $2, $3)
1028 ON CONFLICT DO NOTHING
1029 ";
1030 Ok(sqlx::query(query)
1031 .bind(channel_id.0)
1032 .bind(user_id.0)
1033 .bind(is_admin)
1034 .execute(&self.pool)
1035 .await
1036 .map(drop)?)
1037 }
1038
1039 // messages
1040
1041 async fn create_channel_message(
1042 &self,
1043 channel_id: ChannelId,
1044 sender_id: UserId,
1045 body: &str,
1046 timestamp: OffsetDateTime,
1047 nonce: u128,
1048 ) -> Result<MessageId> {
1049 let query = "
1050 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1051 VALUES ($1, $2, $3, $4, $5)
1052 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1053 RETURNING id
1054 ";
1055 Ok(sqlx::query_scalar(query)
1056 .bind(channel_id.0)
1057 .bind(sender_id.0)
1058 .bind(body)
1059 .bind(timestamp)
1060 .bind(Uuid::from_u128(nonce))
1061 .fetch_one(&self.pool)
1062 .await
1063 .map(MessageId)?)
1064 }
1065
1066 async fn get_channel_messages(
1067 &self,
1068 channel_id: ChannelId,
1069 count: usize,
1070 before_id: Option<MessageId>,
1071 ) -> Result<Vec<ChannelMessage>> {
1072 let query = r#"
1073 SELECT * FROM (
1074 SELECT
1075 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1076 FROM
1077 channel_messages
1078 WHERE
1079 channel_id = $1 AND
1080 id < $2
1081 ORDER BY id DESC
1082 LIMIT $3
1083 ) as recent_messages
1084 ORDER BY id ASC
1085 "#;
1086 Ok(sqlx::query_as(query)
1087 .bind(channel_id.0)
1088 .bind(before_id.unwrap_or(MessageId::MAX))
1089 .bind(count as i64)
1090 .fetch_all(&self.pool)
1091 .await?)
1092 }
1093
1094 #[cfg(test)]
1095 async fn teardown(&self, url: &str) {
1096 use util::ResultExt;
1097
1098 let query = "
1099 SELECT pg_terminate_backend(pg_stat_activity.pid)
1100 FROM pg_stat_activity
1101 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1102 ";
1103 sqlx::query(query).execute(&self.pool).await.log_err();
1104 self.pool.close().await;
1105 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1106 .await
1107 .log_err();
1108 }
1109
1110 #[cfg(test)]
1111 fn as_fake(&self) -> Option<&tests::FakeDb> {
1112 None
1113 }
1114}
1115
1116macro_rules! id_type {
1117 ($name:ident) => {
1118 #[derive(
1119 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
1120 )]
1121 #[sqlx(transparent)]
1122 #[serde(transparent)]
1123 pub struct $name(pub i32);
1124
1125 impl $name {
1126 #[allow(unused)]
1127 pub const MAX: Self = Self(i32::MAX);
1128
1129 #[allow(unused)]
1130 pub fn from_proto(value: u64) -> Self {
1131 Self(value as i32)
1132 }
1133
1134 #[allow(unused)]
1135 pub fn to_proto(&self) -> u64 {
1136 self.0 as u64
1137 }
1138 }
1139
1140 impl std::fmt::Display for $name {
1141 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1142 self.0.fmt(f)
1143 }
1144 }
1145 };
1146}
1147
1148id_type!(UserId);
1149#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1150pub struct User {
1151 pub id: UserId,
1152 pub github_login: String,
1153 pub email_address: Option<String>,
1154 pub admin: bool,
1155 pub invite_code: Option<String>,
1156 pub invite_count: i32,
1157 pub connected_once: bool,
1158}
1159
1160id_type!(ProjectId);
1161#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1162pub struct Project {
1163 pub id: ProjectId,
1164 pub host_user_id: UserId,
1165 pub unregistered: bool,
1166}
1167
1168#[derive(Clone, Debug, PartialEq, Serialize)]
1169pub struct UserActivitySummary {
1170 pub id: UserId,
1171 pub github_login: String,
1172 pub project_activity: Vec<(ProjectId, Duration)>,
1173}
1174
1175id_type!(OrgId);
1176#[derive(FromRow)]
1177pub struct Org {
1178 pub id: OrgId,
1179 pub name: String,
1180 pub slug: String,
1181}
1182
1183id_type!(ChannelId);
1184#[derive(Clone, Debug, FromRow, Serialize)]
1185pub struct Channel {
1186 pub id: ChannelId,
1187 pub name: String,
1188 pub owner_id: i32,
1189 pub owner_is_user: bool,
1190}
1191
1192id_type!(MessageId);
1193#[derive(Clone, Debug, FromRow)]
1194pub struct ChannelMessage {
1195 pub id: MessageId,
1196 pub channel_id: ChannelId,
1197 pub sender_id: UserId,
1198 pub body: String,
1199 pub sent_at: OffsetDateTime,
1200 pub nonce: Uuid,
1201}
1202
1203#[derive(Clone, Debug, PartialEq, Eq)]
1204pub enum Contact {
1205 Accepted {
1206 user_id: UserId,
1207 should_notify: bool,
1208 },
1209 Outgoing {
1210 user_id: UserId,
1211 },
1212 Incoming {
1213 user_id: UserId,
1214 should_notify: bool,
1215 },
1216}
1217
1218impl Contact {
1219 pub fn user_id(&self) -> UserId {
1220 match self {
1221 Contact::Accepted { user_id, .. } => *user_id,
1222 Contact::Outgoing { user_id } => *user_id,
1223 Contact::Incoming { user_id, .. } => *user_id,
1224 }
1225 }
1226}
1227
1228#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1229pub struct IncomingContactRequest {
1230 pub requester_id: UserId,
1231 pub should_notify: bool,
1232}
1233
1234fn fuzzy_like_string(string: &str) -> String {
1235 let mut result = String::with_capacity(string.len() * 2 + 1);
1236 for c in string.chars() {
1237 if c.is_alphanumeric() {
1238 result.push('%');
1239 result.push(c);
1240 }
1241 }
1242 result.push('%');
1243 result
1244}
1245
1246#[cfg(test)]
1247pub mod tests {
1248 use super::*;
1249 use anyhow::anyhow;
1250 use collections::BTreeMap;
1251 use gpui::executor::Background;
1252 use lazy_static::lazy_static;
1253 use parking_lot::Mutex;
1254 use rand::prelude::*;
1255 use sqlx::{
1256 migrate::{MigrateDatabase, Migrator},
1257 Postgres,
1258 };
1259 use std::{path::Path, sync::Arc};
1260 use util::post_inc;
1261
1262 #[tokio::test(flavor = "multi_thread")]
1263 async fn test_get_users_by_ids() {
1264 for test_db in [
1265 TestDb::postgres().await,
1266 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1267 ] {
1268 let db = test_db.db();
1269
1270 let user = db.create_user("user", None, false).await.unwrap();
1271 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1272 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1273 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1274
1275 assert_eq!(
1276 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1277 .await
1278 .unwrap(),
1279 vec![
1280 User {
1281 id: user,
1282 github_login: "user".to_string(),
1283 admin: false,
1284 ..Default::default()
1285 },
1286 User {
1287 id: friend1,
1288 github_login: "friend-1".to_string(),
1289 admin: false,
1290 ..Default::default()
1291 },
1292 User {
1293 id: friend2,
1294 github_login: "friend-2".to_string(),
1295 admin: false,
1296 ..Default::default()
1297 },
1298 User {
1299 id: friend3,
1300 github_login: "friend-3".to_string(),
1301 admin: false,
1302 ..Default::default()
1303 }
1304 ]
1305 );
1306 }
1307 }
1308
1309 #[tokio::test(flavor = "multi_thread")]
1310 async fn test_create_users() {
1311 let db = TestDb::postgres().await;
1312 let db = db.db();
1313
1314 // Create the first batch of users, ensuring invite counts are assigned
1315 // correctly and the respective invite codes are unique.
1316 let user_ids_batch_1 = db
1317 .create_users(vec![
1318 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1319 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1320 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1321 ])
1322 .await
1323 .unwrap();
1324 assert_eq!(user_ids_batch_1.len(), 3);
1325
1326 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1327 assert_eq!(users.len(), 3);
1328 assert_eq!(users[0].github_login, "user1");
1329 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1330 assert_eq!(users[0].invite_count, 5);
1331 assert_eq!(users[1].github_login, "user2");
1332 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1333 assert_eq!(users[1].invite_count, 4);
1334 assert_eq!(users[2].github_login, "user3");
1335 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1336 assert_eq!(users[2].invite_count, 3);
1337
1338 let invite_code_1 = users[0].invite_code.clone().unwrap();
1339 let invite_code_2 = users[1].invite_code.clone().unwrap();
1340 let invite_code_3 = users[2].invite_code.clone().unwrap();
1341 assert_ne!(invite_code_1, invite_code_2);
1342 assert_ne!(invite_code_1, invite_code_3);
1343 assert_ne!(invite_code_2, invite_code_3);
1344
1345 // Create the second batch of users and include a user that is already in the database, ensuring
1346 // the invite count for the existing user is updated without changing their invite code.
1347 let user_ids_batch_2 = db
1348 .create_users(vec![
1349 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1350 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1351 ])
1352 .await
1353 .unwrap();
1354 assert_eq!(user_ids_batch_2.len(), 2);
1355 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1356
1357 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1358 assert_eq!(users.len(), 2);
1359 assert_eq!(users[0].github_login, "user2");
1360 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1361 assert_eq!(users[0].invite_count, 10);
1362 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1363 assert_eq!(users[1].github_login, "user4");
1364 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1365 assert_eq!(users[1].invite_count, 2);
1366
1367 let invite_code_4 = users[1].invite_code.clone().unwrap();
1368 assert_ne!(invite_code_4, invite_code_1);
1369 assert_ne!(invite_code_4, invite_code_2);
1370 assert_ne!(invite_code_4, invite_code_3);
1371 }
1372
1373 #[tokio::test(flavor = "multi_thread")]
1374 async fn test_worktree_extensions() {
1375 let test_db = TestDb::postgres().await;
1376 let db = test_db.db();
1377
1378 let user = db.create_user("user_1", None, false).await.unwrap();
1379 let project = db.register_project(user).await.unwrap();
1380
1381 db.update_worktree_extensions(project, 100, Default::default())
1382 .await
1383 .unwrap();
1384 db.update_worktree_extensions(
1385 project,
1386 100,
1387 [("rs".to_string(), 5), ("md".to_string(), 3)]
1388 .into_iter()
1389 .collect(),
1390 )
1391 .await
1392 .unwrap();
1393 db.update_worktree_extensions(
1394 project,
1395 100,
1396 [("rs".to_string(), 6), ("md".to_string(), 5)]
1397 .into_iter()
1398 .collect(),
1399 )
1400 .await
1401 .unwrap();
1402 db.update_worktree_extensions(
1403 project,
1404 101,
1405 [("ts".to_string(), 2), ("md".to_string(), 1)]
1406 .into_iter()
1407 .collect(),
1408 )
1409 .await
1410 .unwrap();
1411
1412 assert_eq!(
1413 db.get_project_extensions(project).await.unwrap(),
1414 [
1415 (
1416 100,
1417 [("rs".into(), 6), ("md".into(), 5),]
1418 .into_iter()
1419 .collect::<HashMap<_, _>>()
1420 ),
1421 (
1422 101,
1423 [("ts".into(), 2), ("md".into(), 1),]
1424 .into_iter()
1425 .collect::<HashMap<_, _>>()
1426 )
1427 ]
1428 .into_iter()
1429 .collect()
1430 );
1431 }
1432
1433 #[tokio::test(flavor = "multi_thread")]
1434 async fn test_project_activity() {
1435 let test_db = TestDb::postgres().await;
1436 let db = test_db.db();
1437
1438 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1439 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1440 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1441 let project_1 = db.register_project(user_1).await.unwrap();
1442 let project_2 = db.register_project(user_2).await.unwrap();
1443 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1444
1445 // User 2 opens a project
1446 let t1 = t0 + Duration::from_secs(10);
1447 db.record_project_activity(t0..t1, &[(user_2, project_2)])
1448 .await
1449 .unwrap();
1450
1451 let t2 = t1 + Duration::from_secs(10);
1452 db.record_project_activity(t1..t2, &[(user_2, project_2)])
1453 .await
1454 .unwrap();
1455
1456 // User 1 joins the project
1457 let t3 = t2 + Duration::from_secs(10);
1458 db.record_project_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1459 .await
1460 .unwrap();
1461
1462 // User 1 opens another project
1463 let t4 = t3 + Duration::from_secs(10);
1464 db.record_project_activity(
1465 t3..t4,
1466 &[
1467 (user_2, project_2),
1468 (user_1, project_2),
1469 (user_1, project_1),
1470 ],
1471 )
1472 .await
1473 .unwrap();
1474
1475 // User 3 joins that project
1476 let t5 = t4 + Duration::from_secs(10);
1477 db.record_project_activity(
1478 t4..t5,
1479 &[
1480 (user_2, project_2),
1481 (user_1, project_2),
1482 (user_1, project_1),
1483 (user_3, project_1),
1484 ],
1485 )
1486 .await
1487 .unwrap();
1488
1489 // User 2 leaves
1490 let t6 = t5 + Duration::from_secs(5);
1491 db.record_project_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1492 .await
1493 .unwrap();
1494
1495 let summary = db.summarize_project_activity(t0..t6, 10).await.unwrap();
1496 assert_eq!(
1497 summary,
1498 &[
1499 UserActivitySummary {
1500 id: user_1,
1501 github_login: "user_1".to_string(),
1502 project_activity: vec![
1503 (project_2, Duration::from_secs(30)),
1504 (project_1, Duration::from_secs(25))
1505 ]
1506 },
1507 UserActivitySummary {
1508 id: user_2,
1509 github_login: "user_2".to_string(),
1510 project_activity: vec![(project_2, Duration::from_secs(50))]
1511 },
1512 UserActivitySummary {
1513 id: user_3,
1514 github_login: "user_3".to_string(),
1515 project_activity: vec![(project_1, Duration::from_secs(15))]
1516 },
1517 ]
1518 );
1519 }
1520
1521 #[tokio::test(flavor = "multi_thread")]
1522 async fn test_recent_channel_messages() {
1523 for test_db in [
1524 TestDb::postgres().await,
1525 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1526 ] {
1527 let db = test_db.db();
1528 let user = db.create_user("user", None, false).await.unwrap();
1529 let org = db.create_org("org", "org").await.unwrap();
1530 let channel = db.create_org_channel(org, "channel").await.unwrap();
1531 for i in 0..10 {
1532 db.create_channel_message(
1533 channel,
1534 user,
1535 &i.to_string(),
1536 OffsetDateTime::now_utc(),
1537 i,
1538 )
1539 .await
1540 .unwrap();
1541 }
1542
1543 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1544 assert_eq!(
1545 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1546 ["5", "6", "7", "8", "9"]
1547 );
1548
1549 let prev_messages = db
1550 .get_channel_messages(channel, 4, Some(messages[0].id))
1551 .await
1552 .unwrap();
1553 assert_eq!(
1554 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1555 ["1", "2", "3", "4"]
1556 );
1557 }
1558 }
1559
1560 #[tokio::test(flavor = "multi_thread")]
1561 async fn test_channel_message_nonces() {
1562 for test_db in [
1563 TestDb::postgres().await,
1564 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1565 ] {
1566 let db = test_db.db();
1567 let user = db.create_user("user", None, false).await.unwrap();
1568 let org = db.create_org("org", "org").await.unwrap();
1569 let channel = db.create_org_channel(org, "channel").await.unwrap();
1570
1571 let msg1_id = db
1572 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1573 .await
1574 .unwrap();
1575 let msg2_id = db
1576 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1577 .await
1578 .unwrap();
1579 let msg3_id = db
1580 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1581 .await
1582 .unwrap();
1583 let msg4_id = db
1584 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1585 .await
1586 .unwrap();
1587
1588 assert_ne!(msg1_id, msg2_id);
1589 assert_eq!(msg1_id, msg3_id);
1590 assert_eq!(msg2_id, msg4_id);
1591 }
1592 }
1593
1594 #[tokio::test(flavor = "multi_thread")]
1595 async fn test_create_access_tokens() {
1596 let test_db = TestDb::postgres().await;
1597 let db = test_db.db();
1598 let user = db.create_user("the-user", None, false).await.unwrap();
1599
1600 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1601 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1602 assert_eq!(
1603 db.get_access_token_hashes(user).await.unwrap(),
1604 &["h2".to_string(), "h1".to_string()]
1605 );
1606
1607 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1608 assert_eq!(
1609 db.get_access_token_hashes(user).await.unwrap(),
1610 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1611 );
1612
1613 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1614 assert_eq!(
1615 db.get_access_token_hashes(user).await.unwrap(),
1616 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1617 );
1618
1619 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1620 assert_eq!(
1621 db.get_access_token_hashes(user).await.unwrap(),
1622 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1623 );
1624 }
1625
1626 #[test]
1627 fn test_fuzzy_like_string() {
1628 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1629 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1630 assert_eq!(fuzzy_like_string(" z "), "%z%");
1631 }
1632
1633 #[tokio::test(flavor = "multi_thread")]
1634 async fn test_fuzzy_search_users() {
1635 let test_db = TestDb::postgres().await;
1636 let db = test_db.db();
1637 for github_login in [
1638 "California",
1639 "colorado",
1640 "oregon",
1641 "washington",
1642 "florida",
1643 "delaware",
1644 "rhode-island",
1645 ] {
1646 db.create_user(github_login, None, false).await.unwrap();
1647 }
1648
1649 assert_eq!(
1650 fuzzy_search_user_names(db, "clr").await,
1651 &["colorado", "California"]
1652 );
1653 assert_eq!(
1654 fuzzy_search_user_names(db, "ro").await,
1655 &["rhode-island", "colorado", "oregon"],
1656 );
1657
1658 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1659 db.fuzzy_search_users(query, 10)
1660 .await
1661 .unwrap()
1662 .into_iter()
1663 .map(|user| user.github_login)
1664 .collect::<Vec<_>>()
1665 }
1666 }
1667
1668 #[tokio::test(flavor = "multi_thread")]
1669 async fn test_add_contacts() {
1670 for test_db in [
1671 TestDb::postgres().await,
1672 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1673 ] {
1674 let db = test_db.db();
1675
1676 let user_1 = db.create_user("user1", None, false).await.unwrap();
1677 let user_2 = db.create_user("user2", None, false).await.unwrap();
1678 let user_3 = db.create_user("user3", None, false).await.unwrap();
1679
1680 // User starts with no contacts
1681 assert_eq!(
1682 db.get_contacts(user_1).await.unwrap(),
1683 vec![Contact::Accepted {
1684 user_id: user_1,
1685 should_notify: false
1686 }],
1687 );
1688
1689 // User requests a contact. Both users see the pending request.
1690 db.send_contact_request(user_1, user_2).await.unwrap();
1691 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1692 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1693 assert_eq!(
1694 db.get_contacts(user_1).await.unwrap(),
1695 &[
1696 Contact::Accepted {
1697 user_id: user_1,
1698 should_notify: false
1699 },
1700 Contact::Outgoing { user_id: user_2 }
1701 ],
1702 );
1703 assert_eq!(
1704 db.get_contacts(user_2).await.unwrap(),
1705 &[
1706 Contact::Incoming {
1707 user_id: user_1,
1708 should_notify: true
1709 },
1710 Contact::Accepted {
1711 user_id: user_2,
1712 should_notify: false
1713 },
1714 ]
1715 );
1716
1717 // User 2 dismisses the contact request notification without accepting or rejecting.
1718 // We shouldn't notify them again.
1719 db.dismiss_contact_notification(user_1, user_2)
1720 .await
1721 .unwrap_err();
1722 db.dismiss_contact_notification(user_2, user_1)
1723 .await
1724 .unwrap();
1725 assert_eq!(
1726 db.get_contacts(user_2).await.unwrap(),
1727 &[
1728 Contact::Incoming {
1729 user_id: user_1,
1730 should_notify: false
1731 },
1732 Contact::Accepted {
1733 user_id: user_2,
1734 should_notify: false
1735 },
1736 ]
1737 );
1738
1739 // User can't accept their own contact request
1740 db.respond_to_contact_request(user_1, user_2, true)
1741 .await
1742 .unwrap_err();
1743
1744 // User accepts a contact request. Both users see the contact.
1745 db.respond_to_contact_request(user_2, user_1, true)
1746 .await
1747 .unwrap();
1748 assert_eq!(
1749 db.get_contacts(user_1).await.unwrap(),
1750 &[
1751 Contact::Accepted {
1752 user_id: user_1,
1753 should_notify: false
1754 },
1755 Contact::Accepted {
1756 user_id: user_2,
1757 should_notify: true
1758 }
1759 ],
1760 );
1761 assert!(db.has_contact(user_1, user_2).await.unwrap());
1762 assert!(db.has_contact(user_2, user_1).await.unwrap());
1763 assert_eq!(
1764 db.get_contacts(user_2).await.unwrap(),
1765 &[
1766 Contact::Accepted {
1767 user_id: user_1,
1768 should_notify: false,
1769 },
1770 Contact::Accepted {
1771 user_id: user_2,
1772 should_notify: false,
1773 },
1774 ]
1775 );
1776
1777 // Users cannot re-request existing contacts.
1778 db.send_contact_request(user_1, user_2).await.unwrap_err();
1779 db.send_contact_request(user_2, user_1).await.unwrap_err();
1780
1781 // Users can't dismiss notifications of them accepting other users' requests.
1782 db.dismiss_contact_notification(user_2, user_1)
1783 .await
1784 .unwrap_err();
1785 assert_eq!(
1786 db.get_contacts(user_1).await.unwrap(),
1787 &[
1788 Contact::Accepted {
1789 user_id: user_1,
1790 should_notify: false
1791 },
1792 Contact::Accepted {
1793 user_id: user_2,
1794 should_notify: true,
1795 },
1796 ]
1797 );
1798
1799 // Users can dismiss notifications of other users accepting their requests.
1800 db.dismiss_contact_notification(user_1, user_2)
1801 .await
1802 .unwrap();
1803 assert_eq!(
1804 db.get_contacts(user_1).await.unwrap(),
1805 &[
1806 Contact::Accepted {
1807 user_id: user_1,
1808 should_notify: false
1809 },
1810 Contact::Accepted {
1811 user_id: user_2,
1812 should_notify: false,
1813 },
1814 ]
1815 );
1816
1817 // Users send each other concurrent contact requests and
1818 // see that they are immediately accepted.
1819 db.send_contact_request(user_1, user_3).await.unwrap();
1820 db.send_contact_request(user_3, user_1).await.unwrap();
1821 assert_eq!(
1822 db.get_contacts(user_1).await.unwrap(),
1823 &[
1824 Contact::Accepted {
1825 user_id: user_1,
1826 should_notify: false
1827 },
1828 Contact::Accepted {
1829 user_id: user_2,
1830 should_notify: false,
1831 },
1832 Contact::Accepted {
1833 user_id: user_3,
1834 should_notify: false
1835 },
1836 ]
1837 );
1838 assert_eq!(
1839 db.get_contacts(user_3).await.unwrap(),
1840 &[
1841 Contact::Accepted {
1842 user_id: user_1,
1843 should_notify: false
1844 },
1845 Contact::Accepted {
1846 user_id: user_3,
1847 should_notify: false
1848 }
1849 ],
1850 );
1851
1852 // User declines a contact request. Both users see that it is gone.
1853 db.send_contact_request(user_2, user_3).await.unwrap();
1854 db.respond_to_contact_request(user_3, user_2, false)
1855 .await
1856 .unwrap();
1857 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1858 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1859 assert_eq!(
1860 db.get_contacts(user_2).await.unwrap(),
1861 &[
1862 Contact::Accepted {
1863 user_id: user_1,
1864 should_notify: false
1865 },
1866 Contact::Accepted {
1867 user_id: user_2,
1868 should_notify: false
1869 }
1870 ]
1871 );
1872 assert_eq!(
1873 db.get_contacts(user_3).await.unwrap(),
1874 &[
1875 Contact::Accepted {
1876 user_id: user_1,
1877 should_notify: false
1878 },
1879 Contact::Accepted {
1880 user_id: user_3,
1881 should_notify: false
1882 }
1883 ],
1884 );
1885 }
1886 }
1887
1888 #[tokio::test(flavor = "multi_thread")]
1889 async fn test_invite_codes() {
1890 let postgres = TestDb::postgres().await;
1891 let db = postgres.db();
1892 let user1 = db.create_user("user-1", None, false).await.unwrap();
1893
1894 // Initially, user 1 has no invite code
1895 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1896
1897 // Setting invite count to 0 when no code is assigned does not assign a new code
1898 db.set_invite_count(user1, 0).await.unwrap();
1899 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1900
1901 // User 1 creates an invite code that can be used twice.
1902 db.set_invite_count(user1, 2).await.unwrap();
1903 let (invite_code, invite_count) =
1904 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1905 assert_eq!(invite_count, 2);
1906
1907 // User 2 redeems the invite code and becomes a contact of user 1.
1908 let user2 = db
1909 .redeem_invite_code(&invite_code, "user-2", None)
1910 .await
1911 .unwrap();
1912 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1913 assert_eq!(invite_count, 1);
1914 assert_eq!(
1915 db.get_contacts(user1).await.unwrap(),
1916 [
1917 Contact::Accepted {
1918 user_id: user1,
1919 should_notify: false
1920 },
1921 Contact::Accepted {
1922 user_id: user2,
1923 should_notify: true
1924 }
1925 ]
1926 );
1927 assert_eq!(
1928 db.get_contacts(user2).await.unwrap(),
1929 [
1930 Contact::Accepted {
1931 user_id: user1,
1932 should_notify: false
1933 },
1934 Contact::Accepted {
1935 user_id: user2,
1936 should_notify: false
1937 }
1938 ]
1939 );
1940
1941 // User 3 redeems the invite code and becomes a contact of user 1.
1942 let user3 = db
1943 .redeem_invite_code(&invite_code, "user-3", None)
1944 .await
1945 .unwrap();
1946 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1947 assert_eq!(invite_count, 0);
1948 assert_eq!(
1949 db.get_contacts(user1).await.unwrap(),
1950 [
1951 Contact::Accepted {
1952 user_id: user1,
1953 should_notify: false
1954 },
1955 Contact::Accepted {
1956 user_id: user2,
1957 should_notify: true
1958 },
1959 Contact::Accepted {
1960 user_id: user3,
1961 should_notify: true
1962 }
1963 ]
1964 );
1965 assert_eq!(
1966 db.get_contacts(user3).await.unwrap(),
1967 [
1968 Contact::Accepted {
1969 user_id: user1,
1970 should_notify: false
1971 },
1972 Contact::Accepted {
1973 user_id: user3,
1974 should_notify: false
1975 },
1976 ]
1977 );
1978
1979 // Trying to reedem the code for the third time results in an error.
1980 db.redeem_invite_code(&invite_code, "user-4", None)
1981 .await
1982 .unwrap_err();
1983
1984 // Invite count can be updated after the code has been created.
1985 db.set_invite_count(user1, 2).await.unwrap();
1986 let (latest_code, invite_count) =
1987 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1988 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1989 assert_eq!(invite_count, 2);
1990
1991 // User 4 can now redeem the invite code and becomes a contact of user 1.
1992 let user4 = db
1993 .redeem_invite_code(&invite_code, "user-4", None)
1994 .await
1995 .unwrap();
1996 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1997 assert_eq!(invite_count, 1);
1998 assert_eq!(
1999 db.get_contacts(user1).await.unwrap(),
2000 [
2001 Contact::Accepted {
2002 user_id: user1,
2003 should_notify: false
2004 },
2005 Contact::Accepted {
2006 user_id: user2,
2007 should_notify: true
2008 },
2009 Contact::Accepted {
2010 user_id: user3,
2011 should_notify: true
2012 },
2013 Contact::Accepted {
2014 user_id: user4,
2015 should_notify: true
2016 }
2017 ]
2018 );
2019 assert_eq!(
2020 db.get_contacts(user4).await.unwrap(),
2021 [
2022 Contact::Accepted {
2023 user_id: user1,
2024 should_notify: false
2025 },
2026 Contact::Accepted {
2027 user_id: user4,
2028 should_notify: false
2029 },
2030 ]
2031 );
2032
2033 // An existing user cannot redeem invite codes.
2034 db.redeem_invite_code(&invite_code, "user-2", None)
2035 .await
2036 .unwrap_err();
2037 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2038 assert_eq!(invite_count, 1);
2039 }
2040
2041 pub struct TestDb {
2042 pub db: Option<Arc<dyn Db>>,
2043 pub url: String,
2044 }
2045
2046 impl TestDb {
2047 pub async fn postgres() -> Self {
2048 lazy_static! {
2049 static ref LOCK: Mutex<()> = Mutex::new(());
2050 }
2051
2052 let _guard = LOCK.lock();
2053 let mut rng = StdRng::from_entropy();
2054 let name = format!("zed-test-{}", rng.gen::<u128>());
2055 let url = format!("postgres://postgres@localhost/{}", name);
2056 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2057 Postgres::create_database(&url)
2058 .await
2059 .expect("failed to create test db");
2060 let db = PostgresDb::new(&url, 5).await.unwrap();
2061 let migrator = Migrator::new(migrations_path).await.unwrap();
2062 migrator.run(&db.pool).await.unwrap();
2063 Self {
2064 db: Some(Arc::new(db)),
2065 url,
2066 }
2067 }
2068
2069 pub fn fake(background: Arc<Background>) -> Self {
2070 Self {
2071 db: Some(Arc::new(FakeDb::new(background))),
2072 url: Default::default(),
2073 }
2074 }
2075
2076 pub fn db(&self) -> &Arc<dyn Db> {
2077 self.db.as_ref().unwrap()
2078 }
2079 }
2080
2081 impl Drop for TestDb {
2082 fn drop(&mut self) {
2083 if let Some(db) = self.db.take() {
2084 futures::executor::block_on(db.teardown(&self.url));
2085 }
2086 }
2087 }
2088
2089 pub struct FakeDb {
2090 background: Arc<Background>,
2091 pub users: Mutex<BTreeMap<UserId, User>>,
2092 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2093 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), usize>>,
2094 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2095 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2096 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2097 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2098 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2099 pub contacts: Mutex<Vec<FakeContact>>,
2100 next_channel_message_id: Mutex<i32>,
2101 next_user_id: Mutex<i32>,
2102 next_org_id: Mutex<i32>,
2103 next_channel_id: Mutex<i32>,
2104 next_project_id: Mutex<i32>,
2105 }
2106
2107 #[derive(Debug)]
2108 pub struct FakeContact {
2109 pub requester_id: UserId,
2110 pub responder_id: UserId,
2111 pub accepted: bool,
2112 pub should_notify: bool,
2113 }
2114
2115 impl FakeDb {
2116 pub fn new(background: Arc<Background>) -> Self {
2117 Self {
2118 background,
2119 users: Default::default(),
2120 next_user_id: Mutex::new(1),
2121 projects: Default::default(),
2122 worktree_extensions: Default::default(),
2123 next_project_id: Mutex::new(1),
2124 orgs: Default::default(),
2125 next_org_id: Mutex::new(1),
2126 org_memberships: Default::default(),
2127 channels: Default::default(),
2128 next_channel_id: Mutex::new(1),
2129 channel_memberships: Default::default(),
2130 channel_messages: Default::default(),
2131 next_channel_message_id: Mutex::new(1),
2132 contacts: Default::default(),
2133 }
2134 }
2135 }
2136
2137 #[async_trait]
2138 impl Db for FakeDb {
2139 async fn create_user(
2140 &self,
2141 github_login: &str,
2142 email_address: Option<&str>,
2143 admin: bool,
2144 ) -> Result<UserId> {
2145 self.background.simulate_random_delay().await;
2146
2147 let mut users = self.users.lock();
2148 if let Some(user) = users
2149 .values()
2150 .find(|user| user.github_login == github_login)
2151 {
2152 Ok(user.id)
2153 } else {
2154 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2155 users.insert(
2156 user_id,
2157 User {
2158 id: user_id,
2159 github_login: github_login.to_string(),
2160 email_address: email_address.map(str::to_string),
2161 admin,
2162 invite_code: None,
2163 invite_count: 0,
2164 connected_once: false,
2165 },
2166 );
2167 Ok(user_id)
2168 }
2169 }
2170
2171 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2172 unimplemented!()
2173 }
2174
2175 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2176 unimplemented!()
2177 }
2178
2179 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2180 unimplemented!()
2181 }
2182
2183 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2184 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2185 }
2186
2187 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2188 self.background.simulate_random_delay().await;
2189 let users = self.users.lock();
2190 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2191 }
2192
2193 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2194 unimplemented!()
2195 }
2196
2197 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2198 Ok(self
2199 .users
2200 .lock()
2201 .values()
2202 .find(|user| user.github_login == github_login)
2203 .cloned())
2204 }
2205
2206 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2207 unimplemented!()
2208 }
2209
2210 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2211 self.background.simulate_random_delay().await;
2212 let mut users = self.users.lock();
2213 let mut user = users
2214 .get_mut(&id)
2215 .ok_or_else(|| anyhow!("user not found"))?;
2216 user.connected_once = connected_once;
2217 Ok(())
2218 }
2219
2220 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2221 unimplemented!()
2222 }
2223
2224 // invite codes
2225
2226 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2227 unimplemented!()
2228 }
2229
2230 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2231 Ok(None)
2232 }
2233
2234 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2235 unimplemented!()
2236 }
2237
2238 async fn redeem_invite_code(
2239 &self,
2240 _code: &str,
2241 _login: &str,
2242 _email_address: Option<&str>,
2243 ) -> Result<UserId> {
2244 unimplemented!()
2245 }
2246
2247 // projects
2248
2249 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2250 self.background.simulate_random_delay().await;
2251 if !self.users.lock().contains_key(&host_user_id) {
2252 Err(anyhow!("no such user"))?;
2253 }
2254
2255 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2256 self.projects.lock().insert(
2257 project_id,
2258 Project {
2259 id: project_id,
2260 host_user_id,
2261 unregistered: false,
2262 },
2263 );
2264 Ok(project_id)
2265 }
2266
2267 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2268 self.projects
2269 .lock()
2270 .get_mut(&project_id)
2271 .ok_or_else(|| anyhow!("no such project"))?
2272 .unregistered = true;
2273 Ok(())
2274 }
2275
2276 async fn update_worktree_extensions(
2277 &self,
2278 project_id: ProjectId,
2279 worktree_id: u64,
2280 extensions: HashMap<String, usize>,
2281 ) -> Result<()> {
2282 self.background.simulate_random_delay().await;
2283 if !self.projects.lock().contains_key(&project_id) {
2284 Err(anyhow!("no such project"))?;
2285 }
2286
2287 for (extension, count) in extensions {
2288 self.worktree_extensions
2289 .lock()
2290 .insert((project_id, worktree_id, extension), count);
2291 }
2292
2293 Ok(())
2294 }
2295
2296 async fn get_project_extensions(
2297 &self,
2298 _project_id: ProjectId,
2299 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2300 unimplemented!()
2301 }
2302
2303 async fn record_project_activity(
2304 &self,
2305 _period: Range<OffsetDateTime>,
2306 _active_projects: &[(UserId, ProjectId)],
2307 ) -> Result<()> {
2308 unimplemented!()
2309 }
2310
2311 async fn summarize_project_activity(
2312 &self,
2313 _period: Range<OffsetDateTime>,
2314 _limit: usize,
2315 ) -> Result<Vec<UserActivitySummary>> {
2316 unimplemented!()
2317 }
2318
2319 // contacts
2320
2321 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2322 self.background.simulate_random_delay().await;
2323 let mut contacts = vec![Contact::Accepted {
2324 user_id: id,
2325 should_notify: false,
2326 }];
2327
2328 for contact in self.contacts.lock().iter() {
2329 if contact.requester_id == id {
2330 if contact.accepted {
2331 contacts.push(Contact::Accepted {
2332 user_id: contact.responder_id,
2333 should_notify: contact.should_notify,
2334 });
2335 } else {
2336 contacts.push(Contact::Outgoing {
2337 user_id: contact.responder_id,
2338 });
2339 }
2340 } else if contact.responder_id == id {
2341 if contact.accepted {
2342 contacts.push(Contact::Accepted {
2343 user_id: contact.requester_id,
2344 should_notify: false,
2345 });
2346 } else {
2347 contacts.push(Contact::Incoming {
2348 user_id: contact.requester_id,
2349 should_notify: contact.should_notify,
2350 });
2351 }
2352 }
2353 }
2354
2355 contacts.sort_unstable_by_key(|contact| contact.user_id());
2356 Ok(contacts)
2357 }
2358
2359 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2360 self.background.simulate_random_delay().await;
2361 Ok(self.contacts.lock().iter().any(|contact| {
2362 contact.accepted
2363 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2364 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2365 }))
2366 }
2367
2368 async fn send_contact_request(
2369 &self,
2370 requester_id: UserId,
2371 responder_id: UserId,
2372 ) -> Result<()> {
2373 let mut contacts = self.contacts.lock();
2374 for contact in contacts.iter_mut() {
2375 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2376 if contact.accepted {
2377 Err(anyhow!("contact already exists"))?;
2378 } else {
2379 Err(anyhow!("contact already requested"))?;
2380 }
2381 }
2382 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2383 if contact.accepted {
2384 Err(anyhow!("contact already exists"))?;
2385 } else {
2386 contact.accepted = true;
2387 contact.should_notify = false;
2388 return Ok(());
2389 }
2390 }
2391 }
2392 contacts.push(FakeContact {
2393 requester_id,
2394 responder_id,
2395 accepted: false,
2396 should_notify: true,
2397 });
2398 Ok(())
2399 }
2400
2401 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2402 self.contacts.lock().retain(|contact| {
2403 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2404 });
2405 Ok(())
2406 }
2407
2408 async fn dismiss_contact_notification(
2409 &self,
2410 user_id: UserId,
2411 contact_user_id: UserId,
2412 ) -> Result<()> {
2413 let mut contacts = self.contacts.lock();
2414 for contact in contacts.iter_mut() {
2415 if contact.requester_id == contact_user_id
2416 && contact.responder_id == user_id
2417 && !contact.accepted
2418 {
2419 contact.should_notify = false;
2420 return Ok(());
2421 }
2422 if contact.requester_id == user_id
2423 && contact.responder_id == contact_user_id
2424 && contact.accepted
2425 {
2426 contact.should_notify = false;
2427 return Ok(());
2428 }
2429 }
2430 Err(anyhow!("no such notification"))?
2431 }
2432
2433 async fn respond_to_contact_request(
2434 &self,
2435 responder_id: UserId,
2436 requester_id: UserId,
2437 accept: bool,
2438 ) -> Result<()> {
2439 let mut contacts = self.contacts.lock();
2440 for (ix, contact) in contacts.iter_mut().enumerate() {
2441 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2442 if contact.accepted {
2443 Err(anyhow!("contact already confirmed"))?;
2444 }
2445 if accept {
2446 contact.accepted = true;
2447 contact.should_notify = true;
2448 } else {
2449 contacts.remove(ix);
2450 }
2451 return Ok(());
2452 }
2453 }
2454 Err(anyhow!("no such contact request"))?
2455 }
2456
2457 async fn create_access_token_hash(
2458 &self,
2459 _user_id: UserId,
2460 _access_token_hash: &str,
2461 _max_access_token_count: usize,
2462 ) -> Result<()> {
2463 unimplemented!()
2464 }
2465
2466 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2467 unimplemented!()
2468 }
2469
2470 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2471 unimplemented!()
2472 }
2473
2474 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2475 self.background.simulate_random_delay().await;
2476 let mut orgs = self.orgs.lock();
2477 if orgs.values().any(|org| org.slug == slug) {
2478 Err(anyhow!("org already exists"))?
2479 } else {
2480 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2481 orgs.insert(
2482 org_id,
2483 Org {
2484 id: org_id,
2485 name: name.to_string(),
2486 slug: slug.to_string(),
2487 },
2488 );
2489 Ok(org_id)
2490 }
2491 }
2492
2493 async fn add_org_member(
2494 &self,
2495 org_id: OrgId,
2496 user_id: UserId,
2497 is_admin: bool,
2498 ) -> Result<()> {
2499 self.background.simulate_random_delay().await;
2500 if !self.orgs.lock().contains_key(&org_id) {
2501 Err(anyhow!("org does not exist"))?;
2502 }
2503 if !self.users.lock().contains_key(&user_id) {
2504 Err(anyhow!("user does not exist"))?;
2505 }
2506
2507 self.org_memberships
2508 .lock()
2509 .entry((org_id, user_id))
2510 .or_insert(is_admin);
2511 Ok(())
2512 }
2513
2514 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2515 self.background.simulate_random_delay().await;
2516 if !self.orgs.lock().contains_key(&org_id) {
2517 Err(anyhow!("org does not exist"))?;
2518 }
2519
2520 let mut channels = self.channels.lock();
2521 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2522 channels.insert(
2523 channel_id,
2524 Channel {
2525 id: channel_id,
2526 name: name.to_string(),
2527 owner_id: org_id.0,
2528 owner_is_user: false,
2529 },
2530 );
2531 Ok(channel_id)
2532 }
2533
2534 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2535 self.background.simulate_random_delay().await;
2536 Ok(self
2537 .channels
2538 .lock()
2539 .values()
2540 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2541 .cloned()
2542 .collect())
2543 }
2544
2545 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2546 self.background.simulate_random_delay().await;
2547 let channels = self.channels.lock();
2548 let memberships = self.channel_memberships.lock();
2549 Ok(channels
2550 .values()
2551 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2552 .cloned()
2553 .collect())
2554 }
2555
2556 async fn can_user_access_channel(
2557 &self,
2558 user_id: UserId,
2559 channel_id: ChannelId,
2560 ) -> Result<bool> {
2561 self.background.simulate_random_delay().await;
2562 Ok(self
2563 .channel_memberships
2564 .lock()
2565 .contains_key(&(channel_id, user_id)))
2566 }
2567
2568 async fn add_channel_member(
2569 &self,
2570 channel_id: ChannelId,
2571 user_id: UserId,
2572 is_admin: bool,
2573 ) -> Result<()> {
2574 self.background.simulate_random_delay().await;
2575 if !self.channels.lock().contains_key(&channel_id) {
2576 Err(anyhow!("channel does not exist"))?;
2577 }
2578 if !self.users.lock().contains_key(&user_id) {
2579 Err(anyhow!("user does not exist"))?;
2580 }
2581
2582 self.channel_memberships
2583 .lock()
2584 .entry((channel_id, user_id))
2585 .or_insert(is_admin);
2586 Ok(())
2587 }
2588
2589 async fn create_channel_message(
2590 &self,
2591 channel_id: ChannelId,
2592 sender_id: UserId,
2593 body: &str,
2594 timestamp: OffsetDateTime,
2595 nonce: u128,
2596 ) -> Result<MessageId> {
2597 self.background.simulate_random_delay().await;
2598 if !self.channels.lock().contains_key(&channel_id) {
2599 Err(anyhow!("channel does not exist"))?;
2600 }
2601 if !self.users.lock().contains_key(&sender_id) {
2602 Err(anyhow!("user does not exist"))?;
2603 }
2604
2605 let mut messages = self.channel_messages.lock();
2606 if let Some(message) = messages
2607 .values()
2608 .find(|message| message.nonce.as_u128() == nonce)
2609 {
2610 Ok(message.id)
2611 } else {
2612 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2613 messages.insert(
2614 message_id,
2615 ChannelMessage {
2616 id: message_id,
2617 channel_id,
2618 sender_id,
2619 body: body.to_string(),
2620 sent_at: timestamp,
2621 nonce: Uuid::from_u128(nonce),
2622 },
2623 );
2624 Ok(message_id)
2625 }
2626 }
2627
2628 async fn get_channel_messages(
2629 &self,
2630 channel_id: ChannelId,
2631 count: usize,
2632 before_id: Option<MessageId>,
2633 ) -> Result<Vec<ChannelMessage>> {
2634 let mut messages = self
2635 .channel_messages
2636 .lock()
2637 .values()
2638 .rev()
2639 .filter(|message| {
2640 message.channel_id == channel_id
2641 && message.id < before_id.unwrap_or(MessageId::MAX)
2642 })
2643 .take(count)
2644 .cloned()
2645 .collect::<Vec<_>>();
2646 messages.sort_unstable_by_key(|message| message.id);
2647 Ok(messages)
2648 }
2649
2650 async fn teardown(&self, _: &str) {}
2651
2652 #[cfg(test)]
2653 fn as_fake(&self) -> Option<&FakeDb> {
2654 Some(self)
2655 }
2656 }
2657}