1use std::{ops::Range, time::Duration};
2
3use crate::{Error, Result};
4use anyhow::{anyhow, Context};
5use async_trait::async_trait;
6use axum::http::StatusCode;
7use collections::HashMap;
8use futures::StreamExt;
9use nanoid::nanoid;
10use serde::Serialize;
11pub use sqlx::postgres::PgPoolOptions as DbOptions;
12use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
13use time::OffsetDateTime;
14
15#[async_trait]
16pub trait Db: Send + Sync {
17 async fn create_user(
18 &self,
19 github_login: &str,
20 email_address: Option<&str>,
21 admin: bool,
22 ) -> Result<UserId>;
23 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
24 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
28 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
29 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
30 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
31 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
32 async fn destroy_user(&self, id: UserId) -> Result<()>;
33
34 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()>;
35 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
36 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
37 async fn redeem_invite_code(
38 &self,
39 code: &str,
40 login: &str,
41 email_address: Option<&str>,
42 ) -> Result<UserId>;
43
44 /// Registers a new project for the given user.
45 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
46
47 /// Unregisters a project for the given project id.
48 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
49
50 /// Update file counts by extension for the given project and worktree.
51 async fn update_worktree_extensions(
52 &self,
53 project_id: ProjectId,
54 worktree_id: u64,
55 extensions: HashMap<String, usize>,
56 ) -> Result<()>;
57
58 /// Get the file counts on the given project keyed by their worktree and extension.
59 async fn get_project_extensions(
60 &self,
61 project_id: ProjectId,
62 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
63
64 /// Record which users have been active in which projects during
65 /// a given period of time.
66 async fn record_project_activity(
67 &self,
68 time_period: Range<OffsetDateTime>,
69 active_projects: &[(UserId, ProjectId)],
70 ) -> Result<()>;
71
72 /// Get the users that have been most active during the given time period,
73 /// along with the amount of time they have been active in each project.
74 async fn summarize_project_activity(
75 &self,
76 time_period: Range<OffsetDateTime>,
77 max_user_count: usize,
78 ) -> Result<Vec<UserActivitySummary>>;
79
80 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
81 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
82 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
83 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
84 async fn dismiss_contact_notification(
85 &self,
86 responder_id: UserId,
87 requester_id: UserId,
88 ) -> Result<()>;
89 async fn respond_to_contact_request(
90 &self,
91 responder_id: UserId,
92 requester_id: UserId,
93 accept: bool,
94 ) -> Result<()>;
95
96 async fn create_access_token_hash(
97 &self,
98 user_id: UserId,
99 access_token_hash: &str,
100 max_access_token_count: usize,
101 ) -> Result<()>;
102 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
103 #[cfg(any(test, feature = "seed-support"))]
104
105 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
106 #[cfg(any(test, feature = "seed-support"))]
107 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
108 #[cfg(any(test, feature = "seed-support"))]
109 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
110 #[cfg(any(test, feature = "seed-support"))]
111 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
112 #[cfg(any(test, feature = "seed-support"))]
113
114 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
115 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
116 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
117 -> Result<bool>;
118 #[cfg(any(test, feature = "seed-support"))]
119 async fn add_channel_member(
120 &self,
121 channel_id: ChannelId,
122 user_id: UserId,
123 is_admin: bool,
124 ) -> Result<()>;
125 async fn create_channel_message(
126 &self,
127 channel_id: ChannelId,
128 sender_id: UserId,
129 body: &str,
130 timestamp: OffsetDateTime,
131 nonce: u128,
132 ) -> Result<MessageId>;
133 async fn get_channel_messages(
134 &self,
135 channel_id: ChannelId,
136 count: usize,
137 before_id: Option<MessageId>,
138 ) -> Result<Vec<ChannelMessage>>;
139 #[cfg(test)]
140 async fn teardown(&self, url: &str);
141 #[cfg(test)]
142 fn as_fake<'a>(&'a self) -> Option<&'a tests::FakeDb>;
143}
144
145pub struct PostgresDb {
146 pool: sqlx::PgPool,
147}
148
149impl PostgresDb {
150 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
151 let pool = DbOptions::new()
152 .max_connections(max_connections)
153 .connect(&url)
154 .await
155 .context("failed to connect to postgres database")?;
156 Ok(Self { pool })
157 }
158}
159
160#[async_trait]
161impl Db for PostgresDb {
162 // users
163
164 async fn create_user(
165 &self,
166 github_login: &str,
167 email_address: Option<&str>,
168 admin: bool,
169 ) -> Result<UserId> {
170 let query = "
171 INSERT INTO users (github_login, email_address, admin)
172 VALUES ($1, $2, $3)
173 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
174 RETURNING id
175 ";
176 Ok(sqlx::query_scalar(query)
177 .bind(github_login)
178 .bind(email_address)
179 .bind(admin)
180 .fetch_one(&self.pool)
181 .await
182 .map(UserId)?)
183 }
184
185 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
186 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
187 Ok(sqlx::query_as(query)
188 .bind(limit as i32)
189 .bind((page * limit) as i32)
190 .fetch_all(&self.pool)
191 .await?)
192 }
193
194 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
195 let mut query = QueryBuilder::new(
196 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
197 );
198 query.push_values(
199 users,
200 |mut query, (github_login, email_address, invite_count)| {
201 query
202 .push_bind(github_login)
203 .push_bind(email_address)
204 .push_bind(false)
205 .push_bind(nanoid!(16))
206 .push_bind(invite_count as i32);
207 },
208 );
209 query.push(
210 "
211 ON CONFLICT (github_login) DO UPDATE SET
212 github_login = excluded.github_login,
213 invite_count = excluded.invite_count,
214 invite_code = CASE WHEN users.invite_code IS NULL
215 THEN excluded.invite_code
216 ELSE users.invite_code
217 END
218 RETURNING id
219 ",
220 );
221
222 let rows = query.build().fetch_all(&self.pool).await?;
223 Ok(rows
224 .into_iter()
225 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
226 .collect())
227 }
228
229 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
230 let like_string = fuzzy_like_string(name_query);
231 let query = "
232 SELECT users.*
233 FROM users
234 WHERE github_login ILIKE $1
235 ORDER BY github_login <-> $2
236 LIMIT $3
237 ";
238 Ok(sqlx::query_as(query)
239 .bind(like_string)
240 .bind(name_query)
241 .bind(limit as i32)
242 .fetch_all(&self.pool)
243 .await?)
244 }
245
246 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
247 let users = self.get_users_by_ids(vec![id]).await?;
248 Ok(users.into_iter().next())
249 }
250
251 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
252 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
253 let query = "
254 SELECT users.*
255 FROM users
256 WHERE users.id = ANY ($1)
257 ";
258 Ok(sqlx::query_as(query)
259 .bind(&ids)
260 .fetch_all(&self.pool)
261 .await?)
262 }
263
264 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
265 let query = "
266 SELECT users.*
267 FROM users
268 WHERE invite_count = 0
269 ";
270 Ok(sqlx::query_as(query)
271 .bind(&ids)
272 .fetch_all(&self.pool)
273 .await?)
274 }
275
276 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
277 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
278 Ok(sqlx::query_as(query)
279 .bind(github_login)
280 .fetch_optional(&self.pool)
281 .await?)
282 }
283
284 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
285 let query = "UPDATE users SET admin = $1 WHERE id = $2";
286 Ok(sqlx::query(query)
287 .bind(is_admin)
288 .bind(id.0)
289 .execute(&self.pool)
290 .await
291 .map(drop)?)
292 }
293
294 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
295 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
296 Ok(sqlx::query(query)
297 .bind(connected_once)
298 .bind(id.0)
299 .execute(&self.pool)
300 .await
301 .map(drop)?)
302 }
303
304 async fn destroy_user(&self, id: UserId) -> Result<()> {
305 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
306 sqlx::query(query)
307 .bind(id.0)
308 .execute(&self.pool)
309 .await
310 .map(drop)?;
311 let query = "DELETE FROM users WHERE id = $1;";
312 Ok(sqlx::query(query)
313 .bind(id.0)
314 .execute(&self.pool)
315 .await
316 .map(drop)?)
317 }
318
319 // invite codes
320
321 async fn set_invite_count(&self, id: UserId, count: u32) -> Result<()> {
322 let mut tx = self.pool.begin().await?;
323 if count > 0 {
324 sqlx::query(
325 "
326 UPDATE users
327 SET invite_code = $1
328 WHERE id = $2 AND invite_code IS NULL
329 ",
330 )
331 .bind(nanoid!(16))
332 .bind(id)
333 .execute(&mut tx)
334 .await?;
335 }
336
337 sqlx::query(
338 "
339 UPDATE users
340 SET invite_count = $1
341 WHERE id = $2
342 ",
343 )
344 .bind(count as i32)
345 .bind(id)
346 .execute(&mut tx)
347 .await?;
348 tx.commit().await?;
349 Ok(())
350 }
351
352 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
353 let result: Option<(String, i32)> = sqlx::query_as(
354 "
355 SELECT invite_code, invite_count
356 FROM users
357 WHERE id = $1 AND invite_code IS NOT NULL
358 ",
359 )
360 .bind(id)
361 .fetch_optional(&self.pool)
362 .await?;
363 if let Some((code, count)) = result {
364 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
365 } else {
366 Ok(None)
367 }
368 }
369
370 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
371 sqlx::query_as(
372 "
373 SELECT *
374 FROM users
375 WHERE invite_code = $1
376 ",
377 )
378 .bind(code)
379 .fetch_optional(&self.pool)
380 .await?
381 .ok_or_else(|| {
382 Error::Http(
383 StatusCode::NOT_FOUND,
384 "that invite code does not exist".to_string(),
385 )
386 })
387 }
388
389 async fn redeem_invite_code(
390 &self,
391 code: &str,
392 login: &str,
393 email_address: Option<&str>,
394 ) -> Result<UserId> {
395 let mut tx = self.pool.begin().await?;
396
397 let inviter_id: Option<UserId> = sqlx::query_scalar(
398 "
399 UPDATE users
400 SET invite_count = invite_count - 1
401 WHERE
402 invite_code = $1 AND
403 invite_count > 0
404 RETURNING id
405 ",
406 )
407 .bind(code)
408 .fetch_optional(&mut tx)
409 .await?;
410
411 let inviter_id = match inviter_id {
412 Some(inviter_id) => inviter_id,
413 None => {
414 if sqlx::query_scalar::<_, i32>("SELECT 1 FROM users WHERE invite_code = $1")
415 .bind(code)
416 .fetch_optional(&mut tx)
417 .await?
418 .is_some()
419 {
420 Err(Error::Http(
421 StatusCode::UNAUTHORIZED,
422 "no invites remaining".to_string(),
423 ))?
424 } else {
425 Err(Error::Http(
426 StatusCode::NOT_FOUND,
427 "invite code not found".to_string(),
428 ))?
429 }
430 }
431 };
432
433 let invitee_id = sqlx::query_scalar(
434 "
435 INSERT INTO users
436 (github_login, email_address, admin, inviter_id)
437 VALUES
438 ($1, $2, 'f', $3)
439 RETURNING id
440 ",
441 )
442 .bind(login)
443 .bind(email_address)
444 .bind(inviter_id)
445 .fetch_one(&mut tx)
446 .await
447 .map(UserId)?;
448
449 sqlx::query(
450 "
451 INSERT INTO contacts
452 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
453 VALUES
454 ($1, $2, 't', 't', 't')
455 ",
456 )
457 .bind(inviter_id)
458 .bind(invitee_id)
459 .execute(&mut tx)
460 .await?;
461
462 tx.commit().await?;
463 Ok(invitee_id)
464 }
465
466 // projects
467
468 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
469 Ok(sqlx::query_scalar(
470 "
471 INSERT INTO projects(host_user_id)
472 VALUES ($1)
473 RETURNING id
474 ",
475 )
476 .bind(host_user_id)
477 .fetch_one(&self.pool)
478 .await
479 .map(ProjectId)?)
480 }
481
482 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
483 sqlx::query(
484 "
485 UPDATE projects
486 SET unregistered = 't'
487 WHERE id = $1
488 ",
489 )
490 .bind(project_id)
491 .execute(&self.pool)
492 .await?;
493 Ok(())
494 }
495
496 async fn update_worktree_extensions(
497 &self,
498 project_id: ProjectId,
499 worktree_id: u64,
500 extensions: HashMap<String, usize>,
501 ) -> Result<()> {
502 if extensions.is_empty() {
503 return Ok(());
504 }
505
506 let mut query = QueryBuilder::new(
507 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
508 );
509 query.push_values(extensions, |mut query, (extension, count)| {
510 query
511 .push_bind(project_id)
512 .push_bind(worktree_id as i32)
513 .push_bind(extension)
514 .push_bind(count as i32);
515 });
516 query.push(
517 "
518 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
519 count = excluded.count
520 ",
521 );
522 query.build().execute(&self.pool).await?;
523
524 Ok(())
525 }
526
527 async fn get_project_extensions(
528 &self,
529 project_id: ProjectId,
530 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
531 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
532 struct WorktreeExtension {
533 worktree_id: i32,
534 extension: String,
535 count: i32,
536 }
537
538 let query = "
539 SELECT worktree_id, extension, count
540 FROM worktree_extensions
541 WHERE project_id = $1
542 ";
543 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
544 .bind(&project_id)
545 .fetch_all(&self.pool)
546 .await?;
547
548 let mut extension_counts = HashMap::default();
549 for count in counts {
550 extension_counts
551 .entry(count.worktree_id as u64)
552 .or_insert(HashMap::default())
553 .insert(count.extension, count.count as usize);
554 }
555 Ok(extension_counts)
556 }
557
558 async fn record_project_activity(
559 &self,
560 time_period: Range<OffsetDateTime>,
561 projects: &[(UserId, ProjectId)],
562 ) -> Result<()> {
563 let query = "
564 INSERT INTO project_activity_periods
565 (ended_at, duration_millis, user_id, project_id)
566 VALUES
567 ($1, $2, $3, $4);
568 ";
569
570 let mut tx = self.pool.begin().await?;
571 let duration_millis =
572 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
573 for (user_id, project_id) in projects {
574 sqlx::query(query)
575 .bind(time_period.end)
576 .bind(duration_millis)
577 .bind(user_id)
578 .bind(project_id)
579 .execute(&mut tx)
580 .await?;
581 }
582 tx.commit().await?;
583
584 Ok(())
585 }
586
587 async fn summarize_project_activity(
588 &self,
589 time_period: Range<OffsetDateTime>,
590 max_user_count: usize,
591 ) -> Result<Vec<UserActivitySummary>> {
592 let query = "
593 WITH
594 project_durations AS (
595 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
596 FROM project_activity_periods
597 WHERE $1 <= ended_at AND ended_at <= $2
598 GROUP BY user_id, project_id
599 ),
600 user_durations AS (
601 SELECT user_id, SUM(project_duration) as total_duration
602 FROM project_durations
603 GROUP BY user_id
604 ORDER BY total_duration DESC
605 LIMIT $3
606 )
607 SELECT user_durations.user_id, users.github_login, project_id, project_duration
608 FROM user_durations, project_durations, users
609 WHERE
610 user_durations.user_id = project_durations.user_id AND
611 user_durations.user_id = users.id
612 ORDER BY user_id ASC, project_duration DESC
613 ";
614
615 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64)>(query)
616 .bind(time_period.start)
617 .bind(time_period.end)
618 .bind(max_user_count as i32)
619 .fetch(&self.pool);
620
621 let mut result = Vec::<UserActivitySummary>::new();
622 while let Some(row) = rows.next().await {
623 let (user_id, github_login, project_id, duration_millis) = row?;
624 let project_id = project_id;
625 let duration = Duration::from_millis(duration_millis as u64);
626 if let Some(last_summary) = result.last_mut() {
627 if last_summary.id == user_id {
628 last_summary.project_activity.push((project_id, duration));
629 continue;
630 }
631 }
632 result.push(UserActivitySummary {
633 id: user_id,
634 project_activity: vec![(project_id, duration)],
635 github_login,
636 });
637 }
638
639 Ok(result)
640 }
641
642 // contacts
643
644 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
645 let query = "
646 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
647 FROM contacts
648 WHERE user_id_a = $1 OR user_id_b = $1;
649 ";
650
651 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
652 .bind(user_id)
653 .fetch(&self.pool);
654
655 let mut contacts = vec![Contact::Accepted {
656 user_id,
657 should_notify: false,
658 }];
659 while let Some(row) = rows.next().await {
660 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
661
662 if user_id_a == user_id {
663 if accepted {
664 contacts.push(Contact::Accepted {
665 user_id: user_id_b,
666 should_notify: should_notify && a_to_b,
667 });
668 } else if a_to_b {
669 contacts.push(Contact::Outgoing { user_id: user_id_b })
670 } else {
671 contacts.push(Contact::Incoming {
672 user_id: user_id_b,
673 should_notify,
674 });
675 }
676 } else {
677 if accepted {
678 contacts.push(Contact::Accepted {
679 user_id: user_id_a,
680 should_notify: should_notify && !a_to_b,
681 });
682 } else if a_to_b {
683 contacts.push(Contact::Incoming {
684 user_id: user_id_a,
685 should_notify,
686 });
687 } else {
688 contacts.push(Contact::Outgoing { user_id: user_id_a });
689 }
690 }
691 }
692
693 contacts.sort_unstable_by_key(|contact| contact.user_id());
694
695 Ok(contacts)
696 }
697
698 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
699 let (id_a, id_b) = if user_id_1 < user_id_2 {
700 (user_id_1, user_id_2)
701 } else {
702 (user_id_2, user_id_1)
703 };
704
705 let query = "
706 SELECT 1 FROM contacts
707 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
708 LIMIT 1
709 ";
710 Ok(sqlx::query_scalar::<_, i32>(query)
711 .bind(id_a.0)
712 .bind(id_b.0)
713 .fetch_optional(&self.pool)
714 .await?
715 .is_some())
716 }
717
718 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
719 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
720 (sender_id, receiver_id, true)
721 } else {
722 (receiver_id, sender_id, false)
723 };
724 let query = "
725 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
726 VALUES ($1, $2, $3, 'f', 't')
727 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
728 SET
729 accepted = 't',
730 should_notify = 'f'
731 WHERE
732 NOT contacts.accepted AND
733 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
734 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
735 ";
736 let result = sqlx::query(query)
737 .bind(id_a.0)
738 .bind(id_b.0)
739 .bind(a_to_b)
740 .execute(&self.pool)
741 .await?;
742
743 if result.rows_affected() == 1 {
744 Ok(())
745 } else {
746 Err(anyhow!("contact already requested"))?
747 }
748 }
749
750 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
751 let (id_a, id_b) = if responder_id < requester_id {
752 (responder_id, requester_id)
753 } else {
754 (requester_id, responder_id)
755 };
756 let query = "
757 DELETE FROM contacts
758 WHERE user_id_a = $1 AND user_id_b = $2;
759 ";
760 let result = sqlx::query(query)
761 .bind(id_a.0)
762 .bind(id_b.0)
763 .execute(&self.pool)
764 .await?;
765
766 if result.rows_affected() == 1 {
767 Ok(())
768 } else {
769 Err(anyhow!("no such contact"))?
770 }
771 }
772
773 async fn dismiss_contact_notification(
774 &self,
775 user_id: UserId,
776 contact_user_id: UserId,
777 ) -> Result<()> {
778 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
779 (user_id, contact_user_id, true)
780 } else {
781 (contact_user_id, user_id, false)
782 };
783
784 let query = "
785 UPDATE contacts
786 SET should_notify = 'f'
787 WHERE
788 user_id_a = $1 AND user_id_b = $2 AND
789 (
790 (a_to_b = $3 AND accepted) OR
791 (a_to_b != $3 AND NOT accepted)
792 );
793 ";
794
795 let result = sqlx::query(query)
796 .bind(id_a.0)
797 .bind(id_b.0)
798 .bind(a_to_b)
799 .execute(&self.pool)
800 .await?;
801
802 if result.rows_affected() == 0 {
803 Err(anyhow!("no such contact request"))?;
804 }
805
806 Ok(())
807 }
808
809 async fn respond_to_contact_request(
810 &self,
811 responder_id: UserId,
812 requester_id: UserId,
813 accept: bool,
814 ) -> Result<()> {
815 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
816 (responder_id, requester_id, false)
817 } else {
818 (requester_id, responder_id, true)
819 };
820 let result = if accept {
821 let query = "
822 UPDATE contacts
823 SET accepted = 't', should_notify = 't'
824 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
825 ";
826 sqlx::query(query)
827 .bind(id_a.0)
828 .bind(id_b.0)
829 .bind(a_to_b)
830 .execute(&self.pool)
831 .await?
832 } else {
833 let query = "
834 DELETE FROM contacts
835 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
836 ";
837 sqlx::query(query)
838 .bind(id_a.0)
839 .bind(id_b.0)
840 .bind(a_to_b)
841 .execute(&self.pool)
842 .await?
843 };
844 if result.rows_affected() == 1 {
845 Ok(())
846 } else {
847 Err(anyhow!("no such contact request"))?
848 }
849 }
850
851 // access tokens
852
853 async fn create_access_token_hash(
854 &self,
855 user_id: UserId,
856 access_token_hash: &str,
857 max_access_token_count: usize,
858 ) -> Result<()> {
859 let insert_query = "
860 INSERT INTO access_tokens (user_id, hash)
861 VALUES ($1, $2);
862 ";
863 let cleanup_query = "
864 DELETE FROM access_tokens
865 WHERE id IN (
866 SELECT id from access_tokens
867 WHERE user_id = $1
868 ORDER BY id DESC
869 OFFSET $3
870 )
871 ";
872
873 let mut tx = self.pool.begin().await?;
874 sqlx::query(insert_query)
875 .bind(user_id.0)
876 .bind(access_token_hash)
877 .execute(&mut tx)
878 .await?;
879 sqlx::query(cleanup_query)
880 .bind(user_id.0)
881 .bind(access_token_hash)
882 .bind(max_access_token_count as i32)
883 .execute(&mut tx)
884 .await?;
885 Ok(tx.commit().await?)
886 }
887
888 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
889 let query = "
890 SELECT hash
891 FROM access_tokens
892 WHERE user_id = $1
893 ORDER BY id DESC
894 ";
895 Ok(sqlx::query_scalar(query)
896 .bind(user_id.0)
897 .fetch_all(&self.pool)
898 .await?)
899 }
900
901 // orgs
902
903 #[allow(unused)] // Help rust-analyzer
904 #[cfg(any(test, feature = "seed-support"))]
905 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
906 let query = "
907 SELECT *
908 FROM orgs
909 WHERE slug = $1
910 ";
911 Ok(sqlx::query_as(query)
912 .bind(slug)
913 .fetch_optional(&self.pool)
914 .await?)
915 }
916
917 #[cfg(any(test, feature = "seed-support"))]
918 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
919 let query = "
920 INSERT INTO orgs (name, slug)
921 VALUES ($1, $2)
922 RETURNING id
923 ";
924 Ok(sqlx::query_scalar(query)
925 .bind(name)
926 .bind(slug)
927 .fetch_one(&self.pool)
928 .await
929 .map(OrgId)?)
930 }
931
932 #[cfg(any(test, feature = "seed-support"))]
933 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
934 let query = "
935 INSERT INTO org_memberships (org_id, user_id, admin)
936 VALUES ($1, $2, $3)
937 ON CONFLICT DO NOTHING
938 ";
939 Ok(sqlx::query(query)
940 .bind(org_id.0)
941 .bind(user_id.0)
942 .bind(is_admin)
943 .execute(&self.pool)
944 .await
945 .map(drop)?)
946 }
947
948 // channels
949
950 #[cfg(any(test, feature = "seed-support"))]
951 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
952 let query = "
953 INSERT INTO channels (owner_id, owner_is_user, name)
954 VALUES ($1, false, $2)
955 RETURNING id
956 ";
957 Ok(sqlx::query_scalar(query)
958 .bind(org_id.0)
959 .bind(name)
960 .fetch_one(&self.pool)
961 .await
962 .map(ChannelId)?)
963 }
964
965 #[allow(unused)] // Help rust-analyzer
966 #[cfg(any(test, feature = "seed-support"))]
967 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
968 let query = "
969 SELECT *
970 FROM channels
971 WHERE
972 channels.owner_is_user = false AND
973 channels.owner_id = $1
974 ";
975 Ok(sqlx::query_as(query)
976 .bind(org_id.0)
977 .fetch_all(&self.pool)
978 .await?)
979 }
980
981 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
982 let query = "
983 SELECT
984 channels.*
985 FROM
986 channel_memberships, channels
987 WHERE
988 channel_memberships.user_id = $1 AND
989 channel_memberships.channel_id = channels.id
990 ";
991 Ok(sqlx::query_as(query)
992 .bind(user_id.0)
993 .fetch_all(&self.pool)
994 .await?)
995 }
996
997 async fn can_user_access_channel(
998 &self,
999 user_id: UserId,
1000 channel_id: ChannelId,
1001 ) -> Result<bool> {
1002 let query = "
1003 SELECT id
1004 FROM channel_memberships
1005 WHERE user_id = $1 AND channel_id = $2
1006 LIMIT 1
1007 ";
1008 Ok(sqlx::query_scalar::<_, i32>(query)
1009 .bind(user_id.0)
1010 .bind(channel_id.0)
1011 .fetch_optional(&self.pool)
1012 .await
1013 .map(|e| e.is_some())?)
1014 }
1015
1016 #[cfg(any(test, feature = "seed-support"))]
1017 async fn add_channel_member(
1018 &self,
1019 channel_id: ChannelId,
1020 user_id: UserId,
1021 is_admin: bool,
1022 ) -> Result<()> {
1023 let query = "
1024 INSERT INTO channel_memberships (channel_id, user_id, admin)
1025 VALUES ($1, $2, $3)
1026 ON CONFLICT DO NOTHING
1027 ";
1028 Ok(sqlx::query(query)
1029 .bind(channel_id.0)
1030 .bind(user_id.0)
1031 .bind(is_admin)
1032 .execute(&self.pool)
1033 .await
1034 .map(drop)?)
1035 }
1036
1037 // messages
1038
1039 async fn create_channel_message(
1040 &self,
1041 channel_id: ChannelId,
1042 sender_id: UserId,
1043 body: &str,
1044 timestamp: OffsetDateTime,
1045 nonce: u128,
1046 ) -> Result<MessageId> {
1047 let query = "
1048 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1049 VALUES ($1, $2, $3, $4, $5)
1050 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1051 RETURNING id
1052 ";
1053 Ok(sqlx::query_scalar(query)
1054 .bind(channel_id.0)
1055 .bind(sender_id.0)
1056 .bind(body)
1057 .bind(timestamp)
1058 .bind(Uuid::from_u128(nonce))
1059 .fetch_one(&self.pool)
1060 .await
1061 .map(MessageId)?)
1062 }
1063
1064 async fn get_channel_messages(
1065 &self,
1066 channel_id: ChannelId,
1067 count: usize,
1068 before_id: Option<MessageId>,
1069 ) -> Result<Vec<ChannelMessage>> {
1070 let query = r#"
1071 SELECT * FROM (
1072 SELECT
1073 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1074 FROM
1075 channel_messages
1076 WHERE
1077 channel_id = $1 AND
1078 id < $2
1079 ORDER BY id DESC
1080 LIMIT $3
1081 ) as recent_messages
1082 ORDER BY id ASC
1083 "#;
1084 Ok(sqlx::query_as(query)
1085 .bind(channel_id.0)
1086 .bind(before_id.unwrap_or(MessageId::MAX))
1087 .bind(count as i64)
1088 .fetch_all(&self.pool)
1089 .await?)
1090 }
1091
1092 #[cfg(test)]
1093 async fn teardown(&self, url: &str) {
1094 use util::ResultExt;
1095
1096 let query = "
1097 SELECT pg_terminate_backend(pg_stat_activity.pid)
1098 FROM pg_stat_activity
1099 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1100 ";
1101 sqlx::query(query).execute(&self.pool).await.log_err();
1102 self.pool.close().await;
1103 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1104 .await
1105 .log_err();
1106 }
1107
1108 #[cfg(test)]
1109 fn as_fake(&self) -> Option<&tests::FakeDb> {
1110 None
1111 }
1112}
1113
1114macro_rules! id_type {
1115 ($name:ident) => {
1116 #[derive(
1117 Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, sqlx::Type, Serialize,
1118 )]
1119 #[sqlx(transparent)]
1120 #[serde(transparent)]
1121 pub struct $name(pub i32);
1122
1123 impl $name {
1124 #[allow(unused)]
1125 pub const MAX: Self = Self(i32::MAX);
1126
1127 #[allow(unused)]
1128 pub fn from_proto(value: u64) -> Self {
1129 Self(value as i32)
1130 }
1131
1132 #[allow(unused)]
1133 pub fn to_proto(&self) -> u64 {
1134 self.0 as u64
1135 }
1136 }
1137
1138 impl std::fmt::Display for $name {
1139 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1140 self.0.fmt(f)
1141 }
1142 }
1143 };
1144}
1145
1146id_type!(UserId);
1147#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1148pub struct User {
1149 pub id: UserId,
1150 pub github_login: String,
1151 pub email_address: Option<String>,
1152 pub admin: bool,
1153 pub invite_code: Option<String>,
1154 pub invite_count: i32,
1155 pub connected_once: bool,
1156}
1157
1158id_type!(ProjectId);
1159#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1160pub struct Project {
1161 pub id: ProjectId,
1162 pub host_user_id: UserId,
1163 pub unregistered: bool,
1164}
1165
1166#[derive(Clone, Debug, PartialEq, Serialize)]
1167pub struct UserActivitySummary {
1168 pub id: UserId,
1169 pub github_login: String,
1170 pub project_activity: Vec<(ProjectId, Duration)>,
1171}
1172
1173id_type!(OrgId);
1174#[derive(FromRow)]
1175pub struct Org {
1176 pub id: OrgId,
1177 pub name: String,
1178 pub slug: String,
1179}
1180
1181id_type!(ChannelId);
1182#[derive(Clone, Debug, FromRow, Serialize)]
1183pub struct Channel {
1184 pub id: ChannelId,
1185 pub name: String,
1186 pub owner_id: i32,
1187 pub owner_is_user: bool,
1188}
1189
1190id_type!(MessageId);
1191#[derive(Clone, Debug, FromRow)]
1192pub struct ChannelMessage {
1193 pub id: MessageId,
1194 pub channel_id: ChannelId,
1195 pub sender_id: UserId,
1196 pub body: String,
1197 pub sent_at: OffsetDateTime,
1198 pub nonce: Uuid,
1199}
1200
1201#[derive(Clone, Debug, PartialEq, Eq)]
1202pub enum Contact {
1203 Accepted {
1204 user_id: UserId,
1205 should_notify: bool,
1206 },
1207 Outgoing {
1208 user_id: UserId,
1209 },
1210 Incoming {
1211 user_id: UserId,
1212 should_notify: bool,
1213 },
1214}
1215
1216impl Contact {
1217 pub fn user_id(&self) -> UserId {
1218 match self {
1219 Contact::Accepted { user_id, .. } => *user_id,
1220 Contact::Outgoing { user_id } => *user_id,
1221 Contact::Incoming { user_id, .. } => *user_id,
1222 }
1223 }
1224}
1225
1226#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1227pub struct IncomingContactRequest {
1228 pub requester_id: UserId,
1229 pub should_notify: bool,
1230}
1231
1232fn fuzzy_like_string(string: &str) -> String {
1233 let mut result = String::with_capacity(string.len() * 2 + 1);
1234 for c in string.chars() {
1235 if c.is_alphanumeric() {
1236 result.push('%');
1237 result.push(c);
1238 }
1239 }
1240 result.push('%');
1241 result
1242}
1243
1244#[cfg(test)]
1245pub mod tests {
1246 use super::*;
1247 use anyhow::anyhow;
1248 use collections::BTreeMap;
1249 use gpui::executor::Background;
1250 use lazy_static::lazy_static;
1251 use parking_lot::Mutex;
1252 use rand::prelude::*;
1253 use sqlx::{
1254 migrate::{MigrateDatabase, Migrator},
1255 Postgres,
1256 };
1257 use std::{path::Path, sync::Arc};
1258 use util::post_inc;
1259
1260 #[tokio::test(flavor = "multi_thread")]
1261 async fn test_get_users_by_ids() {
1262 for test_db in [
1263 TestDb::postgres().await,
1264 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1265 ] {
1266 let db = test_db.db();
1267
1268 let user = db.create_user("user", None, false).await.unwrap();
1269 let friend1 = db.create_user("friend-1", None, false).await.unwrap();
1270 let friend2 = db.create_user("friend-2", None, false).await.unwrap();
1271 let friend3 = db.create_user("friend-3", None, false).await.unwrap();
1272
1273 assert_eq!(
1274 db.get_users_by_ids(vec![user, friend1, friend2, friend3])
1275 .await
1276 .unwrap(),
1277 vec![
1278 User {
1279 id: user,
1280 github_login: "user".to_string(),
1281 admin: false,
1282 ..Default::default()
1283 },
1284 User {
1285 id: friend1,
1286 github_login: "friend-1".to_string(),
1287 admin: false,
1288 ..Default::default()
1289 },
1290 User {
1291 id: friend2,
1292 github_login: "friend-2".to_string(),
1293 admin: false,
1294 ..Default::default()
1295 },
1296 User {
1297 id: friend3,
1298 github_login: "friend-3".to_string(),
1299 admin: false,
1300 ..Default::default()
1301 }
1302 ]
1303 );
1304 }
1305 }
1306
1307 #[tokio::test(flavor = "multi_thread")]
1308 async fn test_create_users() {
1309 let db = TestDb::postgres().await;
1310 let db = db.db();
1311
1312 // Create the first batch of users, ensuring invite counts are assigned
1313 // correctly and the respective invite codes are unique.
1314 let user_ids_batch_1 = db
1315 .create_users(vec![
1316 ("user1".to_string(), "hi@user1.com".to_string(), 5),
1317 ("user2".to_string(), "hi@user2.com".to_string(), 4),
1318 ("user3".to_string(), "hi@user3.com".to_string(), 3),
1319 ])
1320 .await
1321 .unwrap();
1322 assert_eq!(user_ids_batch_1.len(), 3);
1323
1324 let users = db.get_users_by_ids(user_ids_batch_1.clone()).await.unwrap();
1325 assert_eq!(users.len(), 3);
1326 assert_eq!(users[0].github_login, "user1");
1327 assert_eq!(users[0].email_address.as_deref(), Some("hi@user1.com"));
1328 assert_eq!(users[0].invite_count, 5);
1329 assert_eq!(users[1].github_login, "user2");
1330 assert_eq!(users[1].email_address.as_deref(), Some("hi@user2.com"));
1331 assert_eq!(users[1].invite_count, 4);
1332 assert_eq!(users[2].github_login, "user3");
1333 assert_eq!(users[2].email_address.as_deref(), Some("hi@user3.com"));
1334 assert_eq!(users[2].invite_count, 3);
1335
1336 let invite_code_1 = users[0].invite_code.clone().unwrap();
1337 let invite_code_2 = users[1].invite_code.clone().unwrap();
1338 let invite_code_3 = users[2].invite_code.clone().unwrap();
1339 assert_ne!(invite_code_1, invite_code_2);
1340 assert_ne!(invite_code_1, invite_code_3);
1341 assert_ne!(invite_code_2, invite_code_3);
1342
1343 // Create the second batch of users and include a user that is already in the database, ensuring
1344 // the invite count for the existing user is updated without changing their invite code.
1345 let user_ids_batch_2 = db
1346 .create_users(vec![
1347 ("user2".to_string(), "hi@user2.com".to_string(), 10),
1348 ("user4".to_string(), "hi@user4.com".to_string(), 2),
1349 ])
1350 .await
1351 .unwrap();
1352 assert_eq!(user_ids_batch_2.len(), 2);
1353 assert_eq!(user_ids_batch_2[0], user_ids_batch_1[1]);
1354
1355 let users = db.get_users_by_ids(user_ids_batch_2).await.unwrap();
1356 assert_eq!(users.len(), 2);
1357 assert_eq!(users[0].github_login, "user2");
1358 assert_eq!(users[0].email_address.as_deref(), Some("hi@user2.com"));
1359 assert_eq!(users[0].invite_count, 10);
1360 assert_eq!(users[0].invite_code, Some(invite_code_2.clone()));
1361 assert_eq!(users[1].github_login, "user4");
1362 assert_eq!(users[1].email_address.as_deref(), Some("hi@user4.com"));
1363 assert_eq!(users[1].invite_count, 2);
1364
1365 let invite_code_4 = users[1].invite_code.clone().unwrap();
1366 assert_ne!(invite_code_4, invite_code_1);
1367 assert_ne!(invite_code_4, invite_code_2);
1368 assert_ne!(invite_code_4, invite_code_3);
1369 }
1370
1371 #[tokio::test(flavor = "multi_thread")]
1372 async fn test_worktree_extensions() {
1373 let test_db = TestDb::postgres().await;
1374 let db = test_db.db();
1375
1376 let user = db.create_user("user_1", None, false).await.unwrap();
1377 let project = db.register_project(user).await.unwrap();
1378
1379 db.update_worktree_extensions(project, 100, Default::default())
1380 .await
1381 .unwrap();
1382 db.update_worktree_extensions(
1383 project,
1384 100,
1385 [("rs".to_string(), 5), ("md".to_string(), 3)]
1386 .into_iter()
1387 .collect(),
1388 )
1389 .await
1390 .unwrap();
1391 db.update_worktree_extensions(
1392 project,
1393 100,
1394 [("rs".to_string(), 6), ("md".to_string(), 5)]
1395 .into_iter()
1396 .collect(),
1397 )
1398 .await
1399 .unwrap();
1400 db.update_worktree_extensions(
1401 project,
1402 101,
1403 [("ts".to_string(), 2), ("md".to_string(), 1)]
1404 .into_iter()
1405 .collect(),
1406 )
1407 .await
1408 .unwrap();
1409
1410 assert_eq!(
1411 db.get_project_extensions(project).await.unwrap(),
1412 [
1413 (
1414 100,
1415 [("rs".into(), 6), ("md".into(), 5),]
1416 .into_iter()
1417 .collect::<HashMap<_, _>>()
1418 ),
1419 (
1420 101,
1421 [("ts".into(), 2), ("md".into(), 1),]
1422 .into_iter()
1423 .collect::<HashMap<_, _>>()
1424 )
1425 ]
1426 .into_iter()
1427 .collect()
1428 );
1429 }
1430
1431 #[tokio::test(flavor = "multi_thread")]
1432 async fn test_project_activity() {
1433 let test_db = TestDb::postgres().await;
1434 let db = test_db.db();
1435
1436 let user_1 = db.create_user("user_1", None, false).await.unwrap();
1437 let user_2 = db.create_user("user_2", None, false).await.unwrap();
1438 let user_3 = db.create_user("user_3", None, false).await.unwrap();
1439 let project_1 = db.register_project(user_1).await.unwrap();
1440 let project_2 = db.register_project(user_2).await.unwrap();
1441 let t0 = OffsetDateTime::now_utc() - Duration::from_secs(60 * 60);
1442
1443 // User 2 opens a project
1444 let t1 = t0 + Duration::from_secs(10);
1445 db.record_project_activity(t0..t1, &[(user_2, project_2)])
1446 .await
1447 .unwrap();
1448
1449 let t2 = t1 + Duration::from_secs(10);
1450 db.record_project_activity(t1..t2, &[(user_2, project_2)])
1451 .await
1452 .unwrap();
1453
1454 // User 1 joins the project
1455 let t3 = t2 + Duration::from_secs(10);
1456 db.record_project_activity(t2..t3, &[(user_2, project_2), (user_1, project_2)])
1457 .await
1458 .unwrap();
1459
1460 // User 1 opens another project
1461 let t4 = t3 + Duration::from_secs(10);
1462 db.record_project_activity(
1463 t3..t4,
1464 &[
1465 (user_2, project_2),
1466 (user_1, project_2),
1467 (user_1, project_1),
1468 ],
1469 )
1470 .await
1471 .unwrap();
1472
1473 // User 3 joins that project
1474 let t5 = t4 + Duration::from_secs(10);
1475 db.record_project_activity(
1476 t4..t5,
1477 &[
1478 (user_2, project_2),
1479 (user_1, project_2),
1480 (user_1, project_1),
1481 (user_3, project_1),
1482 ],
1483 )
1484 .await
1485 .unwrap();
1486
1487 // User 2 leaves
1488 let t6 = t5 + Duration::from_secs(5);
1489 db.record_project_activity(t5..t6, &[(user_1, project_1), (user_3, project_1)])
1490 .await
1491 .unwrap();
1492
1493 let summary = db.summarize_project_activity(t0..t6, 10).await.unwrap();
1494 assert_eq!(
1495 summary,
1496 &[
1497 UserActivitySummary {
1498 id: user_1,
1499 github_login: "user_1".to_string(),
1500 project_activity: vec![
1501 (project_2, Duration::from_secs(30)),
1502 (project_1, Duration::from_secs(25))
1503 ]
1504 },
1505 UserActivitySummary {
1506 id: user_2,
1507 github_login: "user_2".to_string(),
1508 project_activity: vec![(project_2, Duration::from_secs(50))]
1509 },
1510 UserActivitySummary {
1511 id: user_3,
1512 github_login: "user_3".to_string(),
1513 project_activity: vec![(project_1, Duration::from_secs(15))]
1514 },
1515 ]
1516 );
1517 }
1518
1519 #[tokio::test(flavor = "multi_thread")]
1520 async fn test_recent_channel_messages() {
1521 for test_db in [
1522 TestDb::postgres().await,
1523 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1524 ] {
1525 let db = test_db.db();
1526 let user = db.create_user("user", None, false).await.unwrap();
1527 let org = db.create_org("org", "org").await.unwrap();
1528 let channel = db.create_org_channel(org, "channel").await.unwrap();
1529 for i in 0..10 {
1530 db.create_channel_message(
1531 channel,
1532 user,
1533 &i.to_string(),
1534 OffsetDateTime::now_utc(),
1535 i,
1536 )
1537 .await
1538 .unwrap();
1539 }
1540
1541 let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
1542 assert_eq!(
1543 messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1544 ["5", "6", "7", "8", "9"]
1545 );
1546
1547 let prev_messages = db
1548 .get_channel_messages(channel, 4, Some(messages[0].id))
1549 .await
1550 .unwrap();
1551 assert_eq!(
1552 prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
1553 ["1", "2", "3", "4"]
1554 );
1555 }
1556 }
1557
1558 #[tokio::test(flavor = "multi_thread")]
1559 async fn test_channel_message_nonces() {
1560 for test_db in [
1561 TestDb::postgres().await,
1562 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1563 ] {
1564 let db = test_db.db();
1565 let user = db.create_user("user", None, false).await.unwrap();
1566 let org = db.create_org("org", "org").await.unwrap();
1567 let channel = db.create_org_channel(org, "channel").await.unwrap();
1568
1569 let msg1_id = db
1570 .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
1571 .await
1572 .unwrap();
1573 let msg2_id = db
1574 .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
1575 .await
1576 .unwrap();
1577 let msg3_id = db
1578 .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
1579 .await
1580 .unwrap();
1581 let msg4_id = db
1582 .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
1583 .await
1584 .unwrap();
1585
1586 assert_ne!(msg1_id, msg2_id);
1587 assert_eq!(msg1_id, msg3_id);
1588 assert_eq!(msg2_id, msg4_id);
1589 }
1590 }
1591
1592 #[tokio::test(flavor = "multi_thread")]
1593 async fn test_create_access_tokens() {
1594 let test_db = TestDb::postgres().await;
1595 let db = test_db.db();
1596 let user = db.create_user("the-user", None, false).await.unwrap();
1597
1598 db.create_access_token_hash(user, "h1", 3).await.unwrap();
1599 db.create_access_token_hash(user, "h2", 3).await.unwrap();
1600 assert_eq!(
1601 db.get_access_token_hashes(user).await.unwrap(),
1602 &["h2".to_string(), "h1".to_string()]
1603 );
1604
1605 db.create_access_token_hash(user, "h3", 3).await.unwrap();
1606 assert_eq!(
1607 db.get_access_token_hashes(user).await.unwrap(),
1608 &["h3".to_string(), "h2".to_string(), "h1".to_string(),]
1609 );
1610
1611 db.create_access_token_hash(user, "h4", 3).await.unwrap();
1612 assert_eq!(
1613 db.get_access_token_hashes(user).await.unwrap(),
1614 &["h4".to_string(), "h3".to_string(), "h2".to_string(),]
1615 );
1616
1617 db.create_access_token_hash(user, "h5", 3).await.unwrap();
1618 assert_eq!(
1619 db.get_access_token_hashes(user).await.unwrap(),
1620 &["h5".to_string(), "h4".to_string(), "h3".to_string()]
1621 );
1622 }
1623
1624 #[test]
1625 fn test_fuzzy_like_string() {
1626 assert_eq!(fuzzy_like_string("abcd"), "%a%b%c%d%");
1627 assert_eq!(fuzzy_like_string("x y"), "%x%y%");
1628 assert_eq!(fuzzy_like_string(" z "), "%z%");
1629 }
1630
1631 #[tokio::test(flavor = "multi_thread")]
1632 async fn test_fuzzy_search_users() {
1633 let test_db = TestDb::postgres().await;
1634 let db = test_db.db();
1635 for github_login in [
1636 "California",
1637 "colorado",
1638 "oregon",
1639 "washington",
1640 "florida",
1641 "delaware",
1642 "rhode-island",
1643 ] {
1644 db.create_user(github_login, None, false).await.unwrap();
1645 }
1646
1647 assert_eq!(
1648 fuzzy_search_user_names(db, "clr").await,
1649 &["colorado", "California"]
1650 );
1651 assert_eq!(
1652 fuzzy_search_user_names(db, "ro").await,
1653 &["rhode-island", "colorado", "oregon"],
1654 );
1655
1656 async fn fuzzy_search_user_names(db: &Arc<dyn Db>, query: &str) -> Vec<String> {
1657 db.fuzzy_search_users(query, 10)
1658 .await
1659 .unwrap()
1660 .into_iter()
1661 .map(|user| user.github_login)
1662 .collect::<Vec<_>>()
1663 }
1664 }
1665
1666 #[tokio::test(flavor = "multi_thread")]
1667 async fn test_add_contacts() {
1668 for test_db in [
1669 TestDb::postgres().await,
1670 TestDb::fake(Arc::new(gpui::executor::Background::new())),
1671 ] {
1672 let db = test_db.db();
1673
1674 let user_1 = db.create_user("user1", None, false).await.unwrap();
1675 let user_2 = db.create_user("user2", None, false).await.unwrap();
1676 let user_3 = db.create_user("user3", None, false).await.unwrap();
1677
1678 // User starts with no contacts
1679 assert_eq!(
1680 db.get_contacts(user_1).await.unwrap(),
1681 vec![Contact::Accepted {
1682 user_id: user_1,
1683 should_notify: false
1684 }],
1685 );
1686
1687 // User requests a contact. Both users see the pending request.
1688 db.send_contact_request(user_1, user_2).await.unwrap();
1689 assert!(!db.has_contact(user_1, user_2).await.unwrap());
1690 assert!(!db.has_contact(user_2, user_1).await.unwrap());
1691 assert_eq!(
1692 db.get_contacts(user_1).await.unwrap(),
1693 &[
1694 Contact::Accepted {
1695 user_id: user_1,
1696 should_notify: false
1697 },
1698 Contact::Outgoing { user_id: user_2 }
1699 ],
1700 );
1701 assert_eq!(
1702 db.get_contacts(user_2).await.unwrap(),
1703 &[
1704 Contact::Incoming {
1705 user_id: user_1,
1706 should_notify: true
1707 },
1708 Contact::Accepted {
1709 user_id: user_2,
1710 should_notify: false
1711 },
1712 ]
1713 );
1714
1715 // User 2 dismisses the contact request notification without accepting or rejecting.
1716 // We shouldn't notify them again.
1717 db.dismiss_contact_notification(user_1, user_2)
1718 .await
1719 .unwrap_err();
1720 db.dismiss_contact_notification(user_2, user_1)
1721 .await
1722 .unwrap();
1723 assert_eq!(
1724 db.get_contacts(user_2).await.unwrap(),
1725 &[
1726 Contact::Incoming {
1727 user_id: user_1,
1728 should_notify: false
1729 },
1730 Contact::Accepted {
1731 user_id: user_2,
1732 should_notify: false
1733 },
1734 ]
1735 );
1736
1737 // User can't accept their own contact request
1738 db.respond_to_contact_request(user_1, user_2, true)
1739 .await
1740 .unwrap_err();
1741
1742 // User accepts a contact request. Both users see the contact.
1743 db.respond_to_contact_request(user_2, user_1, true)
1744 .await
1745 .unwrap();
1746 assert_eq!(
1747 db.get_contacts(user_1).await.unwrap(),
1748 &[
1749 Contact::Accepted {
1750 user_id: user_1,
1751 should_notify: false
1752 },
1753 Contact::Accepted {
1754 user_id: user_2,
1755 should_notify: true
1756 }
1757 ],
1758 );
1759 assert!(db.has_contact(user_1, user_2).await.unwrap());
1760 assert!(db.has_contact(user_2, user_1).await.unwrap());
1761 assert_eq!(
1762 db.get_contacts(user_2).await.unwrap(),
1763 &[
1764 Contact::Accepted {
1765 user_id: user_1,
1766 should_notify: false,
1767 },
1768 Contact::Accepted {
1769 user_id: user_2,
1770 should_notify: false,
1771 },
1772 ]
1773 );
1774
1775 // Users cannot re-request existing contacts.
1776 db.send_contact_request(user_1, user_2).await.unwrap_err();
1777 db.send_contact_request(user_2, user_1).await.unwrap_err();
1778
1779 // Users can't dismiss notifications of them accepting other users' requests.
1780 db.dismiss_contact_notification(user_2, user_1)
1781 .await
1782 .unwrap_err();
1783 assert_eq!(
1784 db.get_contacts(user_1).await.unwrap(),
1785 &[
1786 Contact::Accepted {
1787 user_id: user_1,
1788 should_notify: false
1789 },
1790 Contact::Accepted {
1791 user_id: user_2,
1792 should_notify: true,
1793 },
1794 ]
1795 );
1796
1797 // Users can dismiss notifications of other users accepting their requests.
1798 db.dismiss_contact_notification(user_1, user_2)
1799 .await
1800 .unwrap();
1801 assert_eq!(
1802 db.get_contacts(user_1).await.unwrap(),
1803 &[
1804 Contact::Accepted {
1805 user_id: user_1,
1806 should_notify: false
1807 },
1808 Contact::Accepted {
1809 user_id: user_2,
1810 should_notify: false,
1811 },
1812 ]
1813 );
1814
1815 // Users send each other concurrent contact requests and
1816 // see that they are immediately accepted.
1817 db.send_contact_request(user_1, user_3).await.unwrap();
1818 db.send_contact_request(user_3, user_1).await.unwrap();
1819 assert_eq!(
1820 db.get_contacts(user_1).await.unwrap(),
1821 &[
1822 Contact::Accepted {
1823 user_id: user_1,
1824 should_notify: false
1825 },
1826 Contact::Accepted {
1827 user_id: user_2,
1828 should_notify: false,
1829 },
1830 Contact::Accepted {
1831 user_id: user_3,
1832 should_notify: false
1833 },
1834 ]
1835 );
1836 assert_eq!(
1837 db.get_contacts(user_3).await.unwrap(),
1838 &[
1839 Contact::Accepted {
1840 user_id: user_1,
1841 should_notify: false
1842 },
1843 Contact::Accepted {
1844 user_id: user_3,
1845 should_notify: false
1846 }
1847 ],
1848 );
1849
1850 // User declines a contact request. Both users see that it is gone.
1851 db.send_contact_request(user_2, user_3).await.unwrap();
1852 db.respond_to_contact_request(user_3, user_2, false)
1853 .await
1854 .unwrap();
1855 assert!(!db.has_contact(user_2, user_3).await.unwrap());
1856 assert!(!db.has_contact(user_3, user_2).await.unwrap());
1857 assert_eq!(
1858 db.get_contacts(user_2).await.unwrap(),
1859 &[
1860 Contact::Accepted {
1861 user_id: user_1,
1862 should_notify: false
1863 },
1864 Contact::Accepted {
1865 user_id: user_2,
1866 should_notify: false
1867 }
1868 ]
1869 );
1870 assert_eq!(
1871 db.get_contacts(user_3).await.unwrap(),
1872 &[
1873 Contact::Accepted {
1874 user_id: user_1,
1875 should_notify: false
1876 },
1877 Contact::Accepted {
1878 user_id: user_3,
1879 should_notify: false
1880 }
1881 ],
1882 );
1883 }
1884 }
1885
1886 #[tokio::test(flavor = "multi_thread")]
1887 async fn test_invite_codes() {
1888 let postgres = TestDb::postgres().await;
1889 let db = postgres.db();
1890 let user1 = db.create_user("user-1", None, false).await.unwrap();
1891
1892 // Initially, user 1 has no invite code
1893 assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None);
1894
1895 // Setting invite count to 0 when no code is assigned does not assign a new code
1896 db.set_invite_count(user1, 0).await.unwrap();
1897 assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none());
1898
1899 // User 1 creates an invite code that can be used twice.
1900 db.set_invite_count(user1, 2).await.unwrap();
1901 let (invite_code, invite_count) =
1902 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1903 assert_eq!(invite_count, 2);
1904
1905 // User 2 redeems the invite code and becomes a contact of user 1.
1906 let user2 = db
1907 .redeem_invite_code(&invite_code, "user-2", None)
1908 .await
1909 .unwrap();
1910 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1911 assert_eq!(invite_count, 1);
1912 assert_eq!(
1913 db.get_contacts(user1).await.unwrap(),
1914 [
1915 Contact::Accepted {
1916 user_id: user1,
1917 should_notify: false
1918 },
1919 Contact::Accepted {
1920 user_id: user2,
1921 should_notify: true
1922 }
1923 ]
1924 );
1925 assert_eq!(
1926 db.get_contacts(user2).await.unwrap(),
1927 [
1928 Contact::Accepted {
1929 user_id: user1,
1930 should_notify: false
1931 },
1932 Contact::Accepted {
1933 user_id: user2,
1934 should_notify: false
1935 }
1936 ]
1937 );
1938
1939 // User 3 redeems the invite code and becomes a contact of user 1.
1940 let user3 = db
1941 .redeem_invite_code(&invite_code, "user-3", None)
1942 .await
1943 .unwrap();
1944 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1945 assert_eq!(invite_count, 0);
1946 assert_eq!(
1947 db.get_contacts(user1).await.unwrap(),
1948 [
1949 Contact::Accepted {
1950 user_id: user1,
1951 should_notify: false
1952 },
1953 Contact::Accepted {
1954 user_id: user2,
1955 should_notify: true
1956 },
1957 Contact::Accepted {
1958 user_id: user3,
1959 should_notify: true
1960 }
1961 ]
1962 );
1963 assert_eq!(
1964 db.get_contacts(user3).await.unwrap(),
1965 [
1966 Contact::Accepted {
1967 user_id: user1,
1968 should_notify: false
1969 },
1970 Contact::Accepted {
1971 user_id: user3,
1972 should_notify: false
1973 },
1974 ]
1975 );
1976
1977 // Trying to reedem the code for the third time results in an error.
1978 db.redeem_invite_code(&invite_code, "user-4", None)
1979 .await
1980 .unwrap_err();
1981
1982 // Invite count can be updated after the code has been created.
1983 db.set_invite_count(user1, 2).await.unwrap();
1984 let (latest_code, invite_count) =
1985 db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1986 assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0
1987 assert_eq!(invite_count, 2);
1988
1989 // User 4 can now redeem the invite code and becomes a contact of user 1.
1990 let user4 = db
1991 .redeem_invite_code(&invite_code, "user-4", None)
1992 .await
1993 .unwrap();
1994 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
1995 assert_eq!(invite_count, 1);
1996 assert_eq!(
1997 db.get_contacts(user1).await.unwrap(),
1998 [
1999 Contact::Accepted {
2000 user_id: user1,
2001 should_notify: false
2002 },
2003 Contact::Accepted {
2004 user_id: user2,
2005 should_notify: true
2006 },
2007 Contact::Accepted {
2008 user_id: user3,
2009 should_notify: true
2010 },
2011 Contact::Accepted {
2012 user_id: user4,
2013 should_notify: true
2014 }
2015 ]
2016 );
2017 assert_eq!(
2018 db.get_contacts(user4).await.unwrap(),
2019 [
2020 Contact::Accepted {
2021 user_id: user1,
2022 should_notify: false
2023 },
2024 Contact::Accepted {
2025 user_id: user4,
2026 should_notify: false
2027 },
2028 ]
2029 );
2030
2031 // An existing user cannot redeem invite codes.
2032 db.redeem_invite_code(&invite_code, "user-2", None)
2033 .await
2034 .unwrap_err();
2035 let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap();
2036 assert_eq!(invite_count, 1);
2037 }
2038
2039 pub struct TestDb {
2040 pub db: Option<Arc<dyn Db>>,
2041 pub url: String,
2042 }
2043
2044 impl TestDb {
2045 pub async fn postgres() -> Self {
2046 lazy_static! {
2047 static ref LOCK: Mutex<()> = Mutex::new(());
2048 }
2049
2050 let _guard = LOCK.lock();
2051 let mut rng = StdRng::from_entropy();
2052 let name = format!("zed-test-{}", rng.gen::<u128>());
2053 let url = format!("postgres://postgres@localhost/{}", name);
2054 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2055 Postgres::create_database(&url)
2056 .await
2057 .expect("failed to create test db");
2058 let db = PostgresDb::new(&url, 5).await.unwrap();
2059 let migrator = Migrator::new(migrations_path).await.unwrap();
2060 migrator.run(&db.pool).await.unwrap();
2061 Self {
2062 db: Some(Arc::new(db)),
2063 url,
2064 }
2065 }
2066
2067 pub fn fake(background: Arc<Background>) -> Self {
2068 Self {
2069 db: Some(Arc::new(FakeDb::new(background))),
2070 url: Default::default(),
2071 }
2072 }
2073
2074 pub fn db(&self) -> &Arc<dyn Db> {
2075 self.db.as_ref().unwrap()
2076 }
2077 }
2078
2079 impl Drop for TestDb {
2080 fn drop(&mut self) {
2081 if let Some(db) = self.db.take() {
2082 futures::executor::block_on(db.teardown(&self.url));
2083 }
2084 }
2085 }
2086
2087 pub struct FakeDb {
2088 background: Arc<Background>,
2089 pub users: Mutex<BTreeMap<UserId, User>>,
2090 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
2091 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), usize>>,
2092 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
2093 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
2094 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
2095 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
2096 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
2097 pub contacts: Mutex<Vec<FakeContact>>,
2098 next_channel_message_id: Mutex<i32>,
2099 next_user_id: Mutex<i32>,
2100 next_org_id: Mutex<i32>,
2101 next_channel_id: Mutex<i32>,
2102 next_project_id: Mutex<i32>,
2103 }
2104
2105 #[derive(Debug)]
2106 pub struct FakeContact {
2107 pub requester_id: UserId,
2108 pub responder_id: UserId,
2109 pub accepted: bool,
2110 pub should_notify: bool,
2111 }
2112
2113 impl FakeDb {
2114 pub fn new(background: Arc<Background>) -> Self {
2115 Self {
2116 background,
2117 users: Default::default(),
2118 next_user_id: Mutex::new(1),
2119 projects: Default::default(),
2120 worktree_extensions: Default::default(),
2121 next_project_id: Mutex::new(1),
2122 orgs: Default::default(),
2123 next_org_id: Mutex::new(1),
2124 org_memberships: Default::default(),
2125 channels: Default::default(),
2126 next_channel_id: Mutex::new(1),
2127 channel_memberships: Default::default(),
2128 channel_messages: Default::default(),
2129 next_channel_message_id: Mutex::new(1),
2130 contacts: Default::default(),
2131 }
2132 }
2133 }
2134
2135 #[async_trait]
2136 impl Db for FakeDb {
2137 async fn create_user(
2138 &self,
2139 github_login: &str,
2140 email_address: Option<&str>,
2141 admin: bool,
2142 ) -> Result<UserId> {
2143 self.background.simulate_random_delay().await;
2144
2145 let mut users = self.users.lock();
2146 if let Some(user) = users
2147 .values()
2148 .find(|user| user.github_login == github_login)
2149 {
2150 Ok(user.id)
2151 } else {
2152 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
2153 users.insert(
2154 user_id,
2155 User {
2156 id: user_id,
2157 github_login: github_login.to_string(),
2158 email_address: email_address.map(str::to_string),
2159 admin,
2160 invite_code: None,
2161 invite_count: 0,
2162 connected_once: false,
2163 },
2164 );
2165 Ok(user_id)
2166 }
2167 }
2168
2169 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2170 unimplemented!()
2171 }
2172
2173 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
2174 unimplemented!()
2175 }
2176
2177 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2178 unimplemented!()
2179 }
2180
2181 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2182 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2183 }
2184
2185 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2186 self.background.simulate_random_delay().await;
2187 let users = self.users.lock();
2188 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2189 }
2190
2191 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
2192 Ok(self
2193 .users
2194 .lock()
2195 .values()
2196 .find(|user| user.github_login == github_login)
2197 .cloned())
2198 }
2199
2200 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2201 unimplemented!()
2202 }
2203
2204 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2205 self.background.simulate_random_delay().await;
2206 let mut users = self.users.lock();
2207 let mut user = users
2208 .get_mut(&id)
2209 .ok_or_else(|| anyhow!("user not found"))?;
2210 user.connected_once = connected_once;
2211 Ok(())
2212 }
2213
2214 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2215 unimplemented!()
2216 }
2217
2218 // invite codes
2219
2220 async fn set_invite_count(&self, _id: UserId, _count: u32) -> Result<()> {
2221 unimplemented!()
2222 }
2223
2224 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2225 Ok(None)
2226 }
2227
2228 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2229 unimplemented!()
2230 }
2231
2232 async fn redeem_invite_code(
2233 &self,
2234 _code: &str,
2235 _login: &str,
2236 _email_address: Option<&str>,
2237 ) -> Result<UserId> {
2238 unimplemented!()
2239 }
2240
2241 // projects
2242
2243 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2244 self.background.simulate_random_delay().await;
2245 if !self.users.lock().contains_key(&host_user_id) {
2246 Err(anyhow!("no such user"))?;
2247 }
2248
2249 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2250 self.projects.lock().insert(
2251 project_id,
2252 Project {
2253 id: project_id,
2254 host_user_id,
2255 unregistered: false,
2256 },
2257 );
2258 Ok(project_id)
2259 }
2260
2261 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2262 self.projects
2263 .lock()
2264 .get_mut(&project_id)
2265 .ok_or_else(|| anyhow!("no such project"))?
2266 .unregistered = true;
2267 Ok(())
2268 }
2269
2270 async fn update_worktree_extensions(
2271 &self,
2272 project_id: ProjectId,
2273 worktree_id: u64,
2274 extensions: HashMap<String, usize>,
2275 ) -> Result<()> {
2276 self.background.simulate_random_delay().await;
2277 if !self.projects.lock().contains_key(&project_id) {
2278 Err(anyhow!("no such project"))?;
2279 }
2280
2281 for (extension, count) in extensions {
2282 self.worktree_extensions
2283 .lock()
2284 .insert((project_id, worktree_id, extension), count);
2285 }
2286
2287 Ok(())
2288 }
2289
2290 async fn get_project_extensions(
2291 &self,
2292 _project_id: ProjectId,
2293 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2294 unimplemented!()
2295 }
2296
2297 async fn record_project_activity(
2298 &self,
2299 _period: Range<OffsetDateTime>,
2300 _active_projects: &[(UserId, ProjectId)],
2301 ) -> Result<()> {
2302 unimplemented!()
2303 }
2304
2305 async fn summarize_project_activity(
2306 &self,
2307 _period: Range<OffsetDateTime>,
2308 _limit: usize,
2309 ) -> Result<Vec<UserActivitySummary>> {
2310 unimplemented!()
2311 }
2312
2313 // contacts
2314
2315 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2316 self.background.simulate_random_delay().await;
2317 let mut contacts = vec![Contact::Accepted {
2318 user_id: id,
2319 should_notify: false,
2320 }];
2321
2322 for contact in self.contacts.lock().iter() {
2323 if contact.requester_id == id {
2324 if contact.accepted {
2325 contacts.push(Contact::Accepted {
2326 user_id: contact.responder_id,
2327 should_notify: contact.should_notify,
2328 });
2329 } else {
2330 contacts.push(Contact::Outgoing {
2331 user_id: contact.responder_id,
2332 });
2333 }
2334 } else if contact.responder_id == id {
2335 if contact.accepted {
2336 contacts.push(Contact::Accepted {
2337 user_id: contact.requester_id,
2338 should_notify: false,
2339 });
2340 } else {
2341 contacts.push(Contact::Incoming {
2342 user_id: contact.requester_id,
2343 should_notify: contact.should_notify,
2344 });
2345 }
2346 }
2347 }
2348
2349 contacts.sort_unstable_by_key(|contact| contact.user_id());
2350 Ok(contacts)
2351 }
2352
2353 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2354 self.background.simulate_random_delay().await;
2355 Ok(self.contacts.lock().iter().any(|contact| {
2356 contact.accepted
2357 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2358 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2359 }))
2360 }
2361
2362 async fn send_contact_request(
2363 &self,
2364 requester_id: UserId,
2365 responder_id: UserId,
2366 ) -> Result<()> {
2367 let mut contacts = self.contacts.lock();
2368 for contact in contacts.iter_mut() {
2369 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2370 if contact.accepted {
2371 Err(anyhow!("contact already exists"))?;
2372 } else {
2373 Err(anyhow!("contact already requested"))?;
2374 }
2375 }
2376 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2377 if contact.accepted {
2378 Err(anyhow!("contact already exists"))?;
2379 } else {
2380 contact.accepted = true;
2381 contact.should_notify = false;
2382 return Ok(());
2383 }
2384 }
2385 }
2386 contacts.push(FakeContact {
2387 requester_id,
2388 responder_id,
2389 accepted: false,
2390 should_notify: true,
2391 });
2392 Ok(())
2393 }
2394
2395 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2396 self.contacts.lock().retain(|contact| {
2397 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2398 });
2399 Ok(())
2400 }
2401
2402 async fn dismiss_contact_notification(
2403 &self,
2404 user_id: UserId,
2405 contact_user_id: UserId,
2406 ) -> Result<()> {
2407 let mut contacts = self.contacts.lock();
2408 for contact in contacts.iter_mut() {
2409 if contact.requester_id == contact_user_id
2410 && contact.responder_id == user_id
2411 && !contact.accepted
2412 {
2413 contact.should_notify = false;
2414 return Ok(());
2415 }
2416 if contact.requester_id == user_id
2417 && contact.responder_id == contact_user_id
2418 && contact.accepted
2419 {
2420 contact.should_notify = false;
2421 return Ok(());
2422 }
2423 }
2424 Err(anyhow!("no such notification"))?
2425 }
2426
2427 async fn respond_to_contact_request(
2428 &self,
2429 responder_id: UserId,
2430 requester_id: UserId,
2431 accept: bool,
2432 ) -> Result<()> {
2433 let mut contacts = self.contacts.lock();
2434 for (ix, contact) in contacts.iter_mut().enumerate() {
2435 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2436 if contact.accepted {
2437 Err(anyhow!("contact already confirmed"))?;
2438 }
2439 if accept {
2440 contact.accepted = true;
2441 contact.should_notify = true;
2442 } else {
2443 contacts.remove(ix);
2444 }
2445 return Ok(());
2446 }
2447 }
2448 Err(anyhow!("no such contact request"))?
2449 }
2450
2451 async fn create_access_token_hash(
2452 &self,
2453 _user_id: UserId,
2454 _access_token_hash: &str,
2455 _max_access_token_count: usize,
2456 ) -> Result<()> {
2457 unimplemented!()
2458 }
2459
2460 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2461 unimplemented!()
2462 }
2463
2464 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2465 unimplemented!()
2466 }
2467
2468 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2469 self.background.simulate_random_delay().await;
2470 let mut orgs = self.orgs.lock();
2471 if orgs.values().any(|org| org.slug == slug) {
2472 Err(anyhow!("org already exists"))?
2473 } else {
2474 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2475 orgs.insert(
2476 org_id,
2477 Org {
2478 id: org_id,
2479 name: name.to_string(),
2480 slug: slug.to_string(),
2481 },
2482 );
2483 Ok(org_id)
2484 }
2485 }
2486
2487 async fn add_org_member(
2488 &self,
2489 org_id: OrgId,
2490 user_id: UserId,
2491 is_admin: bool,
2492 ) -> Result<()> {
2493 self.background.simulate_random_delay().await;
2494 if !self.orgs.lock().contains_key(&org_id) {
2495 Err(anyhow!("org does not exist"))?;
2496 }
2497 if !self.users.lock().contains_key(&user_id) {
2498 Err(anyhow!("user does not exist"))?;
2499 }
2500
2501 self.org_memberships
2502 .lock()
2503 .entry((org_id, user_id))
2504 .or_insert(is_admin);
2505 Ok(())
2506 }
2507
2508 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2509 self.background.simulate_random_delay().await;
2510 if !self.orgs.lock().contains_key(&org_id) {
2511 Err(anyhow!("org does not exist"))?;
2512 }
2513
2514 let mut channels = self.channels.lock();
2515 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2516 channels.insert(
2517 channel_id,
2518 Channel {
2519 id: channel_id,
2520 name: name.to_string(),
2521 owner_id: org_id.0,
2522 owner_is_user: false,
2523 },
2524 );
2525 Ok(channel_id)
2526 }
2527
2528 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2529 self.background.simulate_random_delay().await;
2530 Ok(self
2531 .channels
2532 .lock()
2533 .values()
2534 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2535 .cloned()
2536 .collect())
2537 }
2538
2539 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2540 self.background.simulate_random_delay().await;
2541 let channels = self.channels.lock();
2542 let memberships = self.channel_memberships.lock();
2543 Ok(channels
2544 .values()
2545 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2546 .cloned()
2547 .collect())
2548 }
2549
2550 async fn can_user_access_channel(
2551 &self,
2552 user_id: UserId,
2553 channel_id: ChannelId,
2554 ) -> Result<bool> {
2555 self.background.simulate_random_delay().await;
2556 Ok(self
2557 .channel_memberships
2558 .lock()
2559 .contains_key(&(channel_id, user_id)))
2560 }
2561
2562 async fn add_channel_member(
2563 &self,
2564 channel_id: ChannelId,
2565 user_id: UserId,
2566 is_admin: bool,
2567 ) -> Result<()> {
2568 self.background.simulate_random_delay().await;
2569 if !self.channels.lock().contains_key(&channel_id) {
2570 Err(anyhow!("channel does not exist"))?;
2571 }
2572 if !self.users.lock().contains_key(&user_id) {
2573 Err(anyhow!("user does not exist"))?;
2574 }
2575
2576 self.channel_memberships
2577 .lock()
2578 .entry((channel_id, user_id))
2579 .or_insert(is_admin);
2580 Ok(())
2581 }
2582
2583 async fn create_channel_message(
2584 &self,
2585 channel_id: ChannelId,
2586 sender_id: UserId,
2587 body: &str,
2588 timestamp: OffsetDateTime,
2589 nonce: u128,
2590 ) -> Result<MessageId> {
2591 self.background.simulate_random_delay().await;
2592 if !self.channels.lock().contains_key(&channel_id) {
2593 Err(anyhow!("channel does not exist"))?;
2594 }
2595 if !self.users.lock().contains_key(&sender_id) {
2596 Err(anyhow!("user does not exist"))?;
2597 }
2598
2599 let mut messages = self.channel_messages.lock();
2600 if let Some(message) = messages
2601 .values()
2602 .find(|message| message.nonce.as_u128() == nonce)
2603 {
2604 Ok(message.id)
2605 } else {
2606 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2607 messages.insert(
2608 message_id,
2609 ChannelMessage {
2610 id: message_id,
2611 channel_id,
2612 sender_id,
2613 body: body.to_string(),
2614 sent_at: timestamp,
2615 nonce: Uuid::from_u128(nonce),
2616 },
2617 );
2618 Ok(message_id)
2619 }
2620 }
2621
2622 async fn get_channel_messages(
2623 &self,
2624 channel_id: ChannelId,
2625 count: usize,
2626 before_id: Option<MessageId>,
2627 ) -> Result<Vec<ChannelMessage>> {
2628 let mut messages = self
2629 .channel_messages
2630 .lock()
2631 .values()
2632 .rev()
2633 .filter(|message| {
2634 message.channel_id == channel_id
2635 && message.id < before_id.unwrap_or(MessageId::MAX)
2636 })
2637 .take(count)
2638 .cloned()
2639 .collect::<Vec<_>>();
2640 messages.sort_unstable_by_key(|message| message.id);
2641 Ok(messages)
2642 }
2643
2644 async fn teardown(&self, _: &str) {}
2645
2646 #[cfg(test)]
2647 fn as_fake(&self) -> Option<&FakeDb> {
2648 Some(self)
2649 }
2650 }
2651}