1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use collections::HashMap;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{
10 migrate::{Migrate as _, Migration, MigrationSource},
11 types::Uuid,
12 FromRow, QueryBuilder,
13};
14use std::{cmp, ops::Range, path::Path, time::Duration};
15use time::{OffsetDateTime, PrimitiveDateTime};
16
17#[async_trait]
18pub trait Db: Send + Sync {
19 async fn create_user(
20 &self,
21 email_address: &str,
22 admin: bool,
23 params: NewUserParams,
24 ) -> Result<NewUserResult>;
25 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
26 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
27 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
28 async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
29 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
30 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
31 async fn get_user_by_github_account(
32 &self,
33 github_login: &str,
34 github_user_id: Option<i32>,
35 ) -> Result<Option<User>>;
36 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
37 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
38 async fn destroy_user(&self, id: UserId) -> Result<()>;
39
40 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
41 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
42 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
43 async fn create_invite_from_code(
44 &self,
45 code: &str,
46 email_address: &str,
47 device_id: Option<&str>,
48 ) -> Result<Invite>;
49
50 async fn create_signup(&self, signup: Signup) -> Result<()>;
51 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
52 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
53 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
54 async fn create_user_from_invite(
55 &self,
56 invite: &Invite,
57 user: NewUserParams,
58 ) -> Result<Option<NewUserResult>>;
59
60 /// Registers a new project for the given user.
61 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
62
63 /// Unregisters a project for the given project id.
64 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
65
66 /// Update file counts by extension for the given project and worktree.
67 async fn update_worktree_extensions(
68 &self,
69 project_id: ProjectId,
70 worktree_id: u64,
71 extensions: HashMap<String, u32>,
72 ) -> Result<()>;
73
74 /// Get the file counts on the given project keyed by their worktree and extension.
75 async fn get_project_extensions(
76 &self,
77 project_id: ProjectId,
78 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
79
80 /// Record which users have been active in which projects during
81 /// a given period of time.
82 async fn record_user_activity(
83 &self,
84 time_period: Range<OffsetDateTime>,
85 active_projects: &[(UserId, ProjectId)],
86 ) -> Result<()>;
87
88 /// Get the number of users who have been active in the given
89 /// time period for at least the given time duration.
90 async fn get_active_user_count(
91 &self,
92 time_period: Range<OffsetDateTime>,
93 min_duration: Duration,
94 only_collaborative: bool,
95 ) -> Result<usize>;
96
97 /// Get the users that have been most active during the given time period,
98 /// along with the amount of time they have been active in each project.
99 async fn get_top_users_activity_summary(
100 &self,
101 time_period: Range<OffsetDateTime>,
102 max_user_count: usize,
103 ) -> Result<Vec<UserActivitySummary>>;
104
105 /// Get the project activity for the given user and time period.
106 async fn get_user_activity_timeline(
107 &self,
108 time_period: Range<OffsetDateTime>,
109 user_id: UserId,
110 ) -> Result<Vec<UserActivityPeriod>>;
111
112 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
113 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
114 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
115 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
116 async fn dismiss_contact_notification(
117 &self,
118 responder_id: UserId,
119 requester_id: UserId,
120 ) -> Result<()>;
121 async fn respond_to_contact_request(
122 &self,
123 responder_id: UserId,
124 requester_id: UserId,
125 accept: bool,
126 ) -> Result<()>;
127
128 async fn create_access_token_hash(
129 &self,
130 user_id: UserId,
131 access_token_hash: &str,
132 max_access_token_count: usize,
133 ) -> Result<()>;
134 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
135
136 #[cfg(any(test, feature = "seed-support"))]
137 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
138 #[cfg(any(test, feature = "seed-support"))]
139 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
140 #[cfg(any(test, feature = "seed-support"))]
141 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
142 #[cfg(any(test, feature = "seed-support"))]
143 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
144 #[cfg(any(test, feature = "seed-support"))]
145
146 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
147 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
148 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
149 -> Result<bool>;
150
151 #[cfg(any(test, feature = "seed-support"))]
152 async fn add_channel_member(
153 &self,
154 channel_id: ChannelId,
155 user_id: UserId,
156 is_admin: bool,
157 ) -> Result<()>;
158 async fn create_channel_message(
159 &self,
160 channel_id: ChannelId,
161 sender_id: UserId,
162 body: &str,
163 timestamp: OffsetDateTime,
164 nonce: u128,
165 ) -> Result<MessageId>;
166 async fn get_channel_messages(
167 &self,
168 channel_id: ChannelId,
169 count: usize,
170 before_id: Option<MessageId>,
171 ) -> Result<Vec<ChannelMessage>>;
172
173 #[cfg(test)]
174 async fn teardown(&self, url: &str);
175
176 #[cfg(test)]
177 fn as_fake(&self) -> Option<&FakeDb>;
178}
179
180#[cfg(any(test, debug_assertions))]
181pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
182 Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
183
184#[cfg(not(any(test, debug_assertions)))]
185pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
186
187pub struct PostgresDb {
188 pool: sqlx::PgPool,
189}
190
191macro_rules! test_support {
192 ($self:ident, { $($token:tt)* }) => {{
193 let body = async {
194 $($token)*
195 };
196
197 if cfg!(test) {
198 tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build().unwrap().block_on(body)
199 } else {
200 body.await
201 }
202 }};
203}
204
205impl PostgresDb {
206 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
207 let pool = DbOptions::new()
208 .max_connections(max_connections)
209 .connect(url)
210 .await
211 .context("failed to connect to postgres database")?;
212 Ok(Self { pool })
213 }
214
215 pub async fn migrate(
216 &self,
217 migrations_path: &Path,
218 ignore_checksum_mismatch: bool,
219 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
220 let migrations = MigrationSource::resolve(migrations_path)
221 .await
222 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
223
224 let mut conn = self.pool.acquire().await?;
225
226 conn.ensure_migrations_table().await?;
227 let applied_migrations: HashMap<_, _> = conn
228 .list_applied_migrations()
229 .await?
230 .into_iter()
231 .map(|m| (m.version, m))
232 .collect();
233
234 let mut new_migrations = Vec::new();
235 for migration in migrations {
236 match applied_migrations.get(&migration.version) {
237 Some(applied_migration) => {
238 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
239 {
240 Err(anyhow!(
241 "checksum mismatch for applied migration {}",
242 migration.description
243 ))?;
244 }
245 }
246 None => {
247 let elapsed = conn.apply(&migration).await?;
248 new_migrations.push((migration, elapsed));
249 }
250 }
251 }
252
253 Ok(new_migrations)
254 }
255
256 pub fn fuzzy_like_string(string: &str) -> String {
257 let mut result = String::with_capacity(string.len() * 2 + 1);
258 for c in string.chars() {
259 if c.is_alphanumeric() {
260 result.push('%');
261 result.push(c);
262 }
263 }
264 result.push('%');
265 result
266 }
267}
268
269#[async_trait]
270impl Db for PostgresDb {
271 // users
272
273 async fn create_user(
274 &self,
275 email_address: &str,
276 admin: bool,
277 params: NewUserParams,
278 ) -> Result<NewUserResult> {
279 test_support!(self, {
280 let query = "
281 INSERT INTO users (email_address, github_login, github_user_id, admin)
282 VALUES ($1, $2, $3, $4)
283 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
284 RETURNING id, metrics_id::text
285 ";
286
287 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
288 .bind(email_address)
289 .bind(params.github_login)
290 .bind(params.github_user_id)
291 .bind(admin)
292 .fetch_one(&self.pool)
293 .await?;
294 Ok(NewUserResult {
295 user_id,
296 metrics_id,
297 signup_device_id: None,
298 inviting_user_id: None,
299 })
300 })
301 }
302
303 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
304 test_support!(self, {
305 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
306 Ok(sqlx::query_as(query)
307 .bind(limit as i32)
308 .bind((page * limit) as i32)
309 .fetch_all(&self.pool)
310 .await?)
311 })
312 }
313
314 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
315 test_support!(self, {
316 let like_string = Self::fuzzy_like_string(name_query);
317 let query = "
318 SELECT users.*
319 FROM users
320 WHERE github_login ILIKE $1
321 ORDER BY github_login <-> $2
322 LIMIT $3
323 ";
324 Ok(sqlx::query_as(query)
325 .bind(like_string)
326 .bind(name_query)
327 .bind(limit as i32)
328 .fetch_all(&self.pool)
329 .await?)
330 })
331 }
332
333 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
334 let users = self.get_users_by_ids(vec![id]).await?;
335 Ok(users.into_iter().next())
336 }
337
338 async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
339 test_support!(self, {
340 let query = "
341 SELECT metrics_id::text
342 FROM users
343 WHERE id = $1
344 ";
345 Ok(sqlx::query_scalar(query)
346 .bind(id)
347 .fetch_one(&self.pool)
348 .await?)
349 })
350 }
351
352 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
353 test_support!(self, {
354 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
355 let query = "
356 SELECT users.*
357 FROM users
358 WHERE users.id = ANY ($1)
359 ";
360 Ok(sqlx::query_as(query)
361 .bind(&ids)
362 .fetch_all(&self.pool)
363 .await?)
364 })
365 }
366
367 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
368 test_support!(self, {
369 let query = format!(
370 "
371 SELECT users.*
372 FROM users
373 WHERE invite_count = 0
374 AND inviter_id IS{} NULL
375 ",
376 if invited_by_another_user { " NOT" } else { "" }
377 );
378
379 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
380 })
381 }
382
383 async fn get_user_by_github_account(
384 &self,
385 github_login: &str,
386 github_user_id: Option<i32>,
387 ) -> Result<Option<User>> {
388 test_support!(self, {
389 if let Some(github_user_id) = github_user_id {
390 let mut user = sqlx::query_as::<_, User>(
391 "
392 UPDATE users
393 SET github_login = $1
394 WHERE github_user_id = $2
395 RETURNING *
396 ",
397 )
398 .bind(github_login)
399 .bind(github_user_id)
400 .fetch_optional(&self.pool)
401 .await?;
402
403 if user.is_none() {
404 user = sqlx::query_as::<_, User>(
405 "
406 UPDATE users
407 SET github_user_id = $1
408 WHERE github_login = $2
409 RETURNING *
410 ",
411 )
412 .bind(github_user_id)
413 .bind(github_login)
414 .fetch_optional(&self.pool)
415 .await?;
416 }
417
418 Ok(user)
419 } else {
420 let user = sqlx::query_as(
421 "
422 SELECT * FROM users
423 WHERE github_login = $1
424 LIMIT 1
425 ",
426 )
427 .bind(github_login)
428 .fetch_optional(&self.pool)
429 .await?;
430 Ok(user)
431 }
432 })
433 }
434
435 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
436 test_support!(self, {
437 let query = "UPDATE users SET admin = $1 WHERE id = $2";
438 Ok(sqlx::query(query)
439 .bind(is_admin)
440 .bind(id.0)
441 .execute(&self.pool)
442 .await
443 .map(drop)?)
444 })
445 }
446
447 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
448 test_support!(self, {
449 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
450 Ok(sqlx::query(query)
451 .bind(connected_once)
452 .bind(id.0)
453 .execute(&self.pool)
454 .await
455 .map(drop)?)
456 })
457 }
458
459 async fn destroy_user(&self, id: UserId) -> Result<()> {
460 test_support!(self, {
461 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
462 sqlx::query(query)
463 .bind(id.0)
464 .execute(&self.pool)
465 .await
466 .map(drop)?;
467 let query = "DELETE FROM users WHERE id = $1;";
468 Ok(sqlx::query(query)
469 .bind(id.0)
470 .execute(&self.pool)
471 .await
472 .map(drop)?)
473 })
474 }
475
476 // signups
477
478 async fn create_signup(&self, signup: Signup) -> Result<()> {
479 test_support!(self, {
480 sqlx::query(
481 "
482 INSERT INTO signups
483 (
484 email_address,
485 email_confirmation_code,
486 email_confirmation_sent,
487 platform_linux,
488 platform_mac,
489 platform_windows,
490 platform_unknown,
491 editor_features,
492 programming_languages,
493 device_id
494 )
495 VALUES
496 ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
497 RETURNING id
498 ",
499 )
500 .bind(&signup.email_address)
501 .bind(&random_email_confirmation_code())
502 .bind(&signup.platform_linux)
503 .bind(&signup.platform_mac)
504 .bind(&signup.platform_windows)
505 .bind(&signup.editor_features)
506 .bind(&signup.programming_languages)
507 .bind(&signup.device_id)
508 .execute(&self.pool)
509 .await?;
510 Ok(())
511 })
512 }
513
514 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
515 test_support!(self, {
516 Ok(sqlx::query_as(
517 "
518 SELECT
519 COUNT(*) as count,
520 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
521 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
522 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
523 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
524 FROM (
525 SELECT *
526 FROM signups
527 WHERE
528 NOT email_confirmation_sent
529 ) AS unsent
530 ",
531 )
532 .fetch_one(&self.pool)
533 .await?)
534 })
535 }
536
537 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
538 test_support!(self, {
539 Ok(sqlx::query_as(
540 "
541 SELECT
542 email_address, email_confirmation_code
543 FROM signups
544 WHERE
545 NOT email_confirmation_sent AND
546 (platform_mac OR platform_unknown)
547 LIMIT $1
548 ",
549 )
550 .bind(count as i32)
551 .fetch_all(&self.pool)
552 .await?)
553 })
554 }
555
556 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
557 test_support!(self, {
558 sqlx::query(
559 "
560 UPDATE signups
561 SET email_confirmation_sent = 't'
562 WHERE email_address = ANY ($1)
563 ",
564 )
565 .bind(
566 &invites
567 .iter()
568 .map(|s| s.email_address.as_str())
569 .collect::<Vec<_>>(),
570 )
571 .execute(&self.pool)
572 .await?;
573 Ok(())
574 })
575 }
576
577 async fn create_user_from_invite(
578 &self,
579 invite: &Invite,
580 user: NewUserParams,
581 ) -> Result<Option<NewUserResult>> {
582 test_support!(self, {
583 let mut tx = self.pool.begin().await?;
584
585 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
586 i32,
587 Option<UserId>,
588 Option<UserId>,
589 Option<String>,
590 ) = sqlx::query_as(
591 "
592 SELECT id, user_id, inviting_user_id, device_id
593 FROM signups
594 WHERE
595 email_address = $1 AND
596 email_confirmation_code = $2
597 ",
598 )
599 .bind(&invite.email_address)
600 .bind(&invite.email_confirmation_code)
601 .fetch_optional(&mut tx)
602 .await?
603 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
604
605 if existing_user_id.is_some() {
606 return Ok(None);
607 }
608
609 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
610 "
611 INSERT INTO users
612 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
613 VALUES
614 ($1, $2, $3, 'f', $4, $5)
615 ON CONFLICT (github_login) DO UPDATE SET
616 email_address = excluded.email_address,
617 github_user_id = excluded.github_user_id,
618 admin = excluded.admin
619 RETURNING id, metrics_id::text
620 ",
621 )
622 .bind(&invite.email_address)
623 .bind(&user.github_login)
624 .bind(&user.github_user_id)
625 .bind(&user.invite_count)
626 .bind(random_invite_code())
627 .fetch_one(&mut tx)
628 .await?;
629
630 sqlx::query(
631 "
632 UPDATE signups
633 SET user_id = $1
634 WHERE id = $2
635 ",
636 )
637 .bind(&user_id)
638 .bind(&signup_id)
639 .execute(&mut tx)
640 .await?;
641
642 if let Some(inviting_user_id) = inviting_user_id {
643 let id: Option<UserId> = sqlx::query_scalar(
644 "
645 UPDATE users
646 SET invite_count = invite_count - 1
647 WHERE id = $1 AND invite_count > 0
648 RETURNING id
649 ",
650 )
651 .bind(&inviting_user_id)
652 .fetch_optional(&mut tx)
653 .await?;
654
655 if id.is_none() {
656 Err(Error::Http(
657 StatusCode::UNAUTHORIZED,
658 "no invites remaining".to_string(),
659 ))?;
660 }
661
662 sqlx::query(
663 "
664 INSERT INTO contacts
665 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
666 VALUES
667 ($1, $2, 't', 't', 't')
668 ON CONFLICT DO NOTHING
669 ",
670 )
671 .bind(inviting_user_id)
672 .bind(user_id)
673 .execute(&mut tx)
674 .await?;
675 }
676
677 tx.commit().await?;
678 Ok(Some(NewUserResult {
679 user_id,
680 metrics_id,
681 inviting_user_id,
682 signup_device_id,
683 }))
684 })
685 }
686
687 // invite codes
688
689 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
690 test_support!(self, {
691 let mut tx = self.pool.begin().await?;
692 if count > 0 {
693 sqlx::query(
694 "
695 UPDATE users
696 SET invite_code = $1
697 WHERE id = $2 AND invite_code IS NULL
698 ",
699 )
700 .bind(random_invite_code())
701 .bind(id)
702 .execute(&mut tx)
703 .await?;
704 }
705
706 sqlx::query(
707 "
708 UPDATE users
709 SET invite_count = $1
710 WHERE id = $2
711 ",
712 )
713 .bind(count as i32)
714 .bind(id)
715 .execute(&mut tx)
716 .await?;
717 tx.commit().await?;
718 Ok(())
719 })
720 }
721
722 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
723 test_support!(self, {
724 let result: Option<(String, i32)> = sqlx::query_as(
725 "
726 SELECT invite_code, invite_count
727 FROM users
728 WHERE id = $1 AND invite_code IS NOT NULL
729 ",
730 )
731 .bind(id)
732 .fetch_optional(&self.pool)
733 .await?;
734 if let Some((code, count)) = result {
735 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
736 } else {
737 Ok(None)
738 }
739 })
740 }
741
742 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
743 test_support!(self, {
744 sqlx::query_as(
745 "
746 SELECT *
747 FROM users
748 WHERE invite_code = $1
749 ",
750 )
751 .bind(code)
752 .fetch_optional(&self.pool)
753 .await?
754 .ok_or_else(|| {
755 Error::Http(
756 StatusCode::NOT_FOUND,
757 "that invite code does not exist".to_string(),
758 )
759 })
760 })
761 }
762
763 async fn create_invite_from_code(
764 &self,
765 code: &str,
766 email_address: &str,
767 device_id: Option<&str>,
768 ) -> Result<Invite> {
769 test_support!(self, {
770 let mut tx = self.pool.begin().await?;
771
772 let existing_user: Option<UserId> = sqlx::query_scalar(
773 "
774 SELECT id
775 FROM users
776 WHERE email_address = $1
777 ",
778 )
779 .bind(email_address)
780 .fetch_optional(&mut tx)
781 .await?;
782 if existing_user.is_some() {
783 Err(anyhow!("email address is already in use"))?;
784 }
785
786 let row: Option<(UserId, i32)> = sqlx::query_as(
787 "
788 SELECT id, invite_count
789 FROM users
790 WHERE invite_code = $1
791 ",
792 )
793 .bind(code)
794 .fetch_optional(&mut tx)
795 .await?;
796
797 let (inviter_id, invite_count) = match row {
798 Some(row) => row,
799 None => Err(Error::Http(
800 StatusCode::NOT_FOUND,
801 "invite code not found".to_string(),
802 ))?,
803 };
804
805 if invite_count == 0 {
806 Err(Error::Http(
807 StatusCode::UNAUTHORIZED,
808 "no invites remaining".to_string(),
809 ))?;
810 }
811
812 let email_confirmation_code: String = sqlx::query_scalar(
813 "
814 INSERT INTO signups
815 (
816 email_address,
817 email_confirmation_code,
818 email_confirmation_sent,
819 inviting_user_id,
820 platform_linux,
821 platform_mac,
822 platform_windows,
823 platform_unknown,
824 device_id
825 )
826 VALUES
827 ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4)
828 ON CONFLICT (email_address)
829 DO UPDATE SET
830 inviting_user_id = excluded.inviting_user_id
831 RETURNING email_confirmation_code
832 ",
833 )
834 .bind(&email_address)
835 .bind(&random_email_confirmation_code())
836 .bind(&inviter_id)
837 .bind(&device_id)
838 .fetch_one(&mut tx)
839 .await?;
840
841 tx.commit().await?;
842
843 Ok(Invite {
844 email_address: email_address.into(),
845 email_confirmation_code,
846 })
847 })
848 }
849
850 // projects
851
852 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
853 test_support!(self, {
854 Ok(sqlx::query_scalar(
855 "
856 INSERT INTO projects(host_user_id)
857 VALUES ($1)
858 RETURNING id
859 ",
860 )
861 .bind(host_user_id)
862 .fetch_one(&self.pool)
863 .await
864 .map(ProjectId)?)
865 })
866 }
867
868 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
869 test_support!(self, {
870 sqlx::query(
871 "
872 UPDATE projects
873 SET unregistered = 't'
874 WHERE id = $1
875 ",
876 )
877 .bind(project_id)
878 .execute(&self.pool)
879 .await?;
880 Ok(())
881 })
882 }
883
884 async fn update_worktree_extensions(
885 &self,
886 project_id: ProjectId,
887 worktree_id: u64,
888 extensions: HashMap<String, u32>,
889 ) -> Result<()> {
890 test_support!(self, {
891 if extensions.is_empty() {
892 return Ok(());
893 }
894
895 let mut query = QueryBuilder::new(
896 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
897 );
898 query.push_values(extensions, |mut query, (extension, count)| {
899 query
900 .push_bind(project_id)
901 .push_bind(worktree_id as i32)
902 .push_bind(extension)
903 .push_bind(count as i32);
904 });
905 query.push(
906 "
907 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
908 count = excluded.count
909 ",
910 );
911 query.build().execute(&self.pool).await?;
912
913 Ok(())
914 })
915 }
916
917 async fn get_project_extensions(
918 &self,
919 project_id: ProjectId,
920 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
921 test_support!(self, {
922 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
923 struct WorktreeExtension {
924 worktree_id: i32,
925 extension: String,
926 count: i32,
927 }
928
929 let query = "
930 SELECT worktree_id, extension, count
931 FROM worktree_extensions
932 WHERE project_id = $1
933 ";
934 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
935 .bind(&project_id)
936 .fetch_all(&self.pool)
937 .await?;
938
939 let mut extension_counts = HashMap::default();
940 for count in counts {
941 extension_counts
942 .entry(count.worktree_id as u64)
943 .or_insert_with(HashMap::default)
944 .insert(count.extension, count.count as usize);
945 }
946 Ok(extension_counts)
947 })
948 }
949
950 async fn record_user_activity(
951 &self,
952 time_period: Range<OffsetDateTime>,
953 projects: &[(UserId, ProjectId)],
954 ) -> Result<()> {
955 test_support!(self, {
956 let query = "
957 INSERT INTO project_activity_periods
958 (ended_at, duration_millis, user_id, project_id)
959 VALUES
960 ($1, $2, $3, $4);
961 ";
962
963 let mut tx = self.pool.begin().await?;
964 let duration_millis =
965 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
966 for (user_id, project_id) in projects {
967 sqlx::query(query)
968 .bind(time_period.end)
969 .bind(duration_millis)
970 .bind(user_id)
971 .bind(project_id)
972 .execute(&mut tx)
973 .await?;
974 }
975 tx.commit().await?;
976
977 Ok(())
978 })
979 }
980
981 async fn get_active_user_count(
982 &self,
983 time_period: Range<OffsetDateTime>,
984 min_duration: Duration,
985 only_collaborative: bool,
986 ) -> Result<usize> {
987 test_support!(self, {
988 let mut with_clause = String::new();
989 with_clause.push_str("WITH\n");
990 with_clause.push_str(
991 "
992 project_durations AS (
993 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
994 FROM project_activity_periods
995 WHERE $1 < ended_at AND ended_at <= $2
996 GROUP BY user_id, project_id
997 ),
998 ",
999 );
1000 with_clause.push_str(
1001 "
1002 project_collaborators as (
1003 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1004 FROM project_durations
1005 GROUP BY project_id
1006 ),
1007 ",
1008 );
1009
1010 if only_collaborative {
1011 with_clause.push_str(
1012 "
1013 user_durations AS (
1014 SELECT user_id, SUM(project_duration) as total_duration
1015 FROM project_durations, project_collaborators
1016 WHERE
1017 project_durations.project_id = project_collaborators.project_id AND
1018 max_collaborators > 1
1019 GROUP BY user_id
1020 ORDER BY total_duration DESC
1021 LIMIT $3
1022 )
1023 ",
1024 );
1025 } else {
1026 with_clause.push_str(
1027 "
1028 user_durations AS (
1029 SELECT user_id, SUM(project_duration) as total_duration
1030 FROM project_durations
1031 GROUP BY user_id
1032 ORDER BY total_duration DESC
1033 LIMIT $3
1034 )
1035 ",
1036 );
1037 }
1038
1039 let query = format!(
1040 "
1041 {with_clause}
1042 SELECT count(user_durations.user_id)
1043 FROM user_durations
1044 WHERE user_durations.total_duration >= $3
1045 "
1046 );
1047
1048 let count: i64 = sqlx::query_scalar(&query)
1049 .bind(time_period.start)
1050 .bind(time_period.end)
1051 .bind(min_duration.as_millis() as i64)
1052 .fetch_one(&self.pool)
1053 .await?;
1054 Ok(count as usize)
1055 })
1056 }
1057
1058 async fn get_top_users_activity_summary(
1059 &self,
1060 time_period: Range<OffsetDateTime>,
1061 max_user_count: usize,
1062 ) -> Result<Vec<UserActivitySummary>> {
1063 test_support!(self, {
1064 let query = "
1065 WITH
1066 project_durations AS (
1067 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
1068 FROM project_activity_periods
1069 WHERE $1 < ended_at AND ended_at <= $2
1070 GROUP BY user_id, project_id
1071 ),
1072 user_durations AS (
1073 SELECT user_id, SUM(project_duration) as total_duration
1074 FROM project_durations
1075 GROUP BY user_id
1076 ORDER BY total_duration DESC
1077 LIMIT $3
1078 ),
1079 project_collaborators as (
1080 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1081 FROM project_durations
1082 GROUP BY project_id
1083 )
1084 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
1085 FROM user_durations, project_durations, project_collaborators, users
1086 WHERE
1087 user_durations.user_id = project_durations.user_id AND
1088 user_durations.user_id = users.id AND
1089 project_durations.project_id = project_collaborators.project_id
1090 ORDER BY total_duration DESC, user_id ASC, project_id ASC
1091 ";
1092
1093 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
1094 .bind(time_period.start)
1095 .bind(time_period.end)
1096 .bind(max_user_count as i32)
1097 .fetch(&self.pool);
1098
1099 let mut result = Vec::<UserActivitySummary>::new();
1100 while let Some(row) = rows.next().await {
1101 let (user_id, github_login, project_id, duration_millis, project_collaborators) =
1102 row?;
1103 let project_id = project_id;
1104 let duration = Duration::from_millis(duration_millis as u64);
1105 let project_activity = ProjectActivitySummary {
1106 id: project_id,
1107 duration,
1108 max_collaborators: project_collaborators as usize,
1109 };
1110 if let Some(last_summary) = result.last_mut() {
1111 if last_summary.id == user_id {
1112 last_summary.project_activity.push(project_activity);
1113 continue;
1114 }
1115 }
1116 result.push(UserActivitySummary {
1117 id: user_id,
1118 project_activity: vec![project_activity],
1119 github_login,
1120 });
1121 }
1122
1123 Ok(result)
1124 })
1125 }
1126
1127 async fn get_user_activity_timeline(
1128 &self,
1129 time_period: Range<OffsetDateTime>,
1130 user_id: UserId,
1131 ) -> Result<Vec<UserActivityPeriod>> {
1132 test_support!(self, {
1133 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
1134
1135 let query = "
1136 SELECT
1137 project_activity_periods.ended_at,
1138 project_activity_periods.duration_millis,
1139 project_activity_periods.project_id,
1140 worktree_extensions.extension,
1141 worktree_extensions.count
1142 FROM project_activity_periods
1143 LEFT OUTER JOIN
1144 worktree_extensions
1145 ON
1146 project_activity_periods.project_id = worktree_extensions.project_id
1147 WHERE
1148 project_activity_periods.user_id = $1 AND
1149 $2 < project_activity_periods.ended_at AND
1150 project_activity_periods.ended_at <= $3
1151 ORDER BY project_activity_periods.id ASC
1152 ";
1153
1154 let mut rows = sqlx::query_as::<
1155 _,
1156 (
1157 PrimitiveDateTime,
1158 i32,
1159 ProjectId,
1160 Option<String>,
1161 Option<i32>,
1162 ),
1163 >(query)
1164 .bind(user_id)
1165 .bind(time_period.start)
1166 .bind(time_period.end)
1167 .fetch(&self.pool);
1168
1169 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1170 while let Some(row) = rows.next().await {
1171 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1172 let ended_at = ended_at.assume_utc();
1173 let duration = Duration::from_millis(duration_millis as u64);
1174 let started_at = ended_at - duration;
1175 let project_time_periods = time_periods.entry(project_id).or_default();
1176
1177 if let Some(prev_duration) = project_time_periods.last_mut() {
1178 if started_at <= prev_duration.end + COALESCE_THRESHOLD
1179 && ended_at >= prev_duration.start
1180 {
1181 prev_duration.end = cmp::max(prev_duration.end, ended_at);
1182 } else {
1183 project_time_periods.push(UserActivityPeriod {
1184 project_id,
1185 start: started_at,
1186 end: ended_at,
1187 extensions: Default::default(),
1188 });
1189 }
1190 } else {
1191 project_time_periods.push(UserActivityPeriod {
1192 project_id,
1193 start: started_at,
1194 end: ended_at,
1195 extensions: Default::default(),
1196 });
1197 }
1198
1199 if let Some((extension, extension_count)) = extension.zip(extension_count) {
1200 project_time_periods
1201 .last_mut()
1202 .unwrap()
1203 .extensions
1204 .insert(extension, extension_count as usize);
1205 }
1206 }
1207
1208 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1209 durations.sort_unstable_by_key(|duration| duration.start);
1210 Ok(durations)
1211 })
1212 }
1213
1214 // contacts
1215
1216 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1217 test_support!(self, {
1218 let query = "
1219 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1220 FROM contacts
1221 WHERE user_id_a = $1 OR user_id_b = $1;
1222 ";
1223
1224 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1225 .bind(user_id)
1226 .fetch(&self.pool);
1227
1228 let mut contacts = Vec::new();
1229 while let Some(row) = rows.next().await {
1230 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1231
1232 if user_id_a == user_id {
1233 if accepted {
1234 contacts.push(Contact::Accepted {
1235 user_id: user_id_b,
1236 should_notify: should_notify && a_to_b,
1237 });
1238 } else if a_to_b {
1239 contacts.push(Contact::Outgoing { user_id: user_id_b })
1240 } else {
1241 contacts.push(Contact::Incoming {
1242 user_id: user_id_b,
1243 should_notify,
1244 });
1245 }
1246 } else if accepted {
1247 contacts.push(Contact::Accepted {
1248 user_id: user_id_a,
1249 should_notify: should_notify && !a_to_b,
1250 });
1251 } else if a_to_b {
1252 contacts.push(Contact::Incoming {
1253 user_id: user_id_a,
1254 should_notify,
1255 });
1256 } else {
1257 contacts.push(Contact::Outgoing { user_id: user_id_a });
1258 }
1259 }
1260
1261 contacts.sort_unstable_by_key(|contact| contact.user_id());
1262
1263 Ok(contacts)
1264 })
1265 }
1266
1267 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1268 test_support!(self, {
1269 let (id_a, id_b) = if user_id_1 < user_id_2 {
1270 (user_id_1, user_id_2)
1271 } else {
1272 (user_id_2, user_id_1)
1273 };
1274
1275 let query = "
1276 SELECT 1 FROM contacts
1277 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1278 LIMIT 1
1279 ";
1280 Ok(sqlx::query_scalar::<_, i32>(query)
1281 .bind(id_a.0)
1282 .bind(id_b.0)
1283 .fetch_optional(&self.pool)
1284 .await?
1285 .is_some())
1286 })
1287 }
1288
1289 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1290 test_support!(self, {
1291 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1292 (sender_id, receiver_id, true)
1293 } else {
1294 (receiver_id, sender_id, false)
1295 };
1296 let query = "
1297 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1298 VALUES ($1, $2, $3, 'f', 't')
1299 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1300 SET
1301 accepted = 't',
1302 should_notify = 'f'
1303 WHERE
1304 NOT contacts.accepted AND
1305 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1306 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1307 ";
1308 let result = sqlx::query(query)
1309 .bind(id_a.0)
1310 .bind(id_b.0)
1311 .bind(a_to_b)
1312 .execute(&self.pool)
1313 .await?;
1314
1315 if result.rows_affected() == 1 {
1316 Ok(())
1317 } else {
1318 Err(anyhow!("contact already requested"))?
1319 }
1320 })
1321 }
1322
1323 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1324 test_support!(self, {
1325 let (id_a, id_b) = if responder_id < requester_id {
1326 (responder_id, requester_id)
1327 } else {
1328 (requester_id, responder_id)
1329 };
1330 let query = "
1331 DELETE FROM contacts
1332 WHERE user_id_a = $1 AND user_id_b = $2;
1333 ";
1334 let result = sqlx::query(query)
1335 .bind(id_a.0)
1336 .bind(id_b.0)
1337 .execute(&self.pool)
1338 .await?;
1339
1340 if result.rows_affected() == 1 {
1341 Ok(())
1342 } else {
1343 Err(anyhow!("no such contact"))?
1344 }
1345 })
1346 }
1347
1348 async fn dismiss_contact_notification(
1349 &self,
1350 user_id: UserId,
1351 contact_user_id: UserId,
1352 ) -> Result<()> {
1353 test_support!(self, {
1354 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1355 (user_id, contact_user_id, true)
1356 } else {
1357 (contact_user_id, user_id, false)
1358 };
1359
1360 let query = "
1361 UPDATE contacts
1362 SET should_notify = 'f'
1363 WHERE
1364 user_id_a = $1 AND user_id_b = $2 AND
1365 (
1366 (a_to_b = $3 AND accepted) OR
1367 (a_to_b != $3 AND NOT accepted)
1368 );
1369 ";
1370
1371 let result = sqlx::query(query)
1372 .bind(id_a.0)
1373 .bind(id_b.0)
1374 .bind(a_to_b)
1375 .execute(&self.pool)
1376 .await?;
1377
1378 if result.rows_affected() == 0 {
1379 Err(anyhow!("no such contact request"))?;
1380 }
1381
1382 Ok(())
1383 })
1384 }
1385
1386 async fn respond_to_contact_request(
1387 &self,
1388 responder_id: UserId,
1389 requester_id: UserId,
1390 accept: bool,
1391 ) -> Result<()> {
1392 test_support!(self, {
1393 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1394 (responder_id, requester_id, false)
1395 } else {
1396 (requester_id, responder_id, true)
1397 };
1398 let result = if accept {
1399 let query = "
1400 UPDATE contacts
1401 SET accepted = 't', should_notify = 't'
1402 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1403 ";
1404 sqlx::query(query)
1405 .bind(id_a.0)
1406 .bind(id_b.0)
1407 .bind(a_to_b)
1408 .execute(&self.pool)
1409 .await?
1410 } else {
1411 let query = "
1412 DELETE FROM contacts
1413 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1414 ";
1415 sqlx::query(query)
1416 .bind(id_a.0)
1417 .bind(id_b.0)
1418 .bind(a_to_b)
1419 .execute(&self.pool)
1420 .await?
1421 };
1422 if result.rows_affected() == 1 {
1423 Ok(())
1424 } else {
1425 Err(anyhow!("no such contact request"))?
1426 }
1427 })
1428 }
1429
1430 // access tokens
1431
1432 async fn create_access_token_hash(
1433 &self,
1434 user_id: UserId,
1435 access_token_hash: &str,
1436 max_access_token_count: usize,
1437 ) -> Result<()> {
1438 test_support!(self, {
1439 let insert_query = "
1440 INSERT INTO access_tokens (user_id, hash)
1441 VALUES ($1, $2);
1442 ";
1443 let cleanup_query = "
1444 DELETE FROM access_tokens
1445 WHERE id IN (
1446 SELECT id from access_tokens
1447 WHERE user_id = $1
1448 ORDER BY id DESC
1449 OFFSET $3
1450 )
1451 ";
1452
1453 let mut tx = self.pool.begin().await?;
1454 sqlx::query(insert_query)
1455 .bind(user_id.0)
1456 .bind(access_token_hash)
1457 .execute(&mut tx)
1458 .await?;
1459 sqlx::query(cleanup_query)
1460 .bind(user_id.0)
1461 .bind(access_token_hash)
1462 .bind(max_access_token_count as i32)
1463 .execute(&mut tx)
1464 .await?;
1465 Ok(tx.commit().await?)
1466 })
1467 }
1468
1469 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1470 test_support!(self, {
1471 let query = "
1472 SELECT hash
1473 FROM access_tokens
1474 WHERE user_id = $1
1475 ORDER BY id DESC
1476 ";
1477 Ok(sqlx::query_scalar(query)
1478 .bind(user_id.0)
1479 .fetch_all(&self.pool)
1480 .await?)
1481 })
1482 }
1483
1484 // orgs
1485
1486 #[allow(unused)] // Help rust-analyzer
1487 #[cfg(any(test, feature = "seed-support"))]
1488 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1489 test_support!(self, {
1490 let query = "
1491 SELECT *
1492 FROM orgs
1493 WHERE slug = $1
1494 ";
1495 Ok(sqlx::query_as(query)
1496 .bind(slug)
1497 .fetch_optional(&self.pool)
1498 .await?)
1499 })
1500 }
1501
1502 #[cfg(any(test, feature = "seed-support"))]
1503 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1504 test_support!(self, {
1505 let query = "
1506 INSERT INTO orgs (name, slug)
1507 VALUES ($1, $2)
1508 RETURNING id
1509 ";
1510 Ok(sqlx::query_scalar(query)
1511 .bind(name)
1512 .bind(slug)
1513 .fetch_one(&self.pool)
1514 .await
1515 .map(OrgId)?)
1516 })
1517 }
1518
1519 #[cfg(any(test, feature = "seed-support"))]
1520 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1521 test_support!(self, {
1522 let query = "
1523 INSERT INTO org_memberships (org_id, user_id, admin)
1524 VALUES ($1, $2, $3)
1525 ON CONFLICT DO NOTHING
1526 ";
1527 Ok(sqlx::query(query)
1528 .bind(org_id.0)
1529 .bind(user_id.0)
1530 .bind(is_admin)
1531 .execute(&self.pool)
1532 .await
1533 .map(drop)?)
1534 })
1535 }
1536
1537 // channels
1538
1539 #[cfg(any(test, feature = "seed-support"))]
1540 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1541 test_support!(self, {
1542 let query = "
1543 INSERT INTO channels (owner_id, owner_is_user, name)
1544 VALUES ($1, false, $2)
1545 RETURNING id
1546 ";
1547 Ok(sqlx::query_scalar(query)
1548 .bind(org_id.0)
1549 .bind(name)
1550 .fetch_one(&self.pool)
1551 .await
1552 .map(ChannelId)?)
1553 })
1554 }
1555
1556 #[allow(unused)] // Help rust-analyzer
1557 #[cfg(any(test, feature = "seed-support"))]
1558 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1559 test_support!(self, {
1560 let query = "
1561 SELECT *
1562 FROM channels
1563 WHERE
1564 channels.owner_is_user = false AND
1565 channels.owner_id = $1
1566 ";
1567 Ok(sqlx::query_as(query)
1568 .bind(org_id.0)
1569 .fetch_all(&self.pool)
1570 .await?)
1571 })
1572 }
1573
1574 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1575 test_support!(self, {
1576 let query = "
1577 SELECT
1578 channels.*
1579 FROM
1580 channel_memberships, channels
1581 WHERE
1582 channel_memberships.user_id = $1 AND
1583 channel_memberships.channel_id = channels.id
1584 ";
1585 Ok(sqlx::query_as(query)
1586 .bind(user_id.0)
1587 .fetch_all(&self.pool)
1588 .await?)
1589 })
1590 }
1591
1592 async fn can_user_access_channel(
1593 &self,
1594 user_id: UserId,
1595 channel_id: ChannelId,
1596 ) -> Result<bool> {
1597 test_support!(self, {
1598 let query = "
1599 SELECT id
1600 FROM channel_memberships
1601 WHERE user_id = $1 AND channel_id = $2
1602 LIMIT 1
1603 ";
1604 Ok(sqlx::query_scalar::<_, i32>(query)
1605 .bind(user_id.0)
1606 .bind(channel_id.0)
1607 .fetch_optional(&self.pool)
1608 .await
1609 .map(|e| e.is_some())?)
1610 })
1611 }
1612
1613 #[cfg(any(test, feature = "seed-support"))]
1614 async fn add_channel_member(
1615 &self,
1616 channel_id: ChannelId,
1617 user_id: UserId,
1618 is_admin: bool,
1619 ) -> Result<()> {
1620 test_support!(self, {
1621 let query = "
1622 INSERT INTO channel_memberships (channel_id, user_id, admin)
1623 VALUES ($1, $2, $3)
1624 ON CONFLICT DO NOTHING
1625 ";
1626 Ok(sqlx::query(query)
1627 .bind(channel_id.0)
1628 .bind(user_id.0)
1629 .bind(is_admin)
1630 .execute(&self.pool)
1631 .await
1632 .map(drop)?)
1633 })
1634 }
1635
1636 // messages
1637
1638 async fn create_channel_message(
1639 &self,
1640 channel_id: ChannelId,
1641 sender_id: UserId,
1642 body: &str,
1643 timestamp: OffsetDateTime,
1644 nonce: u128,
1645 ) -> Result<MessageId> {
1646 test_support!(self, {
1647 let query = "
1648 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1649 VALUES ($1, $2, $3, $4, $5)
1650 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1651 RETURNING id
1652 ";
1653 Ok(sqlx::query_scalar(query)
1654 .bind(channel_id.0)
1655 .bind(sender_id.0)
1656 .bind(body)
1657 .bind(timestamp)
1658 .bind(Uuid::from_u128(nonce))
1659 .fetch_one(&self.pool)
1660 .await
1661 .map(MessageId)?)
1662 })
1663 }
1664
1665 async fn get_channel_messages(
1666 &self,
1667 channel_id: ChannelId,
1668 count: usize,
1669 before_id: Option<MessageId>,
1670 ) -> Result<Vec<ChannelMessage>> {
1671 test_support!(self, {
1672 let query = r#"
1673 SELECT * FROM (
1674 SELECT
1675 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1676 FROM
1677 channel_messages
1678 WHERE
1679 channel_id = $1 AND
1680 id < $2
1681 ORDER BY id DESC
1682 LIMIT $3
1683 ) as recent_messages
1684 ORDER BY id ASC
1685 "#;
1686 Ok(sqlx::query_as(query)
1687 .bind(channel_id.0)
1688 .bind(before_id.unwrap_or(MessageId::MAX))
1689 .bind(count as i64)
1690 .fetch_all(&self.pool)
1691 .await?)
1692 })
1693 }
1694
1695 #[cfg(test)]
1696 async fn teardown(&self, url: &str) {
1697 let start = std::time::Instant::now();
1698 eprintln!("tearing down database...");
1699 test_support!(self, {
1700 use util::ResultExt;
1701
1702 let query = "
1703 SELECT pg_terminate_backend(pg_stat_activity.pid)
1704 FROM pg_stat_activity
1705 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1706 ";
1707 sqlx::query(query).execute(&self.pool).await.log_err();
1708 self.pool.close().await;
1709 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1710 .await
1711 .log_err();
1712 eprintln!("tore down database: {:?}", start.elapsed());
1713 })
1714 }
1715
1716 #[cfg(test)]
1717 fn as_fake(&self) -> Option<&FakeDb> {
1718 None
1719 }
1720}
1721
1722macro_rules! id_type {
1723 ($name:ident) => {
1724 #[derive(
1725 Clone,
1726 Copy,
1727 Debug,
1728 Default,
1729 PartialEq,
1730 Eq,
1731 PartialOrd,
1732 Ord,
1733 Hash,
1734 sqlx::Type,
1735 Serialize,
1736 Deserialize,
1737 )]
1738 #[sqlx(transparent)]
1739 #[serde(transparent)]
1740 pub struct $name(pub i32);
1741
1742 impl $name {
1743 #[allow(unused)]
1744 pub const MAX: Self = Self(i32::MAX);
1745
1746 #[allow(unused)]
1747 pub fn from_proto(value: u64) -> Self {
1748 Self(value as i32)
1749 }
1750
1751 #[allow(unused)]
1752 pub fn to_proto(self) -> u64 {
1753 self.0 as u64
1754 }
1755 }
1756
1757 impl std::fmt::Display for $name {
1758 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1759 self.0.fmt(f)
1760 }
1761 }
1762 };
1763}
1764
1765id_type!(UserId);
1766#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1767pub struct User {
1768 pub id: UserId,
1769 pub github_login: String,
1770 pub github_user_id: Option<i32>,
1771 pub email_address: Option<String>,
1772 pub admin: bool,
1773 pub invite_code: Option<String>,
1774 pub invite_count: i32,
1775 pub connected_once: bool,
1776}
1777
1778id_type!(ProjectId);
1779#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1780pub struct Project {
1781 pub id: ProjectId,
1782 pub host_user_id: UserId,
1783 pub unregistered: bool,
1784}
1785
1786#[derive(Clone, Debug, PartialEq, Serialize)]
1787pub struct UserActivitySummary {
1788 pub id: UserId,
1789 pub github_login: String,
1790 pub project_activity: Vec<ProjectActivitySummary>,
1791}
1792
1793#[derive(Clone, Debug, PartialEq, Serialize)]
1794pub struct ProjectActivitySummary {
1795 pub id: ProjectId,
1796 pub duration: Duration,
1797 pub max_collaborators: usize,
1798}
1799
1800#[derive(Clone, Debug, PartialEq, Serialize)]
1801pub struct UserActivityPeriod {
1802 pub project_id: ProjectId,
1803 #[serde(with = "time::serde::iso8601")]
1804 pub start: OffsetDateTime,
1805 #[serde(with = "time::serde::iso8601")]
1806 pub end: OffsetDateTime,
1807 pub extensions: HashMap<String, usize>,
1808}
1809
1810id_type!(OrgId);
1811#[derive(FromRow)]
1812pub struct Org {
1813 pub id: OrgId,
1814 pub name: String,
1815 pub slug: String,
1816}
1817
1818id_type!(ChannelId);
1819#[derive(Clone, Debug, FromRow, Serialize)]
1820pub struct Channel {
1821 pub id: ChannelId,
1822 pub name: String,
1823 pub owner_id: i32,
1824 pub owner_is_user: bool,
1825}
1826
1827id_type!(MessageId);
1828#[derive(Clone, Debug, FromRow)]
1829pub struct ChannelMessage {
1830 pub id: MessageId,
1831 pub channel_id: ChannelId,
1832 pub sender_id: UserId,
1833 pub body: String,
1834 pub sent_at: OffsetDateTime,
1835 pub nonce: Uuid,
1836}
1837
1838#[derive(Clone, Debug, PartialEq, Eq)]
1839pub enum Contact {
1840 Accepted {
1841 user_id: UserId,
1842 should_notify: bool,
1843 },
1844 Outgoing {
1845 user_id: UserId,
1846 },
1847 Incoming {
1848 user_id: UserId,
1849 should_notify: bool,
1850 },
1851}
1852
1853impl Contact {
1854 pub fn user_id(&self) -> UserId {
1855 match self {
1856 Contact::Accepted { user_id, .. } => *user_id,
1857 Contact::Outgoing { user_id } => *user_id,
1858 Contact::Incoming { user_id, .. } => *user_id,
1859 }
1860 }
1861}
1862
1863#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1864pub struct IncomingContactRequest {
1865 pub requester_id: UserId,
1866 pub should_notify: bool,
1867}
1868
1869#[derive(Clone, Deserialize)]
1870pub struct Signup {
1871 pub email_address: String,
1872 pub platform_mac: bool,
1873 pub platform_windows: bool,
1874 pub platform_linux: bool,
1875 pub editor_features: Vec<String>,
1876 pub programming_languages: Vec<String>,
1877 pub device_id: Option<String>,
1878}
1879
1880#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1881pub struct WaitlistSummary {
1882 #[sqlx(default)]
1883 pub count: i64,
1884 #[sqlx(default)]
1885 pub linux_count: i64,
1886 #[sqlx(default)]
1887 pub mac_count: i64,
1888 #[sqlx(default)]
1889 pub windows_count: i64,
1890 #[sqlx(default)]
1891 pub unknown_count: i64,
1892}
1893
1894#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1895pub struct Invite {
1896 pub email_address: String,
1897 pub email_confirmation_code: String,
1898}
1899
1900#[derive(Debug, Serialize, Deserialize)]
1901pub struct NewUserParams {
1902 pub github_login: String,
1903 pub github_user_id: i32,
1904 pub invite_count: i32,
1905}
1906
1907#[derive(Debug)]
1908pub struct NewUserResult {
1909 pub user_id: UserId,
1910 pub metrics_id: String,
1911 pub inviting_user_id: Option<UserId>,
1912 pub signup_device_id: Option<String>,
1913}
1914
1915fn random_invite_code() -> String {
1916 nanoid::nanoid!(16)
1917}
1918
1919fn random_email_confirmation_code() -> String {
1920 nanoid::nanoid!(64)
1921}
1922
1923#[cfg(test)]
1924pub use test::*;
1925
1926#[cfg(test)]
1927mod test {
1928 use super::*;
1929 use anyhow::anyhow;
1930 use collections::BTreeMap;
1931 use gpui::executor::Background;
1932 use lazy_static::lazy_static;
1933 use parking_lot::Mutex;
1934 use rand::prelude::*;
1935 use sqlx::{migrate::MigrateDatabase, Postgres};
1936 use std::sync::Arc;
1937 use util::post_inc;
1938
1939 pub struct FakeDb {
1940 background: Arc<Background>,
1941 pub users: Mutex<BTreeMap<UserId, User>>,
1942 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1943 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1944 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1945 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1946 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1947 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1948 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1949 pub contacts: Mutex<Vec<FakeContact>>,
1950 next_channel_message_id: Mutex<i32>,
1951 next_user_id: Mutex<i32>,
1952 next_org_id: Mutex<i32>,
1953 next_channel_id: Mutex<i32>,
1954 next_project_id: Mutex<i32>,
1955 }
1956
1957 #[derive(Debug)]
1958 pub struct FakeContact {
1959 pub requester_id: UserId,
1960 pub responder_id: UserId,
1961 pub accepted: bool,
1962 pub should_notify: bool,
1963 }
1964
1965 impl FakeDb {
1966 pub fn new(background: Arc<Background>) -> Self {
1967 Self {
1968 background,
1969 users: Default::default(),
1970 next_user_id: Mutex::new(0),
1971 projects: Default::default(),
1972 worktree_extensions: Default::default(),
1973 next_project_id: Mutex::new(1),
1974 orgs: Default::default(),
1975 next_org_id: Mutex::new(1),
1976 org_memberships: Default::default(),
1977 channels: Default::default(),
1978 next_channel_id: Mutex::new(1),
1979 channel_memberships: Default::default(),
1980 channel_messages: Default::default(),
1981 next_channel_message_id: Mutex::new(1),
1982 contacts: Default::default(),
1983 }
1984 }
1985 }
1986
1987 #[async_trait]
1988 impl Db for FakeDb {
1989 async fn create_user(
1990 &self,
1991 email_address: &str,
1992 admin: bool,
1993 params: NewUserParams,
1994 ) -> Result<NewUserResult> {
1995 self.background.simulate_random_delay().await;
1996
1997 let mut users = self.users.lock();
1998 let user_id = if let Some(user) = users
1999 .values()
2000 .find(|user| user.github_login == params.github_login)
2001 {
2002 user.id
2003 } else {
2004 let id = post_inc(&mut *self.next_user_id.lock());
2005 let user_id = UserId(id);
2006 users.insert(
2007 user_id,
2008 User {
2009 id: user_id,
2010 github_login: params.github_login,
2011 github_user_id: Some(params.github_user_id),
2012 email_address: Some(email_address.to_string()),
2013 admin,
2014 invite_code: None,
2015 invite_count: 0,
2016 connected_once: false,
2017 },
2018 );
2019 user_id
2020 };
2021 Ok(NewUserResult {
2022 user_id,
2023 metrics_id: "the-metrics-id".to_string(),
2024 inviting_user_id: None,
2025 signup_device_id: None,
2026 })
2027 }
2028
2029 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2030 unimplemented!()
2031 }
2032
2033 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2034 unimplemented!()
2035 }
2036
2037 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2038 self.background.simulate_random_delay().await;
2039 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2040 }
2041
2042 async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
2043 Ok("the-metrics-id".to_string())
2044 }
2045
2046 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2047 self.background.simulate_random_delay().await;
2048 let users = self.users.lock();
2049 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2050 }
2051
2052 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2053 unimplemented!()
2054 }
2055
2056 async fn get_user_by_github_account(
2057 &self,
2058 github_login: &str,
2059 github_user_id: Option<i32>,
2060 ) -> Result<Option<User>> {
2061 self.background.simulate_random_delay().await;
2062 if let Some(github_user_id) = github_user_id {
2063 for user in self.users.lock().values_mut() {
2064 if user.github_user_id == Some(github_user_id) {
2065 user.github_login = github_login.into();
2066 return Ok(Some(user.clone()));
2067 }
2068 if user.github_login == github_login {
2069 user.github_user_id = Some(github_user_id);
2070 return Ok(Some(user.clone()));
2071 }
2072 }
2073 Ok(None)
2074 } else {
2075 Ok(self
2076 .users
2077 .lock()
2078 .values()
2079 .find(|user| user.github_login == github_login)
2080 .cloned())
2081 }
2082 }
2083
2084 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2085 unimplemented!()
2086 }
2087
2088 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2089 self.background.simulate_random_delay().await;
2090 let mut users = self.users.lock();
2091 let mut user = users
2092 .get_mut(&id)
2093 .ok_or_else(|| anyhow!("user not found"))?;
2094 user.connected_once = connected_once;
2095 Ok(())
2096 }
2097
2098 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2099 unimplemented!()
2100 }
2101
2102 // signups
2103
2104 async fn create_signup(&self, _signup: Signup) -> Result<()> {
2105 unimplemented!()
2106 }
2107
2108 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
2109 unimplemented!()
2110 }
2111
2112 async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
2113 unimplemented!()
2114 }
2115
2116 async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
2117 unimplemented!()
2118 }
2119
2120 async fn create_user_from_invite(
2121 &self,
2122 _invite: &Invite,
2123 _user: NewUserParams,
2124 ) -> Result<Option<NewUserResult>> {
2125 unimplemented!()
2126 }
2127
2128 // invite codes
2129
2130 async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
2131 unimplemented!()
2132 }
2133
2134 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2135 self.background.simulate_random_delay().await;
2136 Ok(None)
2137 }
2138
2139 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2140 unimplemented!()
2141 }
2142
2143 async fn create_invite_from_code(
2144 &self,
2145 _code: &str,
2146 _email_address: &str,
2147 _device_id: Option<&str>,
2148 ) -> Result<Invite> {
2149 unimplemented!()
2150 }
2151
2152 // projects
2153
2154 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2155 self.background.simulate_random_delay().await;
2156 if !self.users.lock().contains_key(&host_user_id) {
2157 Err(anyhow!("no such user"))?;
2158 }
2159
2160 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2161 self.projects.lock().insert(
2162 project_id,
2163 Project {
2164 id: project_id,
2165 host_user_id,
2166 unregistered: false,
2167 },
2168 );
2169 Ok(project_id)
2170 }
2171
2172 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2173 self.background.simulate_random_delay().await;
2174 self.projects
2175 .lock()
2176 .get_mut(&project_id)
2177 .ok_or_else(|| anyhow!("no such project"))?
2178 .unregistered = true;
2179 Ok(())
2180 }
2181
2182 async fn update_worktree_extensions(
2183 &self,
2184 project_id: ProjectId,
2185 worktree_id: u64,
2186 extensions: HashMap<String, u32>,
2187 ) -> Result<()> {
2188 self.background.simulate_random_delay().await;
2189 if !self.projects.lock().contains_key(&project_id) {
2190 Err(anyhow!("no such project"))?;
2191 }
2192
2193 for (extension, count) in extensions {
2194 self.worktree_extensions
2195 .lock()
2196 .insert((project_id, worktree_id, extension), count);
2197 }
2198
2199 Ok(())
2200 }
2201
2202 async fn get_project_extensions(
2203 &self,
2204 _project_id: ProjectId,
2205 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2206 unimplemented!()
2207 }
2208
2209 async fn record_user_activity(
2210 &self,
2211 _time_period: Range<OffsetDateTime>,
2212 _active_projects: &[(UserId, ProjectId)],
2213 ) -> Result<()> {
2214 unimplemented!()
2215 }
2216
2217 async fn get_active_user_count(
2218 &self,
2219 _time_period: Range<OffsetDateTime>,
2220 _min_duration: Duration,
2221 _only_collaborative: bool,
2222 ) -> Result<usize> {
2223 unimplemented!()
2224 }
2225
2226 async fn get_top_users_activity_summary(
2227 &self,
2228 _time_period: Range<OffsetDateTime>,
2229 _limit: usize,
2230 ) -> Result<Vec<UserActivitySummary>> {
2231 unimplemented!()
2232 }
2233
2234 async fn get_user_activity_timeline(
2235 &self,
2236 _time_period: Range<OffsetDateTime>,
2237 _user_id: UserId,
2238 ) -> Result<Vec<UserActivityPeriod>> {
2239 unimplemented!()
2240 }
2241
2242 // contacts
2243
2244 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2245 self.background.simulate_random_delay().await;
2246 let mut contacts = Vec::new();
2247
2248 for contact in self.contacts.lock().iter() {
2249 if contact.requester_id == id {
2250 if contact.accepted {
2251 contacts.push(Contact::Accepted {
2252 user_id: contact.responder_id,
2253 should_notify: contact.should_notify,
2254 });
2255 } else {
2256 contacts.push(Contact::Outgoing {
2257 user_id: contact.responder_id,
2258 });
2259 }
2260 } else if contact.responder_id == id {
2261 if contact.accepted {
2262 contacts.push(Contact::Accepted {
2263 user_id: contact.requester_id,
2264 should_notify: false,
2265 });
2266 } else {
2267 contacts.push(Contact::Incoming {
2268 user_id: contact.requester_id,
2269 should_notify: contact.should_notify,
2270 });
2271 }
2272 }
2273 }
2274
2275 contacts.sort_unstable_by_key(|contact| contact.user_id());
2276 Ok(contacts)
2277 }
2278
2279 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2280 self.background.simulate_random_delay().await;
2281 Ok(self.contacts.lock().iter().any(|contact| {
2282 contact.accepted
2283 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2284 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2285 }))
2286 }
2287
2288 async fn send_contact_request(
2289 &self,
2290 requester_id: UserId,
2291 responder_id: UserId,
2292 ) -> Result<()> {
2293 self.background.simulate_random_delay().await;
2294 let mut contacts = self.contacts.lock();
2295 for contact in contacts.iter_mut() {
2296 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2297 if contact.accepted {
2298 Err(anyhow!("contact already exists"))?;
2299 } else {
2300 Err(anyhow!("contact already requested"))?;
2301 }
2302 }
2303 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2304 if contact.accepted {
2305 Err(anyhow!("contact already exists"))?;
2306 } else {
2307 contact.accepted = true;
2308 contact.should_notify = false;
2309 return Ok(());
2310 }
2311 }
2312 }
2313 contacts.push(FakeContact {
2314 requester_id,
2315 responder_id,
2316 accepted: false,
2317 should_notify: true,
2318 });
2319 Ok(())
2320 }
2321
2322 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2323 self.background.simulate_random_delay().await;
2324 self.contacts.lock().retain(|contact| {
2325 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2326 });
2327 Ok(())
2328 }
2329
2330 async fn dismiss_contact_notification(
2331 &self,
2332 user_id: UserId,
2333 contact_user_id: UserId,
2334 ) -> Result<()> {
2335 self.background.simulate_random_delay().await;
2336 let mut contacts = self.contacts.lock();
2337 for contact in contacts.iter_mut() {
2338 if contact.requester_id == contact_user_id
2339 && contact.responder_id == user_id
2340 && !contact.accepted
2341 {
2342 contact.should_notify = false;
2343 return Ok(());
2344 }
2345 if contact.requester_id == user_id
2346 && contact.responder_id == contact_user_id
2347 && contact.accepted
2348 {
2349 contact.should_notify = false;
2350 return Ok(());
2351 }
2352 }
2353 Err(anyhow!("no such notification"))?
2354 }
2355
2356 async fn respond_to_contact_request(
2357 &self,
2358 responder_id: UserId,
2359 requester_id: UserId,
2360 accept: bool,
2361 ) -> Result<()> {
2362 self.background.simulate_random_delay().await;
2363 let mut contacts = self.contacts.lock();
2364 for (ix, contact) in contacts.iter_mut().enumerate() {
2365 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2366 if contact.accepted {
2367 Err(anyhow!("contact already confirmed"))?;
2368 }
2369 if accept {
2370 contact.accepted = true;
2371 contact.should_notify = true;
2372 } else {
2373 contacts.remove(ix);
2374 }
2375 return Ok(());
2376 }
2377 }
2378 Err(anyhow!("no such contact request"))?
2379 }
2380
2381 async fn create_access_token_hash(
2382 &self,
2383 _user_id: UserId,
2384 _access_token_hash: &str,
2385 _max_access_token_count: usize,
2386 ) -> Result<()> {
2387 unimplemented!()
2388 }
2389
2390 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2391 unimplemented!()
2392 }
2393
2394 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2395 unimplemented!()
2396 }
2397
2398 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2399 self.background.simulate_random_delay().await;
2400 let mut orgs = self.orgs.lock();
2401 if orgs.values().any(|org| org.slug == slug) {
2402 Err(anyhow!("org already exists"))?
2403 } else {
2404 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2405 orgs.insert(
2406 org_id,
2407 Org {
2408 id: org_id,
2409 name: name.to_string(),
2410 slug: slug.to_string(),
2411 },
2412 );
2413 Ok(org_id)
2414 }
2415 }
2416
2417 async fn add_org_member(
2418 &self,
2419 org_id: OrgId,
2420 user_id: UserId,
2421 is_admin: bool,
2422 ) -> Result<()> {
2423 self.background.simulate_random_delay().await;
2424 if !self.orgs.lock().contains_key(&org_id) {
2425 Err(anyhow!("org does not exist"))?;
2426 }
2427 if !self.users.lock().contains_key(&user_id) {
2428 Err(anyhow!("user does not exist"))?;
2429 }
2430
2431 self.org_memberships
2432 .lock()
2433 .entry((org_id, user_id))
2434 .or_insert(is_admin);
2435 Ok(())
2436 }
2437
2438 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2439 self.background.simulate_random_delay().await;
2440 if !self.orgs.lock().contains_key(&org_id) {
2441 Err(anyhow!("org does not exist"))?;
2442 }
2443
2444 let mut channels = self.channels.lock();
2445 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2446 channels.insert(
2447 channel_id,
2448 Channel {
2449 id: channel_id,
2450 name: name.to_string(),
2451 owner_id: org_id.0,
2452 owner_is_user: false,
2453 },
2454 );
2455 Ok(channel_id)
2456 }
2457
2458 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2459 self.background.simulate_random_delay().await;
2460 Ok(self
2461 .channels
2462 .lock()
2463 .values()
2464 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2465 .cloned()
2466 .collect())
2467 }
2468
2469 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2470 self.background.simulate_random_delay().await;
2471 let channels = self.channels.lock();
2472 let memberships = self.channel_memberships.lock();
2473 Ok(channels
2474 .values()
2475 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2476 .cloned()
2477 .collect())
2478 }
2479
2480 async fn can_user_access_channel(
2481 &self,
2482 user_id: UserId,
2483 channel_id: ChannelId,
2484 ) -> Result<bool> {
2485 self.background.simulate_random_delay().await;
2486 Ok(self
2487 .channel_memberships
2488 .lock()
2489 .contains_key(&(channel_id, user_id)))
2490 }
2491
2492 async fn add_channel_member(
2493 &self,
2494 channel_id: ChannelId,
2495 user_id: UserId,
2496 is_admin: bool,
2497 ) -> Result<()> {
2498 self.background.simulate_random_delay().await;
2499 if !self.channels.lock().contains_key(&channel_id) {
2500 Err(anyhow!("channel does not exist"))?;
2501 }
2502 if !self.users.lock().contains_key(&user_id) {
2503 Err(anyhow!("user does not exist"))?;
2504 }
2505
2506 self.channel_memberships
2507 .lock()
2508 .entry((channel_id, user_id))
2509 .or_insert(is_admin);
2510 Ok(())
2511 }
2512
2513 async fn create_channel_message(
2514 &self,
2515 channel_id: ChannelId,
2516 sender_id: UserId,
2517 body: &str,
2518 timestamp: OffsetDateTime,
2519 nonce: u128,
2520 ) -> Result<MessageId> {
2521 self.background.simulate_random_delay().await;
2522 if !self.channels.lock().contains_key(&channel_id) {
2523 Err(anyhow!("channel does not exist"))?;
2524 }
2525 if !self.users.lock().contains_key(&sender_id) {
2526 Err(anyhow!("user does not exist"))?;
2527 }
2528
2529 let mut messages = self.channel_messages.lock();
2530 if let Some(message) = messages
2531 .values()
2532 .find(|message| message.nonce.as_u128() == nonce)
2533 {
2534 Ok(message.id)
2535 } else {
2536 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2537 messages.insert(
2538 message_id,
2539 ChannelMessage {
2540 id: message_id,
2541 channel_id,
2542 sender_id,
2543 body: body.to_string(),
2544 sent_at: timestamp,
2545 nonce: Uuid::from_u128(nonce),
2546 },
2547 );
2548 Ok(message_id)
2549 }
2550 }
2551
2552 async fn get_channel_messages(
2553 &self,
2554 channel_id: ChannelId,
2555 count: usize,
2556 before_id: Option<MessageId>,
2557 ) -> Result<Vec<ChannelMessage>> {
2558 self.background.simulate_random_delay().await;
2559 let mut messages = self
2560 .channel_messages
2561 .lock()
2562 .values()
2563 .rev()
2564 .filter(|message| {
2565 message.channel_id == channel_id
2566 && message.id < before_id.unwrap_or(MessageId::MAX)
2567 })
2568 .take(count)
2569 .cloned()
2570 .collect::<Vec<_>>();
2571 messages.sort_unstable_by_key(|message| message.id);
2572 Ok(messages)
2573 }
2574
2575 async fn teardown(&self, _: &str) {}
2576
2577 #[cfg(test)]
2578 fn as_fake(&self) -> Option<&FakeDb> {
2579 Some(self)
2580 }
2581 }
2582
2583 pub struct TestDb {
2584 pub db: Option<Arc<dyn Db>>,
2585 pub url: String,
2586 }
2587
2588 impl TestDb {
2589 #[allow(clippy::await_holding_lock)]
2590 pub async fn postgres() -> Self {
2591 lazy_static! {
2592 static ref LOCK: Mutex<()> = Mutex::new(());
2593 }
2594
2595 eprintln!("creating database...");
2596 let start = std::time::Instant::now();
2597 let _guard = LOCK.lock();
2598 let mut rng = StdRng::from_entropy();
2599 let name = format!("zed-test-{}", rng.gen::<u128>());
2600 let url = format!("postgres://postgres@localhost:5433/{}", name);
2601 Postgres::create_database(&url)
2602 .await
2603 .expect("failed to create test db");
2604 let db = PostgresDb::new(&url, 5).await.unwrap();
2605 db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false)
2606 .await
2607 .unwrap();
2608
2609 eprintln!("created database: {:?}", start.elapsed());
2610 Self {
2611 db: Some(Arc::new(db)),
2612 url,
2613 }
2614 }
2615
2616 pub fn fake(background: Arc<Background>) -> Self {
2617 Self {
2618 db: Some(Arc::new(FakeDb::new(background))),
2619 url: Default::default(),
2620 }
2621 }
2622
2623 pub fn db(&self) -> &Arc<dyn Db> {
2624 self.db.as_ref().unwrap()
2625 }
2626 }
2627
2628 impl Drop for TestDb {
2629 fn drop(&mut self) {
2630 if let Some(db) = self.db.take() {
2631 futures::executor::block_on(db.teardown(&self.url));
2632 }
2633 }
2634 }
2635}