1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use collections::HashMap;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8use sqlx::{
9 migrate::{Migrate as _, Migration, MigrationSource},
10 types::Uuid,
11 FromRow, QueryBuilder,
12};
13use std::{cmp, ops::Range, path::Path, time::Duration};
14use time::{OffsetDateTime, PrimitiveDateTime};
15
16#[async_trait]
17pub trait Db: Send + Sync {
18 async fn create_user(
19 &self,
20 email_address: &str,
21 admin: bool,
22 params: NewUserParams,
23 ) -> Result<NewUserResult>;
24 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
25 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
26 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
27 async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
28 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
29 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
30 async fn get_user_by_github_account(
31 &self,
32 github_login: &str,
33 github_user_id: Option<i32>,
34 ) -> Result<Option<User>>;
35 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
36 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
37 async fn destroy_user(&self, id: UserId) -> Result<()>;
38
39 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
40 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
41 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
42 async fn create_invite_from_code(
43 &self,
44 code: &str,
45 email_address: &str,
46 device_id: Option<&str>,
47 ) -> Result<Invite>;
48
49 async fn create_signup(&self, signup: Signup) -> Result<()>;
50 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
51 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
52 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
53 async fn create_user_from_invite(
54 &self,
55 invite: &Invite,
56 user: NewUserParams,
57 ) -> Result<Option<NewUserResult>>;
58
59 /// Registers a new project for the given user.
60 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
61
62 /// Unregisters a project for the given project id.
63 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
64
65 /// Update file counts by extension for the given project and worktree.
66 async fn update_worktree_extensions(
67 &self,
68 project_id: ProjectId,
69 worktree_id: u64,
70 extensions: HashMap<String, u32>,
71 ) -> Result<()>;
72
73 /// Get the file counts on the given project keyed by their worktree and extension.
74 async fn get_project_extensions(
75 &self,
76 project_id: ProjectId,
77 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
78
79 /// Record which users have been active in which projects during
80 /// a given period of time.
81 async fn record_user_activity(
82 &self,
83 time_period: Range<OffsetDateTime>,
84 active_projects: &[(UserId, ProjectId)],
85 ) -> Result<()>;
86
87 /// Get the number of users who have been active in the given
88 /// time period for at least the given time duration.
89 async fn get_active_user_count(
90 &self,
91 time_period: Range<OffsetDateTime>,
92 min_duration: Duration,
93 only_collaborative: bool,
94 ) -> Result<usize>;
95
96 /// Get the users that have been most active during the given time period,
97 /// along with the amount of time they have been active in each project.
98 async fn get_top_users_activity_summary(
99 &self,
100 time_period: Range<OffsetDateTime>,
101 max_user_count: usize,
102 ) -> Result<Vec<UserActivitySummary>>;
103
104 /// Get the project activity for the given user and time period.
105 async fn get_user_activity_timeline(
106 &self,
107 time_period: Range<OffsetDateTime>,
108 user_id: UserId,
109 ) -> Result<Vec<UserActivityPeriod>>;
110
111 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
112 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
113 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
114 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
115 async fn dismiss_contact_notification(
116 &self,
117 responder_id: UserId,
118 requester_id: UserId,
119 ) -> Result<()>;
120 async fn respond_to_contact_request(
121 &self,
122 responder_id: UserId,
123 requester_id: UserId,
124 accept: bool,
125 ) -> Result<()>;
126
127 async fn create_access_token_hash(
128 &self,
129 user_id: UserId,
130 access_token_hash: &str,
131 max_access_token_count: usize,
132 ) -> Result<()>;
133 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
134
135 #[cfg(any(test, feature = "seed-support"))]
136 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
137 #[cfg(any(test, feature = "seed-support"))]
138 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
139 #[cfg(any(test, feature = "seed-support"))]
140 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
141 #[cfg(any(test, feature = "seed-support"))]
142 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
143 #[cfg(any(test, feature = "seed-support"))]
144
145 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
146 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
147 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
148 -> Result<bool>;
149
150 #[cfg(any(test, feature = "seed-support"))]
151 async fn add_channel_member(
152 &self,
153 channel_id: ChannelId,
154 user_id: UserId,
155 is_admin: bool,
156 ) -> Result<()>;
157 async fn create_channel_message(
158 &self,
159 channel_id: ChannelId,
160 sender_id: UserId,
161 body: &str,
162 timestamp: OffsetDateTime,
163 nonce: u128,
164 ) -> Result<MessageId>;
165 async fn get_channel_messages(
166 &self,
167 channel_id: ChannelId,
168 count: usize,
169 before_id: Option<MessageId>,
170 ) -> Result<Vec<ChannelMessage>>;
171
172 #[cfg(test)]
173 async fn teardown(&self, url: &str);
174
175 #[cfg(test)]
176 fn as_fake(&self) -> Option<&FakeDb>;
177}
178
179#[cfg(any(test, debug_assertions))]
180pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
181 Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
182
183pub const TEST_MIGRATIONS_PATH: Option<&'static str> =
184 Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"));
185
186#[cfg(not(any(test, debug_assertions)))]
187pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
188
189pub struct RealDb {
190 pool: sqlx::SqlitePool,
191}
192
193macro_rules! test_support {
194 ($self:ident, { $($token:tt)* }) => {{
195 let body = async {
196 $($token)*
197 };
198
199 if cfg!(test) {
200 tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build().unwrap().block_on(body)
201 } else {
202 body.await
203 }
204 }};
205}
206
207impl RealDb {
208 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
209 eprintln!("{url}");
210 let pool = sqlx::sqlite::SqlitePoolOptions::new()
211 .max_connections(1)
212 .connect(url)
213 .await?;
214 Ok(Self { pool })
215 }
216
217 pub async fn migrate(
218 &self,
219 migrations_path: &Path,
220 ignore_checksum_mismatch: bool,
221 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
222 let migrations = MigrationSource::resolve(migrations_path)
223 .await
224 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
225
226 let mut conn = self.pool.acquire().await?;
227
228 conn.ensure_migrations_table().await?;
229 let applied_migrations: HashMap<_, _> = conn
230 .list_applied_migrations()
231 .await?
232 .into_iter()
233 .map(|m| (m.version, m))
234 .collect();
235
236 let mut new_migrations = Vec::new();
237 for migration in migrations {
238 match applied_migrations.get(&migration.version) {
239 Some(applied_migration) => {
240 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
241 {
242 Err(anyhow!(
243 "checksum mismatch for applied migration {}",
244 migration.description
245 ))?;
246 }
247 }
248 None => {
249 let elapsed = conn.apply(&migration).await?;
250 new_migrations.push((migration, elapsed));
251 }
252 }
253 }
254
255 Ok(new_migrations)
256 }
257
258 pub fn fuzzy_like_string(string: &str) -> String {
259 let mut result = String::with_capacity(string.len() * 2 + 1);
260 for c in string.chars() {
261 if c.is_alphanumeric() {
262 result.push('%');
263 result.push(c);
264 }
265 }
266 result.push('%');
267 result
268 }
269}
270
271#[async_trait]
272impl Db for RealDb {
273 // users
274
275 async fn create_user(
276 &self,
277 email_address: &str,
278 admin: bool,
279 params: NewUserParams,
280 ) -> Result<NewUserResult> {
281 test_support!(self, {
282 let query = "
283 INSERT INTO users (email_address, github_login, github_user_id, admin)
284 VALUES ($1, $2, $3, $4)
285 -- ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
286 RETURNING id, 'the-metrics-id'
287 ";
288
289 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
290 .bind(email_address)
291 .bind(params.github_login)
292 .bind(params.github_user_id)
293 .bind(admin)
294 .fetch_one(&self.pool)
295 .await?;
296 Ok(NewUserResult {
297 user_id,
298 metrics_id,
299 signup_device_id: None,
300 inviting_user_id: None,
301 })
302 })
303 }
304
305 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
306 test_support!(self, {
307 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
308 Ok(sqlx::query_as(query)
309 .bind(limit as i32)
310 .bind((page * limit) as i32)
311 .fetch_all(&self.pool)
312 .await?)
313 })
314 }
315
316 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
317 test_support!(self, {
318 let like_string = Self::fuzzy_like_string(name_query);
319 let query = "
320 SELECT users.*
321 FROM users
322 WHERE github_login ILIKE $1
323 ORDER BY github_login <-> $2
324 LIMIT $3
325 ";
326 Ok(sqlx::query_as(query)
327 .bind(like_string)
328 .bind(name_query)
329 .bind(limit as i32)
330 .fetch_all(&self.pool)
331 .await?)
332 })
333 }
334
335 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
336 test_support!(self, {
337 let query = "
338 SELECT users.*
339 FROM users
340 WHERE id = $1
341 LIMIT 1
342 ";
343 Ok(sqlx::query_as(query)
344 .bind(&id)
345 .fetch_optional(&self.pool)
346 .await?)
347 })
348 }
349
350 async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
351 test_support!(self, {
352 let query = "
353 SELECT metrics_id::text
354 FROM users
355 WHERE id = $1
356 ";
357 Ok(sqlx::query_scalar(query)
358 .bind(id)
359 .fetch_one(&self.pool)
360 .await?)
361 })
362 }
363
364 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
365 test_support!(self, {
366 let query = "
367 SELECT users.*
368 FROM users
369 WHERE users.id IN (SELECT value from json_each($1))
370 ";
371 Ok(sqlx::query_as(query)
372 .bind(&serde_json::json!(ids))
373 .fetch_all(&self.pool)
374 .await?)
375 })
376 }
377
378 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
379 test_support!(self, {
380 let query = format!(
381 "
382 SELECT users.*
383 FROM users
384 WHERE invite_count = 0
385 AND inviter_id IS{} NULL
386 ",
387 if invited_by_another_user { " NOT" } else { "" }
388 );
389
390 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
391 })
392 }
393
394 async fn get_user_by_github_account(
395 &self,
396 github_login: &str,
397 github_user_id: Option<i32>,
398 ) -> Result<Option<User>> {
399 test_support!(self, {
400 if let Some(github_user_id) = github_user_id {
401 let mut user = sqlx::query_as::<_, User>(
402 "
403 UPDATE users
404 SET github_login = $1
405 WHERE github_user_id = $2
406 RETURNING *
407 ",
408 )
409 .bind(github_login)
410 .bind(github_user_id)
411 .fetch_optional(&self.pool)
412 .await?;
413
414 if user.is_none() {
415 user = sqlx::query_as::<_, User>(
416 "
417 UPDATE users
418 SET github_user_id = $1
419 WHERE github_login = $2
420 RETURNING *
421 ",
422 )
423 .bind(github_user_id)
424 .bind(github_login)
425 .fetch_optional(&self.pool)
426 .await?;
427 }
428
429 Ok(user)
430 } else {
431 let user = sqlx::query_as(
432 "
433 SELECT * FROM users
434 WHERE github_login = $1
435 LIMIT 1
436 ",
437 )
438 .bind(github_login)
439 .fetch_optional(&self.pool)
440 .await?;
441 Ok(user)
442 }
443 })
444 }
445
446 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
447 test_support!(self, {
448 let query = "UPDATE users SET admin = $1 WHERE id = $2";
449 Ok(sqlx::query(query)
450 .bind(is_admin)
451 .bind(id.0)
452 .execute(&self.pool)
453 .await
454 .map(drop)?)
455 })
456 }
457
458 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
459 test_support!(self, {
460 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
461 Ok(sqlx::query(query)
462 .bind(connected_once)
463 .bind(id.0)
464 .execute(&self.pool)
465 .await
466 .map(drop)?)
467 })
468 }
469
470 async fn destroy_user(&self, id: UserId) -> Result<()> {
471 test_support!(self, {
472 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
473 sqlx::query(query)
474 .bind(id.0)
475 .execute(&self.pool)
476 .await
477 .map(drop)?;
478 let query = "DELETE FROM users WHERE id = $1;";
479 Ok(sqlx::query(query)
480 .bind(id.0)
481 .execute(&self.pool)
482 .await
483 .map(drop)?)
484 })
485 }
486
487 // signups
488
489 async fn create_signup(&self, signup: Signup) -> Result<()> {
490 test_support!(self, {
491 sqlx::query(
492 "
493 INSERT INTO signups
494 (
495 email_address,
496 email_confirmation_code,
497 email_confirmation_sent,
498 platform_linux,
499 platform_mac,
500 platform_windows,
501 platform_unknown,
502 editor_features,
503 programming_languages,
504 device_id
505 )
506 VALUES
507 ($1, $2, FALSE, $3, $4, $5, FALSE, $6)
508 RETURNING id
509 ",
510 )
511 .bind(&signup.email_address)
512 .bind(&random_email_confirmation_code())
513 .bind(&signup.platform_linux)
514 .bind(&signup.platform_mac)
515 .bind(&signup.platform_windows)
516 // .bind(&signup.editor_features)
517 // .bind(&signup.programming_languages)
518 .bind(&signup.device_id)
519 .execute(&self.pool)
520 .await?;
521 Ok(())
522 })
523 }
524
525 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
526 test_support!(self, {
527 Ok(sqlx::query_as(
528 "
529 SELECT
530 COUNT(*) as count,
531 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
532 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
533 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
534 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
535 FROM (
536 SELECT *
537 FROM signups
538 WHERE
539 NOT email_confirmation_sent
540 ) AS unsent
541 ",
542 )
543 .fetch_one(&self.pool)
544 .await?)
545 })
546 }
547
548 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
549 test_support!(self, {
550 Ok(sqlx::query_as(
551 "
552 SELECT
553 email_address, email_confirmation_code
554 FROM signups
555 WHERE
556 NOT email_confirmation_sent AND
557 (platform_mac OR platform_unknown)
558 LIMIT $1
559 ",
560 )
561 .bind(count as i32)
562 .fetch_all(&self.pool)
563 .await?)
564 })
565 }
566
567 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
568 test_support!(self, {
569 // sqlx::query(
570 // "
571 // UPDATE signups
572 // SET email_confirmation_sent = TRUE
573 // WHERE email_address = ANY ($1)
574 // ",
575 // )
576 // .bind(
577 // &invites
578 // .iter()
579 // .map(|s| s.email_address.as_str())
580 // .collect::<Vec<_>>(),
581 // )
582 // .execute(&self.pool)
583 // .await?;
584 Ok(())
585 })
586 }
587
588 async fn create_user_from_invite(
589 &self,
590 invite: &Invite,
591 user: NewUserParams,
592 ) -> Result<Option<NewUserResult>> {
593 test_support!(self, {
594 let mut tx = self.pool.begin().await?;
595
596 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
597 i32,
598 Option<UserId>,
599 Option<UserId>,
600 Option<String>,
601 ) = sqlx::query_as(
602 "
603 SELECT id, user_id, inviting_user_id, device_id
604 FROM signups
605 WHERE
606 email_address = $1 AND
607 email_confirmation_code = $2
608 ",
609 )
610 .bind(&invite.email_address)
611 .bind(&invite.email_confirmation_code)
612 .fetch_optional(&mut tx)
613 .await?
614 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
615
616 if existing_user_id.is_some() {
617 return Ok(None);
618 }
619
620 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
621 "
622 INSERT INTO users
623 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
624 VALUES
625 ($1, $2, $3, FALSE, $4, $5)
626 ON CONFLICT (github_login) DO UPDATE SET
627 email_address = excluded.email_address,
628 github_user_id = excluded.github_user_id,
629 admin = excluded.admin
630 RETURNING id, metrics_id::text
631 ",
632 )
633 .bind(&invite.email_address)
634 .bind(&user.github_login)
635 .bind(&user.github_user_id)
636 .bind(&user.invite_count)
637 .bind(random_invite_code())
638 .fetch_one(&mut tx)
639 .await?;
640
641 sqlx::query(
642 "
643 UPDATE signups
644 SET user_id = $1
645 WHERE id = $2
646 ",
647 )
648 .bind(&user_id)
649 .bind(&signup_id)
650 .execute(&mut tx)
651 .await?;
652
653 if let Some(inviting_user_id) = inviting_user_id {
654 let id: Option<UserId> = sqlx::query_scalar(
655 "
656 UPDATE users
657 SET invite_count = invite_count - 1
658 WHERE id = $1 AND invite_count > 0
659 RETURNING id
660 ",
661 )
662 .bind(&inviting_user_id)
663 .fetch_optional(&mut tx)
664 .await?;
665
666 if id.is_none() {
667 Err(Error::Http(
668 StatusCode::UNAUTHORIZED,
669 "no invites remaining".to_string(),
670 ))?;
671 }
672
673 sqlx::query(
674 "
675 INSERT INTO contacts
676 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
677 VALUES
678 ($1, $2, TRUE, TRUE, TRUE)
679 ON CONFLICT DO NOTHING
680 ",
681 )
682 .bind(inviting_user_id)
683 .bind(user_id)
684 .execute(&mut tx)
685 .await?;
686 }
687
688 tx.commit().await?;
689 Ok(Some(NewUserResult {
690 user_id,
691 metrics_id,
692 inviting_user_id,
693 signup_device_id,
694 }))
695 })
696 }
697
698 // invite codes
699
700 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
701 test_support!(self, {
702 let mut tx = self.pool.begin().await?;
703 if count > 0 {
704 sqlx::query(
705 "
706 UPDATE users
707 SET invite_code = $1
708 WHERE id = $2 AND invite_code IS NULL
709 ",
710 )
711 .bind(random_invite_code())
712 .bind(id)
713 .execute(&mut tx)
714 .await?;
715 }
716
717 sqlx::query(
718 "
719 UPDATE users
720 SET invite_count = $1
721 WHERE id = $2
722 ",
723 )
724 .bind(count as i32)
725 .bind(id)
726 .execute(&mut tx)
727 .await?;
728 tx.commit().await?;
729 Ok(())
730 })
731 }
732
733 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
734 test_support!(self, {
735 let result: Option<(String, i32)> = sqlx::query_as(
736 "
737 SELECT invite_code, invite_count
738 FROM users
739 WHERE id = $1 AND invite_code IS NOT NULL
740 ",
741 )
742 .bind(id)
743 .fetch_optional(&self.pool)
744 .await?;
745 if let Some((code, count)) = result {
746 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
747 } else {
748 Ok(None)
749 }
750 })
751 }
752
753 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
754 test_support!(self, {
755 sqlx::query_as(
756 "
757 SELECT *
758 FROM users
759 WHERE invite_code = $1
760 ",
761 )
762 .bind(code)
763 .fetch_optional(&self.pool)
764 .await?
765 .ok_or_else(|| {
766 Error::Http(
767 StatusCode::NOT_FOUND,
768 "that invite code does not exist".to_string(),
769 )
770 })
771 })
772 }
773
774 async fn create_invite_from_code(
775 &self,
776 code: &str,
777 email_address: &str,
778 device_id: Option<&str>,
779 ) -> Result<Invite> {
780 test_support!(self, {
781 let mut tx = self.pool.begin().await?;
782
783 let existing_user: Option<UserId> = sqlx::query_scalar(
784 "
785 SELECT id
786 FROM users
787 WHERE email_address = $1
788 ",
789 )
790 .bind(email_address)
791 .fetch_optional(&mut tx)
792 .await?;
793 if existing_user.is_some() {
794 Err(anyhow!("email address is already in use"))?;
795 }
796
797 let row: Option<(UserId, i32)> = sqlx::query_as(
798 "
799 SELECT id, invite_count
800 FROM users
801 WHERE invite_code = $1
802 ",
803 )
804 .bind(code)
805 .fetch_optional(&mut tx)
806 .await?;
807
808 let (inviter_id, invite_count) = match row {
809 Some(row) => row,
810 None => Err(Error::Http(
811 StatusCode::NOT_FOUND,
812 "invite code not found".to_string(),
813 ))?,
814 };
815
816 if invite_count == 0 {
817 Err(Error::Http(
818 StatusCode::UNAUTHORIZED,
819 "no invites remaining".to_string(),
820 ))?;
821 }
822
823 let email_confirmation_code: String = sqlx::query_scalar(
824 "
825 INSERT INTO signups
826 (
827 email_address,
828 email_confirmation_code,
829 email_confirmation_sent,
830 inviting_user_id,
831 platform_linux,
832 platform_mac,
833 platform_windows,
834 platform_unknown,
835 device_id
836 )
837 VALUES
838 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
839 ON CONFLICT (email_address)
840 DO UPDATE SET
841 inviting_user_id = excluded.inviting_user_id
842 RETURNING email_confirmation_code
843 ",
844 )
845 .bind(&email_address)
846 .bind(&random_email_confirmation_code())
847 .bind(&inviter_id)
848 .bind(&device_id)
849 .fetch_one(&mut tx)
850 .await?;
851
852 tx.commit().await?;
853
854 Ok(Invite {
855 email_address: email_address.into(),
856 email_confirmation_code,
857 })
858 })
859 }
860
861 // projects
862
863 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
864 test_support!(self, {
865 Ok(sqlx::query_scalar(
866 "
867 INSERT INTO projects(host_user_id)
868 VALUES ($1)
869 RETURNING id
870 ",
871 )
872 .bind(host_user_id)
873 .fetch_one(&self.pool)
874 .await
875 .map(ProjectId)?)
876 })
877 }
878
879 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
880 test_support!(self, {
881 sqlx::query(
882 "
883 UPDATE projects
884 SET unregistered = TRUE
885 WHERE id = $1
886 ",
887 )
888 .bind(project_id)
889 .execute(&self.pool)
890 .await?;
891 Ok(())
892 })
893 }
894
895 async fn update_worktree_extensions(
896 &self,
897 project_id: ProjectId,
898 worktree_id: u64,
899 extensions: HashMap<String, u32>,
900 ) -> Result<()> {
901 test_support!(self, {
902 if extensions.is_empty() {
903 return Ok(());
904 }
905
906 let mut query = QueryBuilder::new(
907 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
908 );
909 query.push_values(extensions, |mut query, (extension, count)| {
910 query
911 .push_bind(project_id)
912 .push_bind(worktree_id as i32)
913 .push_bind(extension)
914 .push_bind(count as i32);
915 });
916 query.push(
917 "
918 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
919 count = excluded.count
920 ",
921 );
922 query.build().execute(&self.pool).await?;
923
924 Ok(())
925 })
926 }
927
928 async fn get_project_extensions(
929 &self,
930 project_id: ProjectId,
931 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
932 test_support!(self, {
933 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
934 struct WorktreeExtension {
935 worktree_id: i32,
936 extension: String,
937 count: i32,
938 }
939
940 let query = "
941 SELECT worktree_id, extension, count
942 FROM worktree_extensions
943 WHERE project_id = $1
944 ";
945 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
946 .bind(&project_id)
947 .fetch_all(&self.pool)
948 .await?;
949
950 let mut extension_counts = HashMap::default();
951 for count in counts {
952 extension_counts
953 .entry(count.worktree_id as u64)
954 .or_insert_with(HashMap::default)
955 .insert(count.extension, count.count as usize);
956 }
957 Ok(extension_counts)
958 })
959 }
960
961 async fn record_user_activity(
962 &self,
963 time_period: Range<OffsetDateTime>,
964 projects: &[(UserId, ProjectId)],
965 ) -> Result<()> {
966 test_support!(self, {
967 let query = "
968 INSERT INTO project_activity_periods
969 (ended_at, duration_millis, user_id, project_id)
970 VALUES
971 ($1, $2, $3, $4);
972 ";
973
974 let mut tx = self.pool.begin().await?;
975 let duration_millis =
976 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
977 for (user_id, project_id) in projects {
978 sqlx::query(query)
979 .bind(time_period.end)
980 .bind(duration_millis)
981 .bind(user_id)
982 .bind(project_id)
983 .execute(&mut tx)
984 .await?;
985 }
986 tx.commit().await?;
987
988 Ok(())
989 })
990 }
991
992 async fn get_active_user_count(
993 &self,
994 time_period: Range<OffsetDateTime>,
995 min_duration: Duration,
996 only_collaborative: bool,
997 ) -> Result<usize> {
998 test_support!(self, {
999 let mut with_clause = String::new();
1000 with_clause.push_str("WITH\n");
1001 with_clause.push_str(
1002 "
1003 project_durations AS (
1004 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
1005 FROM project_activity_periods
1006 WHERE $1 < ended_at AND ended_at <= $2
1007 GROUP BY user_id, project_id
1008 ),
1009 ",
1010 );
1011 with_clause.push_str(
1012 "
1013 project_collaborators as (
1014 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1015 FROM project_durations
1016 GROUP BY project_id
1017 ),
1018 ",
1019 );
1020
1021 if only_collaborative {
1022 with_clause.push_str(
1023 "
1024 user_durations AS (
1025 SELECT user_id, SUM(project_duration) as total_duration
1026 FROM project_durations, project_collaborators
1027 WHERE
1028 project_durations.project_id = project_collaborators.project_id AND
1029 max_collaborators > 1
1030 GROUP BY user_id
1031 ORDER BY total_duration DESC
1032 LIMIT $3
1033 )
1034 ",
1035 );
1036 } else {
1037 with_clause.push_str(
1038 "
1039 user_durations AS (
1040 SELECT user_id, SUM(project_duration) as total_duration
1041 FROM project_durations
1042 GROUP BY user_id
1043 ORDER BY total_duration DESC
1044 LIMIT $3
1045 )
1046 ",
1047 );
1048 }
1049
1050 let query = format!(
1051 "
1052 {with_clause}
1053 SELECT count(user_durations.user_id)
1054 FROM user_durations
1055 WHERE user_durations.total_duration >= $3
1056 "
1057 );
1058
1059 let count: i64 = sqlx::query_scalar(&query)
1060 .bind(time_period.start)
1061 .bind(time_period.end)
1062 .bind(min_duration.as_millis() as i64)
1063 .fetch_one(&self.pool)
1064 .await?;
1065 Ok(count as usize)
1066 })
1067 }
1068
1069 async fn get_top_users_activity_summary(
1070 &self,
1071 time_period: Range<OffsetDateTime>,
1072 max_user_count: usize,
1073 ) -> Result<Vec<UserActivitySummary>> {
1074 test_support!(self, {
1075 let query = "
1076 WITH
1077 project_durations AS (
1078 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
1079 FROM project_activity_periods
1080 WHERE $1 < ended_at AND ended_at <= $2
1081 GROUP BY user_id, project_id
1082 ),
1083 user_durations AS (
1084 SELECT user_id, SUM(project_duration) as total_duration
1085 FROM project_durations
1086 GROUP BY user_id
1087 ORDER BY total_duration DESC
1088 LIMIT $3
1089 ),
1090 project_collaborators as (
1091 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1092 FROM project_durations
1093 GROUP BY project_id
1094 )
1095 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
1096 FROM user_durations, project_durations, project_collaborators, users
1097 WHERE
1098 user_durations.user_id = project_durations.user_id AND
1099 user_durations.user_id = users.id AND
1100 project_durations.project_id = project_collaborators.project_id
1101 ORDER BY total_duration DESC, user_id ASC, project_id ASC
1102 ";
1103
1104 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
1105 .bind(time_period.start)
1106 .bind(time_period.end)
1107 .bind(max_user_count as i32)
1108 .fetch(&self.pool);
1109
1110 let mut result = Vec::<UserActivitySummary>::new();
1111 while let Some(row) = rows.next().await {
1112 let (user_id, github_login, project_id, duration_millis, project_collaborators) =
1113 row?;
1114 let project_id = project_id;
1115 let duration = Duration::from_millis(duration_millis as u64);
1116 let project_activity = ProjectActivitySummary {
1117 id: project_id,
1118 duration,
1119 max_collaborators: project_collaborators as usize,
1120 };
1121 if let Some(last_summary) = result.last_mut() {
1122 if last_summary.id == user_id {
1123 last_summary.project_activity.push(project_activity);
1124 continue;
1125 }
1126 }
1127 result.push(UserActivitySummary {
1128 id: user_id,
1129 project_activity: vec![project_activity],
1130 github_login,
1131 });
1132 }
1133
1134 Ok(result)
1135 })
1136 }
1137
1138 async fn get_user_activity_timeline(
1139 &self,
1140 time_period: Range<OffsetDateTime>,
1141 user_id: UserId,
1142 ) -> Result<Vec<UserActivityPeriod>> {
1143 test_support!(self, {
1144 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
1145
1146 let query = "
1147 SELECT
1148 project_activity_periods.ended_at,
1149 project_activity_periods.duration_millis,
1150 project_activity_periods.project_id,
1151 worktree_extensions.extension,
1152 worktree_extensions.count
1153 FROM project_activity_periods
1154 LEFT OUTER JOIN
1155 worktree_extensions
1156 ON
1157 project_activity_periods.project_id = worktree_extensions.project_id
1158 WHERE
1159 project_activity_periods.user_id = $1 AND
1160 $2 < project_activity_periods.ended_at AND
1161 project_activity_periods.ended_at <= $3
1162 ORDER BY project_activity_periods.id ASC
1163 ";
1164
1165 let mut rows = sqlx::query_as::<
1166 _,
1167 (
1168 PrimitiveDateTime,
1169 i32,
1170 ProjectId,
1171 Option<String>,
1172 Option<i32>,
1173 ),
1174 >(query)
1175 .bind(user_id)
1176 .bind(time_period.start)
1177 .bind(time_period.end)
1178 .fetch(&self.pool);
1179
1180 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1181 while let Some(row) = rows.next().await {
1182 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1183 let ended_at = ended_at.assume_utc();
1184 let duration = Duration::from_millis(duration_millis as u64);
1185 let started_at = ended_at - duration;
1186 let project_time_periods = time_periods.entry(project_id).or_default();
1187
1188 if let Some(prev_duration) = project_time_periods.last_mut() {
1189 if started_at <= prev_duration.end + COALESCE_THRESHOLD
1190 && ended_at >= prev_duration.start
1191 {
1192 prev_duration.end = cmp::max(prev_duration.end, ended_at);
1193 } else {
1194 project_time_periods.push(UserActivityPeriod {
1195 project_id,
1196 start: started_at,
1197 end: ended_at,
1198 extensions: Default::default(),
1199 });
1200 }
1201 } else {
1202 project_time_periods.push(UserActivityPeriod {
1203 project_id,
1204 start: started_at,
1205 end: ended_at,
1206 extensions: Default::default(),
1207 });
1208 }
1209
1210 if let Some((extension, extension_count)) = extension.zip(extension_count) {
1211 project_time_periods
1212 .last_mut()
1213 .unwrap()
1214 .extensions
1215 .insert(extension, extension_count as usize);
1216 }
1217 }
1218
1219 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1220 durations.sort_unstable_by_key(|duration| duration.start);
1221 Ok(durations)
1222 })
1223 }
1224
1225 // contacts
1226
1227 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1228 test_support!(self, {
1229 let query = "
1230 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1231 FROM contacts
1232 WHERE user_id_a = $1 OR user_id_b = $1;
1233 ";
1234
1235 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1236 .bind(user_id)
1237 .fetch(&self.pool);
1238
1239 let mut contacts = Vec::new();
1240 while let Some(row) = rows.next().await {
1241 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1242
1243 if user_id_a == user_id {
1244 if accepted {
1245 contacts.push(Contact::Accepted {
1246 user_id: user_id_b,
1247 should_notify: should_notify && a_to_b,
1248 });
1249 } else if a_to_b {
1250 contacts.push(Contact::Outgoing { user_id: user_id_b })
1251 } else {
1252 contacts.push(Contact::Incoming {
1253 user_id: user_id_b,
1254 should_notify,
1255 });
1256 }
1257 } else if accepted {
1258 contacts.push(Contact::Accepted {
1259 user_id: user_id_a,
1260 should_notify: should_notify && !a_to_b,
1261 });
1262 } else if a_to_b {
1263 contacts.push(Contact::Incoming {
1264 user_id: user_id_a,
1265 should_notify,
1266 });
1267 } else {
1268 contacts.push(Contact::Outgoing { user_id: user_id_a });
1269 }
1270 }
1271
1272 contacts.sort_unstable_by_key(|contact| contact.user_id());
1273
1274 Ok(contacts)
1275 })
1276 }
1277
1278 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1279 test_support!(self, {
1280 let (id_a, id_b) = if user_id_1 < user_id_2 {
1281 (user_id_1, user_id_2)
1282 } else {
1283 (user_id_2, user_id_1)
1284 };
1285
1286 let query = "
1287 SELECT 1 FROM contacts
1288 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
1289 LIMIT 1
1290 ";
1291 Ok(sqlx::query_scalar::<_, i32>(query)
1292 .bind(id_a.0)
1293 .bind(id_b.0)
1294 .fetch_optional(&self.pool)
1295 .await?
1296 .is_some())
1297 })
1298 }
1299
1300 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1301 test_support!(self, {
1302 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1303 (sender_id, receiver_id, true)
1304 } else {
1305 (receiver_id, sender_id, false)
1306 };
1307 let query = "
1308 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1309 VALUES ($1, $2, $3, FALSE, TRUE)
1310 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1311 SET
1312 accepted = TRUE,
1313 should_notify = FALSE
1314 WHERE
1315 NOT contacts.accepted AND
1316 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1317 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1318 ";
1319 let result = sqlx::query(query)
1320 .bind(id_a.0)
1321 .bind(id_b.0)
1322 .bind(a_to_b)
1323 .execute(&self.pool)
1324 .await?;
1325
1326 if result.rows_affected() == 1 {
1327 Ok(())
1328 } else {
1329 Err(anyhow!("contact already requested"))?
1330 }
1331 })
1332 }
1333
1334 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1335 test_support!(self, {
1336 let (id_a, id_b) = if responder_id < requester_id {
1337 (responder_id, requester_id)
1338 } else {
1339 (requester_id, responder_id)
1340 };
1341 let query = "
1342 DELETE FROM contacts
1343 WHERE user_id_a = $1 AND user_id_b = $2;
1344 ";
1345 let result = sqlx::query(query)
1346 .bind(id_a.0)
1347 .bind(id_b.0)
1348 .execute(&self.pool)
1349 .await?;
1350
1351 if result.rows_affected() == 1 {
1352 Ok(())
1353 } else {
1354 Err(anyhow!("no such contact"))?
1355 }
1356 })
1357 }
1358
1359 async fn dismiss_contact_notification(
1360 &self,
1361 user_id: UserId,
1362 contact_user_id: UserId,
1363 ) -> Result<()> {
1364 test_support!(self, {
1365 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1366 (user_id, contact_user_id, true)
1367 } else {
1368 (contact_user_id, user_id, false)
1369 };
1370
1371 let query = "
1372 UPDATE contacts
1373 SET should_notify = FALSE
1374 WHERE
1375 user_id_a = $1 AND user_id_b = $2 AND
1376 (
1377 (a_to_b = $3 AND accepted) OR
1378 (a_to_b != $3 AND NOT accepted)
1379 );
1380 ";
1381
1382 let result = sqlx::query(query)
1383 .bind(id_a.0)
1384 .bind(id_b.0)
1385 .bind(a_to_b)
1386 .execute(&self.pool)
1387 .await?;
1388
1389 if result.rows_affected() == 0 {
1390 Err(anyhow!("no such contact request"))?;
1391 }
1392
1393 Ok(())
1394 })
1395 }
1396
1397 async fn respond_to_contact_request(
1398 &self,
1399 responder_id: UserId,
1400 requester_id: UserId,
1401 accept: bool,
1402 ) -> Result<()> {
1403 test_support!(self, {
1404 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1405 (responder_id, requester_id, false)
1406 } else {
1407 (requester_id, responder_id, true)
1408 };
1409 let result = if accept {
1410 let query = "
1411 UPDATE contacts
1412 SET accepted = TRUE, should_notify = TRUE
1413 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1414 ";
1415 sqlx::query(query)
1416 .bind(id_a.0)
1417 .bind(id_b.0)
1418 .bind(a_to_b)
1419 .execute(&self.pool)
1420 .await?
1421 } else {
1422 let query = "
1423 DELETE FROM contacts
1424 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1425 ";
1426 sqlx::query(query)
1427 .bind(id_a.0)
1428 .bind(id_b.0)
1429 .bind(a_to_b)
1430 .execute(&self.pool)
1431 .await?
1432 };
1433 if result.rows_affected() == 1 {
1434 Ok(())
1435 } else {
1436 Err(anyhow!("no such contact request"))?
1437 }
1438 })
1439 }
1440
1441 // access tokens
1442
1443 async fn create_access_token_hash(
1444 &self,
1445 user_id: UserId,
1446 access_token_hash: &str,
1447 max_access_token_count: usize,
1448 ) -> Result<()> {
1449 test_support!(self, {
1450 let insert_query = "
1451 INSERT INTO access_tokens (user_id, hash)
1452 VALUES ($1, $2);
1453 ";
1454 let cleanup_query = "
1455 DELETE FROM access_tokens
1456 WHERE id IN (
1457 SELECT id from access_tokens
1458 WHERE user_id = $1
1459 ORDER BY id DESC
1460 OFFSET $3
1461 )
1462 ";
1463
1464 let mut tx = self.pool.begin().await?;
1465 sqlx::query(insert_query)
1466 .bind(user_id.0)
1467 .bind(access_token_hash)
1468 .execute(&mut tx)
1469 .await?;
1470 sqlx::query(cleanup_query)
1471 .bind(user_id.0)
1472 .bind(access_token_hash)
1473 .bind(max_access_token_count as i32)
1474 .execute(&mut tx)
1475 .await?;
1476 Ok(tx.commit().await?)
1477 })
1478 }
1479
1480 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1481 test_support!(self, {
1482 let query = "
1483 SELECT hash
1484 FROM access_tokens
1485 WHERE user_id = $1
1486 ORDER BY id DESC
1487 ";
1488 Ok(sqlx::query_scalar(query)
1489 .bind(user_id.0)
1490 .fetch_all(&self.pool)
1491 .await?)
1492 })
1493 }
1494
1495 // orgs
1496
1497 #[allow(unused)] // Help rust-analyzer
1498 #[cfg(any(test, feature = "seed-support"))]
1499 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1500 test_support!(self, {
1501 let query = "
1502 SELECT *
1503 FROM orgs
1504 WHERE slug = $1
1505 ";
1506 Ok(sqlx::query_as(query)
1507 .bind(slug)
1508 .fetch_optional(&self.pool)
1509 .await?)
1510 })
1511 }
1512
1513 #[cfg(any(test, feature = "seed-support"))]
1514 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1515 test_support!(self, {
1516 let query = "
1517 INSERT INTO orgs (name, slug)
1518 VALUES ($1, $2)
1519 RETURNING id
1520 ";
1521 Ok(sqlx::query_scalar(query)
1522 .bind(name)
1523 .bind(slug)
1524 .fetch_one(&self.pool)
1525 .await
1526 .map(OrgId)?)
1527 })
1528 }
1529
1530 #[cfg(any(test, feature = "seed-support"))]
1531 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1532 test_support!(self, {
1533 let query = "
1534 INSERT INTO org_memberships (org_id, user_id, admin)
1535 VALUES ($1, $2, $3)
1536 ON CONFLICT DO NOTHING
1537 ";
1538 Ok(sqlx::query(query)
1539 .bind(org_id.0)
1540 .bind(user_id.0)
1541 .bind(is_admin)
1542 .execute(&self.pool)
1543 .await
1544 .map(drop)?)
1545 })
1546 }
1547
1548 // channels
1549
1550 #[cfg(any(test, feature = "seed-support"))]
1551 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1552 test_support!(self, {
1553 let query = "
1554 INSERT INTO channels (owner_id, owner_is_user, name)
1555 VALUES ($1, false, $2)
1556 RETURNING id
1557 ";
1558 Ok(sqlx::query_scalar(query)
1559 .bind(org_id.0)
1560 .bind(name)
1561 .fetch_one(&self.pool)
1562 .await
1563 .map(ChannelId)?)
1564 })
1565 }
1566
1567 #[allow(unused)] // Help rust-analyzer
1568 #[cfg(any(test, feature = "seed-support"))]
1569 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1570 test_support!(self, {
1571 let query = "
1572 SELECT *
1573 FROM channels
1574 WHERE
1575 channels.owner_is_user = false AND
1576 channels.owner_id = $1
1577 ";
1578 Ok(sqlx::query_as(query)
1579 .bind(org_id.0)
1580 .fetch_all(&self.pool)
1581 .await?)
1582 })
1583 }
1584
1585 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1586 test_support!(self, {
1587 let query = "
1588 SELECT
1589 channels.*
1590 FROM
1591 channel_memberships, channels
1592 WHERE
1593 channel_memberships.user_id = $1 AND
1594 channel_memberships.channel_id = channels.id
1595 ";
1596 Ok(sqlx::query_as(query)
1597 .bind(user_id.0)
1598 .fetch_all(&self.pool)
1599 .await?)
1600 })
1601 }
1602
1603 async fn can_user_access_channel(
1604 &self,
1605 user_id: UserId,
1606 channel_id: ChannelId,
1607 ) -> Result<bool> {
1608 test_support!(self, {
1609 let query = "
1610 SELECT id
1611 FROM channel_memberships
1612 WHERE user_id = $1 AND channel_id = $2
1613 LIMIT 1
1614 ";
1615 Ok(sqlx::query_scalar::<_, i32>(query)
1616 .bind(user_id.0)
1617 .bind(channel_id.0)
1618 .fetch_optional(&self.pool)
1619 .await
1620 .map(|e| e.is_some())?)
1621 })
1622 }
1623
1624 #[cfg(any(test, feature = "seed-support"))]
1625 async fn add_channel_member(
1626 &self,
1627 channel_id: ChannelId,
1628 user_id: UserId,
1629 is_admin: bool,
1630 ) -> Result<()> {
1631 test_support!(self, {
1632 let query = "
1633 INSERT INTO channel_memberships (channel_id, user_id, admin)
1634 VALUES ($1, $2, $3)
1635 ON CONFLICT DO NOTHING
1636 ";
1637 Ok(sqlx::query(query)
1638 .bind(channel_id.0)
1639 .bind(user_id.0)
1640 .bind(is_admin)
1641 .execute(&self.pool)
1642 .await
1643 .map(drop)?)
1644 })
1645 }
1646
1647 // messages
1648
1649 async fn create_channel_message(
1650 &self,
1651 channel_id: ChannelId,
1652 sender_id: UserId,
1653 body: &str,
1654 timestamp: OffsetDateTime,
1655 nonce: u128,
1656 ) -> Result<MessageId> {
1657 test_support!(self, {
1658 let query = "
1659 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1660 VALUES ($1, $2, $3, $4, $5)
1661 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1662 RETURNING id
1663 ";
1664 Ok(sqlx::query_scalar(query)
1665 .bind(channel_id.0)
1666 .bind(sender_id.0)
1667 .bind(body)
1668 .bind(timestamp)
1669 .bind(Uuid::from_u128(nonce))
1670 .fetch_one(&self.pool)
1671 .await
1672 .map(MessageId)?)
1673 })
1674 }
1675
1676 async fn get_channel_messages(
1677 &self,
1678 channel_id: ChannelId,
1679 count: usize,
1680 before_id: Option<MessageId>,
1681 ) -> Result<Vec<ChannelMessage>> {
1682 test_support!(self, {
1683 let query = r#"
1684 SELECT * FROM (
1685 SELECT
1686 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1687 FROM
1688 channel_messages
1689 WHERE
1690 channel_id = $1 AND
1691 id < $2
1692 ORDER BY id DESC
1693 LIMIT $3
1694 ) as recent_messages
1695 ORDER BY id ASC
1696 "#;
1697 Ok(sqlx::query_as(query)
1698 .bind(channel_id.0)
1699 .bind(before_id.unwrap_or(MessageId::MAX))
1700 .bind(count as i64)
1701 .fetch_all(&self.pool)
1702 .await?)
1703 })
1704 }
1705
1706 #[cfg(test)]
1707 async fn teardown(&self, url: &str) {
1708 let start = std::time::Instant::now();
1709 eprintln!("tearing down database...");
1710 test_support!(self, {
1711 use util::ResultExt;
1712
1713 let query = "
1714 SELECT pg_terminate_backend(pg_stat_activity.pid)
1715 FROM pg_stat_activity
1716 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1717 ";
1718 sqlx::query(query).execute(&self.pool).await.log_err();
1719 self.pool.close().await;
1720 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
1721 .await
1722 .log_err();
1723 eprintln!("tore down database: {:?}", start.elapsed());
1724 })
1725 }
1726
1727 #[cfg(test)]
1728 fn as_fake(&self) -> Option<&FakeDb> {
1729 None
1730 }
1731}
1732
1733macro_rules! id_type {
1734 ($name:ident) => {
1735 #[derive(
1736 Clone,
1737 Copy,
1738 Debug,
1739 Default,
1740 PartialEq,
1741 Eq,
1742 PartialOrd,
1743 Ord,
1744 Hash,
1745 sqlx::Type,
1746 Serialize,
1747 Deserialize,
1748 )]
1749 #[sqlx(transparent)]
1750 #[serde(transparent)]
1751 pub struct $name(pub i32);
1752
1753 impl $name {
1754 #[allow(unused)]
1755 pub const MAX: Self = Self(i32::MAX);
1756
1757 #[allow(unused)]
1758 pub fn from_proto(value: u64) -> Self {
1759 Self(value as i32)
1760 }
1761
1762 #[allow(unused)]
1763 pub fn to_proto(self) -> u64 {
1764 self.0 as u64
1765 }
1766 }
1767
1768 impl std::fmt::Display for $name {
1769 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1770 self.0.fmt(f)
1771 }
1772 }
1773 };
1774}
1775
1776id_type!(UserId);
1777#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1778pub struct User {
1779 pub id: UserId,
1780 pub github_login: String,
1781 pub github_user_id: Option<i32>,
1782 pub email_address: Option<String>,
1783 pub admin: bool,
1784 pub invite_code: Option<String>,
1785 pub invite_count: i32,
1786 pub connected_once: bool,
1787}
1788
1789id_type!(ProjectId);
1790#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1791pub struct Project {
1792 pub id: ProjectId,
1793 pub host_user_id: UserId,
1794 pub unregistered: bool,
1795}
1796
1797#[derive(Clone, Debug, PartialEq, Serialize)]
1798pub struct UserActivitySummary {
1799 pub id: UserId,
1800 pub github_login: String,
1801 pub project_activity: Vec<ProjectActivitySummary>,
1802}
1803
1804#[derive(Clone, Debug, PartialEq, Serialize)]
1805pub struct ProjectActivitySummary {
1806 pub id: ProjectId,
1807 pub duration: Duration,
1808 pub max_collaborators: usize,
1809}
1810
1811#[derive(Clone, Debug, PartialEq, Serialize)]
1812pub struct UserActivityPeriod {
1813 pub project_id: ProjectId,
1814 #[serde(with = "time::serde::iso8601")]
1815 pub start: OffsetDateTime,
1816 #[serde(with = "time::serde::iso8601")]
1817 pub end: OffsetDateTime,
1818 pub extensions: HashMap<String, usize>,
1819}
1820
1821id_type!(OrgId);
1822#[derive(FromRow)]
1823pub struct Org {
1824 pub id: OrgId,
1825 pub name: String,
1826 pub slug: String,
1827}
1828
1829id_type!(ChannelId);
1830#[derive(Clone, Debug, FromRow, Serialize)]
1831pub struct Channel {
1832 pub id: ChannelId,
1833 pub name: String,
1834 pub owner_id: i32,
1835 pub owner_is_user: bool,
1836}
1837
1838id_type!(MessageId);
1839#[derive(Clone, Debug, FromRow)]
1840pub struct ChannelMessage {
1841 pub id: MessageId,
1842 pub channel_id: ChannelId,
1843 pub sender_id: UserId,
1844 pub body: String,
1845 pub sent_at: OffsetDateTime,
1846 pub nonce: Uuid,
1847}
1848
1849#[derive(Clone, Debug, PartialEq, Eq)]
1850pub enum Contact {
1851 Accepted {
1852 user_id: UserId,
1853 should_notify: bool,
1854 },
1855 Outgoing {
1856 user_id: UserId,
1857 },
1858 Incoming {
1859 user_id: UserId,
1860 should_notify: bool,
1861 },
1862}
1863
1864impl Contact {
1865 pub fn user_id(&self) -> UserId {
1866 match self {
1867 Contact::Accepted { user_id, .. } => *user_id,
1868 Contact::Outgoing { user_id } => *user_id,
1869 Contact::Incoming { user_id, .. } => *user_id,
1870 }
1871 }
1872}
1873
1874#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1875pub struct IncomingContactRequest {
1876 pub requester_id: UserId,
1877 pub should_notify: bool,
1878}
1879
1880#[derive(Clone, Deserialize)]
1881pub struct Signup {
1882 pub email_address: String,
1883 pub platform_mac: bool,
1884 pub platform_windows: bool,
1885 pub platform_linux: bool,
1886 pub editor_features: Vec<String>,
1887 pub programming_languages: Vec<String>,
1888 pub device_id: Option<String>,
1889}
1890
1891#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1892pub struct WaitlistSummary {
1893 #[sqlx(default)]
1894 pub count: i64,
1895 #[sqlx(default)]
1896 pub linux_count: i64,
1897 #[sqlx(default)]
1898 pub mac_count: i64,
1899 #[sqlx(default)]
1900 pub windows_count: i64,
1901 #[sqlx(default)]
1902 pub unknown_count: i64,
1903}
1904
1905#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1906pub struct Invite {
1907 pub email_address: String,
1908 pub email_confirmation_code: String,
1909}
1910
1911#[derive(Debug, Serialize, Deserialize)]
1912pub struct NewUserParams {
1913 pub github_login: String,
1914 pub github_user_id: i32,
1915 pub invite_count: i32,
1916}
1917
1918#[derive(Debug)]
1919pub struct NewUserResult {
1920 pub user_id: UserId,
1921 pub metrics_id: String,
1922 pub inviting_user_id: Option<UserId>,
1923 pub signup_device_id: Option<String>,
1924}
1925
1926fn random_invite_code() -> String {
1927 nanoid::nanoid!(16)
1928}
1929
1930fn random_email_confirmation_code() -> String {
1931 nanoid::nanoid!(64)
1932}
1933
1934#[cfg(test)]
1935pub use test::*;
1936
1937#[cfg(test)]
1938mod test {
1939 use super::*;
1940 use anyhow::anyhow;
1941 use collections::BTreeMap;
1942 use gpui::executor::Background;
1943 use parking_lot::Mutex;
1944 use rand::prelude::*;
1945 use sqlx::{migrate::MigrateDatabase, Sqlite};
1946 use std::sync::Arc;
1947 use util::post_inc;
1948
1949 pub struct FakeDb {
1950 background: Arc<Background>,
1951 pub users: Mutex<BTreeMap<UserId, User>>,
1952 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1953 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1954 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1955 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1956 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1957 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1958 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1959 pub contacts: Mutex<Vec<FakeContact>>,
1960 next_channel_message_id: Mutex<i32>,
1961 next_user_id: Mutex<i32>,
1962 next_org_id: Mutex<i32>,
1963 next_channel_id: Mutex<i32>,
1964 next_project_id: Mutex<i32>,
1965 }
1966
1967 #[derive(Debug)]
1968 pub struct FakeContact {
1969 pub requester_id: UserId,
1970 pub responder_id: UserId,
1971 pub accepted: bool,
1972 pub should_notify: bool,
1973 }
1974
1975 impl FakeDb {
1976 pub fn new(background: Arc<Background>) -> Self {
1977 Self {
1978 background,
1979 users: Default::default(),
1980 next_user_id: Mutex::new(0),
1981 projects: Default::default(),
1982 worktree_extensions: Default::default(),
1983 next_project_id: Mutex::new(1),
1984 orgs: Default::default(),
1985 next_org_id: Mutex::new(1),
1986 org_memberships: Default::default(),
1987 channels: Default::default(),
1988 next_channel_id: Mutex::new(1),
1989 channel_memberships: Default::default(),
1990 channel_messages: Default::default(),
1991 next_channel_message_id: Mutex::new(1),
1992 contacts: Default::default(),
1993 }
1994 }
1995 }
1996
1997 #[async_trait]
1998 impl Db for FakeDb {
1999 async fn create_user(
2000 &self,
2001 email_address: &str,
2002 admin: bool,
2003 params: NewUserParams,
2004 ) -> Result<NewUserResult> {
2005 self.background.simulate_random_delay().await;
2006
2007 let mut users = self.users.lock();
2008 let user_id = if let Some(user) = users
2009 .values()
2010 .find(|user| user.github_login == params.github_login)
2011 {
2012 user.id
2013 } else {
2014 let id = post_inc(&mut *self.next_user_id.lock());
2015 let user_id = UserId(id);
2016 users.insert(
2017 user_id,
2018 User {
2019 id: user_id,
2020 github_login: params.github_login,
2021 github_user_id: Some(params.github_user_id),
2022 email_address: Some(email_address.to_string()),
2023 admin,
2024 invite_code: None,
2025 invite_count: 0,
2026 connected_once: false,
2027 },
2028 );
2029 user_id
2030 };
2031 Ok(NewUserResult {
2032 user_id,
2033 metrics_id: "the-metrics-id".to_string(),
2034 inviting_user_id: None,
2035 signup_device_id: None,
2036 })
2037 }
2038
2039 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
2040 unimplemented!()
2041 }
2042
2043 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
2044 unimplemented!()
2045 }
2046
2047 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
2048 self.background.simulate_random_delay().await;
2049 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
2050 }
2051
2052 async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
2053 Ok("the-metrics-id".to_string())
2054 }
2055
2056 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
2057 self.background.simulate_random_delay().await;
2058 let users = self.users.lock();
2059 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
2060 }
2061
2062 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
2063 unimplemented!()
2064 }
2065
2066 async fn get_user_by_github_account(
2067 &self,
2068 github_login: &str,
2069 github_user_id: Option<i32>,
2070 ) -> Result<Option<User>> {
2071 self.background.simulate_random_delay().await;
2072 if let Some(github_user_id) = github_user_id {
2073 for user in self.users.lock().values_mut() {
2074 if user.github_user_id == Some(github_user_id) {
2075 user.github_login = github_login.into();
2076 return Ok(Some(user.clone()));
2077 }
2078 if user.github_login == github_login {
2079 user.github_user_id = Some(github_user_id);
2080 return Ok(Some(user.clone()));
2081 }
2082 }
2083 Ok(None)
2084 } else {
2085 Ok(self
2086 .users
2087 .lock()
2088 .values()
2089 .find(|user| user.github_login == github_login)
2090 .cloned())
2091 }
2092 }
2093
2094 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
2095 unimplemented!()
2096 }
2097
2098 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
2099 self.background.simulate_random_delay().await;
2100 let mut users = self.users.lock();
2101 let mut user = users
2102 .get_mut(&id)
2103 .ok_or_else(|| anyhow!("user not found"))?;
2104 user.connected_once = connected_once;
2105 Ok(())
2106 }
2107
2108 async fn destroy_user(&self, _id: UserId) -> Result<()> {
2109 unimplemented!()
2110 }
2111
2112 // signups
2113
2114 async fn create_signup(&self, _signup: Signup) -> Result<()> {
2115 unimplemented!()
2116 }
2117
2118 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
2119 unimplemented!()
2120 }
2121
2122 async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
2123 unimplemented!()
2124 }
2125
2126 async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
2127 unimplemented!()
2128 }
2129
2130 async fn create_user_from_invite(
2131 &self,
2132 _invite: &Invite,
2133 _user: NewUserParams,
2134 ) -> Result<Option<NewUserResult>> {
2135 unimplemented!()
2136 }
2137
2138 // invite codes
2139
2140 async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
2141 unimplemented!()
2142 }
2143
2144 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2145 self.background.simulate_random_delay().await;
2146 Ok(None)
2147 }
2148
2149 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2150 unimplemented!()
2151 }
2152
2153 async fn create_invite_from_code(
2154 &self,
2155 _code: &str,
2156 _email_address: &str,
2157 _device_id: Option<&str>,
2158 ) -> Result<Invite> {
2159 unimplemented!()
2160 }
2161
2162 // projects
2163
2164 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2165 self.background.simulate_random_delay().await;
2166 if !self.users.lock().contains_key(&host_user_id) {
2167 Err(anyhow!("no such user"))?;
2168 }
2169
2170 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2171 self.projects.lock().insert(
2172 project_id,
2173 Project {
2174 id: project_id,
2175 host_user_id,
2176 unregistered: false,
2177 },
2178 );
2179 Ok(project_id)
2180 }
2181
2182 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2183 self.background.simulate_random_delay().await;
2184 self.projects
2185 .lock()
2186 .get_mut(&project_id)
2187 .ok_or_else(|| anyhow!("no such project"))?
2188 .unregistered = true;
2189 Ok(())
2190 }
2191
2192 async fn update_worktree_extensions(
2193 &self,
2194 project_id: ProjectId,
2195 worktree_id: u64,
2196 extensions: HashMap<String, u32>,
2197 ) -> Result<()> {
2198 self.background.simulate_random_delay().await;
2199 if !self.projects.lock().contains_key(&project_id) {
2200 Err(anyhow!("no such project"))?;
2201 }
2202
2203 for (extension, count) in extensions {
2204 self.worktree_extensions
2205 .lock()
2206 .insert((project_id, worktree_id, extension), count);
2207 }
2208
2209 Ok(())
2210 }
2211
2212 async fn get_project_extensions(
2213 &self,
2214 _project_id: ProjectId,
2215 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2216 unimplemented!()
2217 }
2218
2219 async fn record_user_activity(
2220 &self,
2221 _time_period: Range<OffsetDateTime>,
2222 _active_projects: &[(UserId, ProjectId)],
2223 ) -> Result<()> {
2224 unimplemented!()
2225 }
2226
2227 async fn get_active_user_count(
2228 &self,
2229 _time_period: Range<OffsetDateTime>,
2230 _min_duration: Duration,
2231 _only_collaborative: bool,
2232 ) -> Result<usize> {
2233 unimplemented!()
2234 }
2235
2236 async fn get_top_users_activity_summary(
2237 &self,
2238 _time_period: Range<OffsetDateTime>,
2239 _limit: usize,
2240 ) -> Result<Vec<UserActivitySummary>> {
2241 unimplemented!()
2242 }
2243
2244 async fn get_user_activity_timeline(
2245 &self,
2246 _time_period: Range<OffsetDateTime>,
2247 _user_id: UserId,
2248 ) -> Result<Vec<UserActivityPeriod>> {
2249 unimplemented!()
2250 }
2251
2252 // contacts
2253
2254 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2255 self.background.simulate_random_delay().await;
2256 let mut contacts = Vec::new();
2257
2258 for contact in self.contacts.lock().iter() {
2259 if contact.requester_id == id {
2260 if contact.accepted {
2261 contacts.push(Contact::Accepted {
2262 user_id: contact.responder_id,
2263 should_notify: contact.should_notify,
2264 });
2265 } else {
2266 contacts.push(Contact::Outgoing {
2267 user_id: contact.responder_id,
2268 });
2269 }
2270 } else if contact.responder_id == id {
2271 if contact.accepted {
2272 contacts.push(Contact::Accepted {
2273 user_id: contact.requester_id,
2274 should_notify: false,
2275 });
2276 } else {
2277 contacts.push(Contact::Incoming {
2278 user_id: contact.requester_id,
2279 should_notify: contact.should_notify,
2280 });
2281 }
2282 }
2283 }
2284
2285 contacts.sort_unstable_by_key(|contact| contact.user_id());
2286 Ok(contacts)
2287 }
2288
2289 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2290 self.background.simulate_random_delay().await;
2291 Ok(self.contacts.lock().iter().any(|contact| {
2292 contact.accepted
2293 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2294 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2295 }))
2296 }
2297
2298 async fn send_contact_request(
2299 &self,
2300 requester_id: UserId,
2301 responder_id: UserId,
2302 ) -> Result<()> {
2303 self.background.simulate_random_delay().await;
2304 let mut contacts = self.contacts.lock();
2305 for contact in contacts.iter_mut() {
2306 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2307 if contact.accepted {
2308 Err(anyhow!("contact already exists"))?;
2309 } else {
2310 Err(anyhow!("contact already requested"))?;
2311 }
2312 }
2313 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2314 if contact.accepted {
2315 Err(anyhow!("contact already exists"))?;
2316 } else {
2317 contact.accepted = true;
2318 contact.should_notify = false;
2319 return Ok(());
2320 }
2321 }
2322 }
2323 contacts.push(FakeContact {
2324 requester_id,
2325 responder_id,
2326 accepted: false,
2327 should_notify: true,
2328 });
2329 Ok(())
2330 }
2331
2332 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2333 self.background.simulate_random_delay().await;
2334 self.contacts.lock().retain(|contact| {
2335 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2336 });
2337 Ok(())
2338 }
2339
2340 async fn dismiss_contact_notification(
2341 &self,
2342 user_id: UserId,
2343 contact_user_id: UserId,
2344 ) -> Result<()> {
2345 self.background.simulate_random_delay().await;
2346 let mut contacts = self.contacts.lock();
2347 for contact in contacts.iter_mut() {
2348 if contact.requester_id == contact_user_id
2349 && contact.responder_id == user_id
2350 && !contact.accepted
2351 {
2352 contact.should_notify = false;
2353 return Ok(());
2354 }
2355 if contact.requester_id == user_id
2356 && contact.responder_id == contact_user_id
2357 && contact.accepted
2358 {
2359 contact.should_notify = false;
2360 return Ok(());
2361 }
2362 }
2363 Err(anyhow!("no such notification"))?
2364 }
2365
2366 async fn respond_to_contact_request(
2367 &self,
2368 responder_id: UserId,
2369 requester_id: UserId,
2370 accept: bool,
2371 ) -> Result<()> {
2372 self.background.simulate_random_delay().await;
2373 let mut contacts = self.contacts.lock();
2374 for (ix, contact) in contacts.iter_mut().enumerate() {
2375 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2376 if contact.accepted {
2377 Err(anyhow!("contact already confirmed"))?;
2378 }
2379 if accept {
2380 contact.accepted = true;
2381 contact.should_notify = true;
2382 } else {
2383 contacts.remove(ix);
2384 }
2385 return Ok(());
2386 }
2387 }
2388 Err(anyhow!("no such contact request"))?
2389 }
2390
2391 async fn create_access_token_hash(
2392 &self,
2393 _user_id: UserId,
2394 _access_token_hash: &str,
2395 _max_access_token_count: usize,
2396 ) -> Result<()> {
2397 unimplemented!()
2398 }
2399
2400 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2401 unimplemented!()
2402 }
2403
2404 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2405 unimplemented!()
2406 }
2407
2408 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2409 self.background.simulate_random_delay().await;
2410 let mut orgs = self.orgs.lock();
2411 if orgs.values().any(|org| org.slug == slug) {
2412 Err(anyhow!("org already exists"))?
2413 } else {
2414 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2415 orgs.insert(
2416 org_id,
2417 Org {
2418 id: org_id,
2419 name: name.to_string(),
2420 slug: slug.to_string(),
2421 },
2422 );
2423 Ok(org_id)
2424 }
2425 }
2426
2427 async fn add_org_member(
2428 &self,
2429 org_id: OrgId,
2430 user_id: UserId,
2431 is_admin: bool,
2432 ) -> Result<()> {
2433 self.background.simulate_random_delay().await;
2434 if !self.orgs.lock().contains_key(&org_id) {
2435 Err(anyhow!("org does not exist"))?;
2436 }
2437 if !self.users.lock().contains_key(&user_id) {
2438 Err(anyhow!("user does not exist"))?;
2439 }
2440
2441 self.org_memberships
2442 .lock()
2443 .entry((org_id, user_id))
2444 .or_insert(is_admin);
2445 Ok(())
2446 }
2447
2448 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2449 self.background.simulate_random_delay().await;
2450 if !self.orgs.lock().contains_key(&org_id) {
2451 Err(anyhow!("org does not exist"))?;
2452 }
2453
2454 let mut channels = self.channels.lock();
2455 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2456 channels.insert(
2457 channel_id,
2458 Channel {
2459 id: channel_id,
2460 name: name.to_string(),
2461 owner_id: org_id.0,
2462 owner_is_user: false,
2463 },
2464 );
2465 Ok(channel_id)
2466 }
2467
2468 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2469 self.background.simulate_random_delay().await;
2470 Ok(self
2471 .channels
2472 .lock()
2473 .values()
2474 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2475 .cloned()
2476 .collect())
2477 }
2478
2479 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2480 self.background.simulate_random_delay().await;
2481 let channels = self.channels.lock();
2482 let memberships = self.channel_memberships.lock();
2483 Ok(channels
2484 .values()
2485 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2486 .cloned()
2487 .collect())
2488 }
2489
2490 async fn can_user_access_channel(
2491 &self,
2492 user_id: UserId,
2493 channel_id: ChannelId,
2494 ) -> Result<bool> {
2495 self.background.simulate_random_delay().await;
2496 Ok(self
2497 .channel_memberships
2498 .lock()
2499 .contains_key(&(channel_id, user_id)))
2500 }
2501
2502 async fn add_channel_member(
2503 &self,
2504 channel_id: ChannelId,
2505 user_id: UserId,
2506 is_admin: bool,
2507 ) -> Result<()> {
2508 self.background.simulate_random_delay().await;
2509 if !self.channels.lock().contains_key(&channel_id) {
2510 Err(anyhow!("channel does not exist"))?;
2511 }
2512 if !self.users.lock().contains_key(&user_id) {
2513 Err(anyhow!("user does not exist"))?;
2514 }
2515
2516 self.channel_memberships
2517 .lock()
2518 .entry((channel_id, user_id))
2519 .or_insert(is_admin);
2520 Ok(())
2521 }
2522
2523 async fn create_channel_message(
2524 &self,
2525 channel_id: ChannelId,
2526 sender_id: UserId,
2527 body: &str,
2528 timestamp: OffsetDateTime,
2529 nonce: u128,
2530 ) -> Result<MessageId> {
2531 self.background.simulate_random_delay().await;
2532 if !self.channels.lock().contains_key(&channel_id) {
2533 Err(anyhow!("channel does not exist"))?;
2534 }
2535 if !self.users.lock().contains_key(&sender_id) {
2536 Err(anyhow!("user does not exist"))?;
2537 }
2538
2539 let mut messages = self.channel_messages.lock();
2540 if let Some(message) = messages
2541 .values()
2542 .find(|message| message.nonce.as_u128() == nonce)
2543 {
2544 Ok(message.id)
2545 } else {
2546 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2547 messages.insert(
2548 message_id,
2549 ChannelMessage {
2550 id: message_id,
2551 channel_id,
2552 sender_id,
2553 body: body.to_string(),
2554 sent_at: timestamp,
2555 nonce: Uuid::from_u128(nonce),
2556 },
2557 );
2558 Ok(message_id)
2559 }
2560 }
2561
2562 async fn get_channel_messages(
2563 &self,
2564 channel_id: ChannelId,
2565 count: usize,
2566 before_id: Option<MessageId>,
2567 ) -> Result<Vec<ChannelMessage>> {
2568 self.background.simulate_random_delay().await;
2569 let mut messages = self
2570 .channel_messages
2571 .lock()
2572 .values()
2573 .rev()
2574 .filter(|message| {
2575 message.channel_id == channel_id
2576 && message.id < before_id.unwrap_or(MessageId::MAX)
2577 })
2578 .take(count)
2579 .cloned()
2580 .collect::<Vec<_>>();
2581 messages.sort_unstable_by_key(|message| message.id);
2582 Ok(messages)
2583 }
2584
2585 async fn teardown(&self, _: &str) {}
2586
2587 #[cfg(test)]
2588 fn as_fake(&self) -> Option<&FakeDb> {
2589 Some(self)
2590 }
2591 }
2592
2593 pub struct TestDb {
2594 pub db: Option<Arc<dyn Db>>,
2595 pub url: String,
2596 }
2597
2598 impl TestDb {
2599 #[allow(clippy::await_holding_lock)]
2600 pub async fn real() -> Self {
2601 eprintln!("creating database...");
2602 let start = std::time::Instant::now();
2603 let mut rng = StdRng::from_entropy();
2604 let url = format!("/tmp/zed-test-{}", rng.gen::<u128>());
2605 Sqlite::create_database(&url).await.unwrap();
2606 let db = RealDb::new(&url, 5).await.unwrap();
2607 db.migrate(Path::new(TEST_MIGRATIONS_PATH.unwrap()), false)
2608 .await
2609 .unwrap();
2610
2611 eprintln!("created database: {:?}", start.elapsed());
2612 Self {
2613 db: Some(Arc::new(db)),
2614 url,
2615 }
2616 }
2617
2618 pub fn fake(background: Arc<Background>) -> Self {
2619 Self {
2620 db: Some(Arc::new(FakeDb::new(background))),
2621 url: Default::default(),
2622 }
2623 }
2624
2625 pub fn db(&self) -> &Arc<dyn Db> {
2626 self.db.as_ref().unwrap()
2627 }
2628 }
2629
2630 impl Drop for TestDb {
2631 fn drop(&mut self) {
2632 if let Some(db) = self.db.take() {
2633 std::fs::remove_file(&self.url).ok();
2634 }
2635 }
2636 }
2637}