1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use collections::HashMap;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{
10 migrate::{Migrate as _, Migration, MigrationSource},
11 types::Uuid,
12 FromRow, QueryBuilder,
13};
14use std::{cmp, ops::Range, path::Path, time::Duration};
15use time::{OffsetDateTime, PrimitiveDateTime};
16
17#[async_trait]
18pub trait Db: Send + Sync {
19 async fn create_user(
20 &self,
21 email_address: &str,
22 admin: bool,
23 params: NewUserParams,
24 ) -> Result<NewUserResult>;
25 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
26 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
27 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
28 async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
29 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
30 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
31 async fn get_user_by_github_account(
32 &self,
33 github_login: &str,
34 github_user_id: Option<i32>,
35 ) -> Result<Option<User>>;
36 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
37 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
38 async fn destroy_user(&self, id: UserId) -> Result<()>;
39
40 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
41 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
42 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
43 async fn create_invite_from_code(
44 &self,
45 code: &str,
46 email_address: &str,
47 device_id: Option<&str>,
48 ) -> Result<Invite>;
49
50 async fn create_signup(&self, signup: Signup) -> Result<()>;
51 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
52 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
53 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
54 async fn create_user_from_invite(
55 &self,
56 invite: &Invite,
57 user: NewUserParams,
58 ) -> Result<Option<NewUserResult>>;
59
60 /// Registers a new project for the given user.
61 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
62
63 /// Unregisters a project for the given project id.
64 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
65
66 /// Update file counts by extension for the given project and worktree.
67 async fn update_worktree_extensions(
68 &self,
69 project_id: ProjectId,
70 worktree_id: u64,
71 extensions: HashMap<String, u32>,
72 ) -> Result<()>;
73
74 /// Get the file counts on the given project keyed by their worktree and extension.
75 async fn get_project_extensions(
76 &self,
77 project_id: ProjectId,
78 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
79
80 /// Record which users have been active in which projects during
81 /// a given period of time.
82 async fn record_user_activity(
83 &self,
84 time_period: Range<OffsetDateTime>,
85 active_projects: &[(UserId, ProjectId)],
86 ) -> Result<()>;
87
88 /// Get the number of users who have been active in the given
89 /// time period for at least the given time duration.
90 async fn get_active_user_count(
91 &self,
92 time_period: Range<OffsetDateTime>,
93 min_duration: Duration,
94 only_collaborative: bool,
95 ) -> Result<usize>;
96
97 /// Get the users that have been most active during the given time period,
98 /// along with the amount of time they have been active in each project.
99 async fn get_top_users_activity_summary(
100 &self,
101 time_period: Range<OffsetDateTime>,
102 max_user_count: usize,
103 ) -> Result<Vec<UserActivitySummary>>;
104
105 /// Get the project activity for the given user and time period.
106 async fn get_user_activity_timeline(
107 &self,
108 time_period: Range<OffsetDateTime>,
109 user_id: UserId,
110 ) -> Result<Vec<UserActivityPeriod>>;
111
112 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
113 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
114 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
115 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
116 async fn dismiss_contact_notification(
117 &self,
118 responder_id: UserId,
119 requester_id: UserId,
120 ) -> Result<()>;
121 async fn respond_to_contact_request(
122 &self,
123 responder_id: UserId,
124 requester_id: UserId,
125 accept: bool,
126 ) -> Result<()>;
127
128 async fn create_access_token_hash(
129 &self,
130 user_id: UserId,
131 access_token_hash: &str,
132 max_access_token_count: usize,
133 ) -> Result<()>;
134 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
135
136 #[cfg(any(test, feature = "seed-support"))]
137 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
138 #[cfg(any(test, feature = "seed-support"))]
139 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
140 #[cfg(any(test, feature = "seed-support"))]
141 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
142 #[cfg(any(test, feature = "seed-support"))]
143 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
144 #[cfg(any(test, feature = "seed-support"))]
145
146 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
147 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
148 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
149 -> Result<bool>;
150
151 #[cfg(any(test, feature = "seed-support"))]
152 async fn add_channel_member(
153 &self,
154 channel_id: ChannelId,
155 user_id: UserId,
156 is_admin: bool,
157 ) -> Result<()>;
158 async fn create_channel_message(
159 &self,
160 channel_id: ChannelId,
161 sender_id: UserId,
162 body: &str,
163 timestamp: OffsetDateTime,
164 nonce: u128,
165 ) -> Result<MessageId>;
166 async fn get_channel_messages(
167 &self,
168 channel_id: ChannelId,
169 count: usize,
170 before_id: Option<MessageId>,
171 ) -> Result<Vec<ChannelMessage>>;
172
173 #[cfg(test)]
174 async fn teardown(&self, url: &str);
175
176 #[cfg(test)]
177 fn as_fake(&self) -> Option<&FakeDb>;
178}
179
180#[cfg(any(test, debug_assertions))]
181pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> =
182 Some(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
183
184#[cfg(not(any(test, debug_assertions)))]
185pub const DEFAULT_MIGRATIONS_PATH: Option<&'static str> = None;
186
187pub struct PostgresDb {
188 pool: sqlx::PgPool,
189}
190
191impl PostgresDb {
192 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
193 let pool = DbOptions::new()
194 .max_connections(max_connections)
195 .connect(url)
196 .await
197 .context("failed to connect to postgres database")?;
198 Ok(Self { pool })
199 }
200
201 pub async fn migrate(
202 &self,
203 migrations_path: &Path,
204 ignore_checksum_mismatch: bool,
205 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
206 let migrations = MigrationSource::resolve(migrations_path)
207 .await
208 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
209
210 let mut conn = self.pool.acquire().await?;
211
212 conn.ensure_migrations_table().await?;
213 let applied_migrations: HashMap<_, _> = conn
214 .list_applied_migrations()
215 .await?
216 .into_iter()
217 .map(|m| (m.version, m))
218 .collect();
219
220 let mut new_migrations = Vec::new();
221 for migration in migrations {
222 match applied_migrations.get(&migration.version) {
223 Some(applied_migration) => {
224 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
225 {
226 Err(anyhow!(
227 "checksum mismatch for applied migration {}",
228 migration.description
229 ))?;
230 }
231 }
232 None => {
233 let elapsed = conn.apply(&migration).await?;
234 new_migrations.push((migration, elapsed));
235 }
236 }
237 }
238
239 Ok(new_migrations)
240 }
241
242 pub fn fuzzy_like_string(string: &str) -> String {
243 let mut result = String::with_capacity(string.len() * 2 + 1);
244 for c in string.chars() {
245 if c.is_alphanumeric() {
246 result.push('%');
247 result.push(c);
248 }
249 }
250 result.push('%');
251 result
252 }
253}
254
255#[async_trait]
256impl Db for PostgresDb {
257 // users
258
259 async fn create_user(
260 &self,
261 email_address: &str,
262 admin: bool,
263 params: NewUserParams,
264 ) -> Result<NewUserResult> {
265 let query = "
266 INSERT INTO users (email_address, github_login, github_user_id, admin)
267 VALUES ($1, $2, $3, $4)
268 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
269 RETURNING id, metrics_id::text
270 ";
271 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
272 .bind(email_address)
273 .bind(params.github_login)
274 .bind(params.github_user_id)
275 .bind(admin)
276 .fetch_one(&self.pool)
277 .await?;
278 Ok(NewUserResult {
279 user_id,
280 metrics_id,
281 signup_device_id: None,
282 inviting_user_id: None,
283 })
284 }
285
286 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
287 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
288 Ok(sqlx::query_as(query)
289 .bind(limit as i32)
290 .bind((page * limit) as i32)
291 .fetch_all(&self.pool)
292 .await?)
293 }
294
295 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
296 let like_string = Self::fuzzy_like_string(name_query);
297 let query = "
298 SELECT users.*
299 FROM users
300 WHERE github_login ILIKE $1
301 ORDER BY github_login <-> $2
302 LIMIT $3
303 ";
304 Ok(sqlx::query_as(query)
305 .bind(like_string)
306 .bind(name_query)
307 .bind(limit as i32)
308 .fetch_all(&self.pool)
309 .await?)
310 }
311
312 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
313 let users = self.get_users_by_ids(vec![id]).await?;
314 Ok(users.into_iter().next())
315 }
316
317 async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
318 let query = "
319 SELECT metrics_id::text
320 FROM users
321 WHERE id = $1
322 ";
323 Ok(sqlx::query_scalar(query)
324 .bind(id)
325 .fetch_one(&self.pool)
326 .await?)
327 }
328
329 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
330 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
331 let query = "
332 SELECT users.*
333 FROM users
334 WHERE users.id = ANY ($1)
335 ";
336 Ok(sqlx::query_as(query)
337 .bind(&ids)
338 .fetch_all(&self.pool)
339 .await?)
340 }
341
342 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
343 let query = format!(
344 "
345 SELECT users.*
346 FROM users
347 WHERE invite_count = 0
348 AND inviter_id IS{} NULL
349 ",
350 if invited_by_another_user { " NOT" } else { "" }
351 );
352
353 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
354 }
355
356 async fn get_user_by_github_account(
357 &self,
358 github_login: &str,
359 github_user_id: Option<i32>,
360 ) -> Result<Option<User>> {
361 if let Some(github_user_id) = github_user_id {
362 let mut user = sqlx::query_as::<_, User>(
363 "
364 UPDATE users
365 SET github_login = $1
366 WHERE github_user_id = $2
367 RETURNING *
368 ",
369 )
370 .bind(github_login)
371 .bind(github_user_id)
372 .fetch_optional(&self.pool)
373 .await?;
374
375 if user.is_none() {
376 user = sqlx::query_as::<_, User>(
377 "
378 UPDATE users
379 SET github_user_id = $1
380 WHERE github_login = $2
381 RETURNING *
382 ",
383 )
384 .bind(github_user_id)
385 .bind(github_login)
386 .fetch_optional(&self.pool)
387 .await?;
388 }
389
390 Ok(user)
391 } else {
392 Ok(sqlx::query_as(
393 "
394 SELECT * FROM users
395 WHERE github_login = $1
396 LIMIT 1
397 ",
398 )
399 .bind(github_login)
400 .fetch_optional(&self.pool)
401 .await?)
402 }
403 }
404
405 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
406 let query = "UPDATE users SET admin = $1 WHERE id = $2";
407 Ok(sqlx::query(query)
408 .bind(is_admin)
409 .bind(id.0)
410 .execute(&self.pool)
411 .await
412 .map(drop)?)
413 }
414
415 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
416 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
417 Ok(sqlx::query(query)
418 .bind(connected_once)
419 .bind(id.0)
420 .execute(&self.pool)
421 .await
422 .map(drop)?)
423 }
424
425 async fn destroy_user(&self, id: UserId) -> Result<()> {
426 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
427 sqlx::query(query)
428 .bind(id.0)
429 .execute(&self.pool)
430 .await
431 .map(drop)?;
432 let query = "DELETE FROM users WHERE id = $1;";
433 Ok(sqlx::query(query)
434 .bind(id.0)
435 .execute(&self.pool)
436 .await
437 .map(drop)?)
438 }
439
440 // signups
441
442 async fn create_signup(&self, signup: Signup) -> Result<()> {
443 sqlx::query(
444 "
445 INSERT INTO signups
446 (
447 email_address,
448 email_confirmation_code,
449 email_confirmation_sent,
450 platform_linux,
451 platform_mac,
452 platform_windows,
453 platform_unknown,
454 editor_features,
455 programming_languages,
456 device_id
457 )
458 VALUES
459 ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
460 RETURNING id
461 ",
462 )
463 .bind(&signup.email_address)
464 .bind(&random_email_confirmation_code())
465 .bind(&signup.platform_linux)
466 .bind(&signup.platform_mac)
467 .bind(&signup.platform_windows)
468 .bind(&signup.editor_features)
469 .bind(&signup.programming_languages)
470 .bind(&signup.device_id)
471 .execute(&self.pool)
472 .await?;
473 Ok(())
474 }
475
476 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
477 Ok(sqlx::query_as(
478 "
479 SELECT
480 COUNT(*) as count,
481 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
482 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
483 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
484 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
485 FROM (
486 SELECT *
487 FROM signups
488 WHERE
489 NOT email_confirmation_sent
490 ) AS unsent
491 ",
492 )
493 .fetch_one(&self.pool)
494 .await?)
495 }
496
497 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
498 Ok(sqlx::query_as(
499 "
500 SELECT
501 email_address, email_confirmation_code
502 FROM signups
503 WHERE
504 NOT email_confirmation_sent AND
505 (platform_mac OR platform_unknown)
506 LIMIT $1
507 ",
508 )
509 .bind(count as i32)
510 .fetch_all(&self.pool)
511 .await?)
512 }
513
514 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
515 sqlx::query(
516 "
517 UPDATE signups
518 SET email_confirmation_sent = 't'
519 WHERE email_address = ANY ($1)
520 ",
521 )
522 .bind(
523 &invites
524 .iter()
525 .map(|s| s.email_address.as_str())
526 .collect::<Vec<_>>(),
527 )
528 .execute(&self.pool)
529 .await?;
530 Ok(())
531 }
532
533 async fn create_user_from_invite(
534 &self,
535 invite: &Invite,
536 user: NewUserParams,
537 ) -> Result<Option<NewUserResult>> {
538 let mut tx = self.pool.begin().await?;
539
540 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
541 i32,
542 Option<UserId>,
543 Option<UserId>,
544 Option<String>,
545 ) = sqlx::query_as(
546 "
547 SELECT id, user_id, inviting_user_id, device_id
548 FROM signups
549 WHERE
550 email_address = $1 AND
551 email_confirmation_code = $2
552 ",
553 )
554 .bind(&invite.email_address)
555 .bind(&invite.email_confirmation_code)
556 .fetch_optional(&mut tx)
557 .await?
558 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
559
560 if existing_user_id.is_some() {
561 return Ok(None);
562 }
563
564 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
565 "
566 INSERT INTO users
567 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
568 VALUES
569 ($1, $2, $3, 'f', $4, $5)
570 RETURNING id, metrics_id::text
571 ",
572 )
573 .bind(&invite.email_address)
574 .bind(&user.github_login)
575 .bind(&user.github_user_id)
576 .bind(&user.invite_count)
577 .bind(random_invite_code())
578 .fetch_one(&mut tx)
579 .await?;
580
581 sqlx::query(
582 "
583 UPDATE signups
584 SET user_id = $1
585 WHERE id = $2
586 ",
587 )
588 .bind(&user_id)
589 .bind(&signup_id)
590 .execute(&mut tx)
591 .await?;
592
593 if let Some(inviting_user_id) = inviting_user_id {
594 let id: Option<UserId> = sqlx::query_scalar(
595 "
596 UPDATE users
597 SET invite_count = invite_count - 1
598 WHERE id = $1 AND invite_count > 0
599 RETURNING id
600 ",
601 )
602 .bind(&inviting_user_id)
603 .fetch_optional(&mut tx)
604 .await?;
605
606 if id.is_none() {
607 Err(Error::Http(
608 StatusCode::UNAUTHORIZED,
609 "no invites remaining".to_string(),
610 ))?;
611 }
612
613 sqlx::query(
614 "
615 INSERT INTO contacts
616 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
617 VALUES
618 ($1, $2, 't', 't', 't')
619 ",
620 )
621 .bind(inviting_user_id)
622 .bind(user_id)
623 .execute(&mut tx)
624 .await?;
625 }
626
627 tx.commit().await?;
628 Ok(Some(NewUserResult {
629 user_id,
630 metrics_id,
631 inviting_user_id,
632 signup_device_id,
633 }))
634 }
635
636 // invite codes
637
638 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
639 let mut tx = self.pool.begin().await?;
640 if count > 0 {
641 sqlx::query(
642 "
643 UPDATE users
644 SET invite_code = $1
645 WHERE id = $2 AND invite_code IS NULL
646 ",
647 )
648 .bind(random_invite_code())
649 .bind(id)
650 .execute(&mut tx)
651 .await?;
652 }
653
654 sqlx::query(
655 "
656 UPDATE users
657 SET invite_count = $1
658 WHERE id = $2
659 ",
660 )
661 .bind(count as i32)
662 .bind(id)
663 .execute(&mut tx)
664 .await?;
665 tx.commit().await?;
666 Ok(())
667 }
668
669 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
670 let result: Option<(String, i32)> = sqlx::query_as(
671 "
672 SELECT invite_code, invite_count
673 FROM users
674 WHERE id = $1 AND invite_code IS NOT NULL
675 ",
676 )
677 .bind(id)
678 .fetch_optional(&self.pool)
679 .await?;
680 if let Some((code, count)) = result {
681 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
682 } else {
683 Ok(None)
684 }
685 }
686
687 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
688 sqlx::query_as(
689 "
690 SELECT *
691 FROM users
692 WHERE invite_code = $1
693 ",
694 )
695 .bind(code)
696 .fetch_optional(&self.pool)
697 .await?
698 .ok_or_else(|| {
699 Error::Http(
700 StatusCode::NOT_FOUND,
701 "that invite code does not exist".to_string(),
702 )
703 })
704 }
705
706 async fn create_invite_from_code(
707 &self,
708 code: &str,
709 email_address: &str,
710 device_id: Option<&str>,
711 ) -> Result<Invite> {
712 let mut tx = self.pool.begin().await?;
713
714 let existing_user: Option<UserId> = sqlx::query_scalar(
715 "
716 SELECT id
717 FROM users
718 WHERE email_address = $1
719 ",
720 )
721 .bind(email_address)
722 .fetch_optional(&mut tx)
723 .await?;
724 if existing_user.is_some() {
725 Err(anyhow!("email address is already in use"))?;
726 }
727
728 let row: Option<(UserId, i32)> = sqlx::query_as(
729 "
730 SELECT id, invite_count
731 FROM users
732 WHERE invite_code = $1
733 ",
734 )
735 .bind(code)
736 .fetch_optional(&mut tx)
737 .await?;
738
739 let (inviter_id, invite_count) = match row {
740 Some(row) => row,
741 None => Err(Error::Http(
742 StatusCode::NOT_FOUND,
743 "invite code not found".to_string(),
744 ))?,
745 };
746
747 if invite_count == 0 {
748 Err(Error::Http(
749 StatusCode::UNAUTHORIZED,
750 "no invites remaining".to_string(),
751 ))?;
752 }
753
754 let email_confirmation_code: String = sqlx::query_scalar(
755 "
756 INSERT INTO signups
757 (
758 email_address,
759 email_confirmation_code,
760 email_confirmation_sent,
761 inviting_user_id,
762 platform_linux,
763 platform_mac,
764 platform_windows,
765 platform_unknown,
766 device_id
767 )
768 VALUES
769 ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4)
770 ON CONFLICT (email_address)
771 DO UPDATE SET
772 inviting_user_id = excluded.inviting_user_id
773 RETURNING email_confirmation_code
774 ",
775 )
776 .bind(&email_address)
777 .bind(&random_email_confirmation_code())
778 .bind(&inviter_id)
779 .bind(&device_id)
780 .fetch_one(&mut tx)
781 .await?;
782
783 tx.commit().await?;
784
785 Ok(Invite {
786 email_address: email_address.into(),
787 email_confirmation_code,
788 })
789 }
790
791 // projects
792
793 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
794 Ok(sqlx::query_scalar(
795 "
796 INSERT INTO projects(host_user_id)
797 VALUES ($1)
798 RETURNING id
799 ",
800 )
801 .bind(host_user_id)
802 .fetch_one(&self.pool)
803 .await
804 .map(ProjectId)?)
805 }
806
807 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
808 sqlx::query(
809 "
810 UPDATE projects
811 SET unregistered = 't'
812 WHERE id = $1
813 ",
814 )
815 .bind(project_id)
816 .execute(&self.pool)
817 .await?;
818 Ok(())
819 }
820
821 async fn update_worktree_extensions(
822 &self,
823 project_id: ProjectId,
824 worktree_id: u64,
825 extensions: HashMap<String, u32>,
826 ) -> Result<()> {
827 if extensions.is_empty() {
828 return Ok(());
829 }
830
831 let mut query = QueryBuilder::new(
832 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
833 );
834 query.push_values(extensions, |mut query, (extension, count)| {
835 query
836 .push_bind(project_id)
837 .push_bind(worktree_id as i32)
838 .push_bind(extension)
839 .push_bind(count as i32);
840 });
841 query.push(
842 "
843 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
844 count = excluded.count
845 ",
846 );
847 query.build().execute(&self.pool).await?;
848
849 Ok(())
850 }
851
852 async fn get_project_extensions(
853 &self,
854 project_id: ProjectId,
855 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
856 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
857 struct WorktreeExtension {
858 worktree_id: i32,
859 extension: String,
860 count: i32,
861 }
862
863 let query = "
864 SELECT worktree_id, extension, count
865 FROM worktree_extensions
866 WHERE project_id = $1
867 ";
868 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
869 .bind(&project_id)
870 .fetch_all(&self.pool)
871 .await?;
872
873 let mut extension_counts = HashMap::default();
874 for count in counts {
875 extension_counts
876 .entry(count.worktree_id as u64)
877 .or_insert_with(HashMap::default)
878 .insert(count.extension, count.count as usize);
879 }
880 Ok(extension_counts)
881 }
882
883 async fn record_user_activity(
884 &self,
885 time_period: Range<OffsetDateTime>,
886 projects: &[(UserId, ProjectId)],
887 ) -> Result<()> {
888 let query = "
889 INSERT INTO project_activity_periods
890 (ended_at, duration_millis, user_id, project_id)
891 VALUES
892 ($1, $2, $3, $4);
893 ";
894
895 let mut tx = self.pool.begin().await?;
896 let duration_millis =
897 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
898 for (user_id, project_id) in projects {
899 sqlx::query(query)
900 .bind(time_period.end)
901 .bind(duration_millis)
902 .bind(user_id)
903 .bind(project_id)
904 .execute(&mut tx)
905 .await?;
906 }
907 tx.commit().await?;
908
909 Ok(())
910 }
911
912 async fn get_active_user_count(
913 &self,
914 time_period: Range<OffsetDateTime>,
915 min_duration: Duration,
916 only_collaborative: bool,
917 ) -> Result<usize> {
918 let mut with_clause = String::new();
919 with_clause.push_str("WITH\n");
920 with_clause.push_str(
921 "
922 project_durations AS (
923 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
924 FROM project_activity_periods
925 WHERE $1 < ended_at AND ended_at <= $2
926 GROUP BY user_id, project_id
927 ),
928 ",
929 );
930 with_clause.push_str(
931 "
932 project_collaborators as (
933 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
934 FROM project_durations
935 GROUP BY project_id
936 ),
937 ",
938 );
939
940 if only_collaborative {
941 with_clause.push_str(
942 "
943 user_durations AS (
944 SELECT user_id, SUM(project_duration) as total_duration
945 FROM project_durations, project_collaborators
946 WHERE
947 project_durations.project_id = project_collaborators.project_id AND
948 max_collaborators > 1
949 GROUP BY user_id
950 ORDER BY total_duration DESC
951 LIMIT $3
952 )
953 ",
954 );
955 } else {
956 with_clause.push_str(
957 "
958 user_durations AS (
959 SELECT user_id, SUM(project_duration) as total_duration
960 FROM project_durations
961 GROUP BY user_id
962 ORDER BY total_duration DESC
963 LIMIT $3
964 )
965 ",
966 );
967 }
968
969 let query = format!(
970 "
971 {with_clause}
972 SELECT count(user_durations.user_id)
973 FROM user_durations
974 WHERE user_durations.total_duration >= $3
975 "
976 );
977
978 let count: i64 = sqlx::query_scalar(&query)
979 .bind(time_period.start)
980 .bind(time_period.end)
981 .bind(min_duration.as_millis() as i64)
982 .fetch_one(&self.pool)
983 .await?;
984 Ok(count as usize)
985 }
986
987 async fn get_top_users_activity_summary(
988 &self,
989 time_period: Range<OffsetDateTime>,
990 max_user_count: usize,
991 ) -> Result<Vec<UserActivitySummary>> {
992 let query = "
993 WITH
994 project_durations AS (
995 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
996 FROM project_activity_periods
997 WHERE $1 < ended_at AND ended_at <= $2
998 GROUP BY user_id, project_id
999 ),
1000 user_durations AS (
1001 SELECT user_id, SUM(project_duration) as total_duration
1002 FROM project_durations
1003 GROUP BY user_id
1004 ORDER BY total_duration DESC
1005 LIMIT $3
1006 ),
1007 project_collaborators as (
1008 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
1009 FROM project_durations
1010 GROUP BY project_id
1011 )
1012 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
1013 FROM user_durations, project_durations, project_collaborators, users
1014 WHERE
1015 user_durations.user_id = project_durations.user_id AND
1016 user_durations.user_id = users.id AND
1017 project_durations.project_id = project_collaborators.project_id
1018 ORDER BY total_duration DESC, user_id ASC, project_id ASC
1019 ";
1020
1021 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
1022 .bind(time_period.start)
1023 .bind(time_period.end)
1024 .bind(max_user_count as i32)
1025 .fetch(&self.pool);
1026
1027 let mut result = Vec::<UserActivitySummary>::new();
1028 while let Some(row) = rows.next().await {
1029 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
1030 let project_id = project_id;
1031 let duration = Duration::from_millis(duration_millis as u64);
1032 let project_activity = ProjectActivitySummary {
1033 id: project_id,
1034 duration,
1035 max_collaborators: project_collaborators as usize,
1036 };
1037 if let Some(last_summary) = result.last_mut() {
1038 if last_summary.id == user_id {
1039 last_summary.project_activity.push(project_activity);
1040 continue;
1041 }
1042 }
1043 result.push(UserActivitySummary {
1044 id: user_id,
1045 project_activity: vec![project_activity],
1046 github_login,
1047 });
1048 }
1049
1050 Ok(result)
1051 }
1052
1053 async fn get_user_activity_timeline(
1054 &self,
1055 time_period: Range<OffsetDateTime>,
1056 user_id: UserId,
1057 ) -> Result<Vec<UserActivityPeriod>> {
1058 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
1059
1060 let query = "
1061 SELECT
1062 project_activity_periods.ended_at,
1063 project_activity_periods.duration_millis,
1064 project_activity_periods.project_id,
1065 worktree_extensions.extension,
1066 worktree_extensions.count
1067 FROM project_activity_periods
1068 LEFT OUTER JOIN
1069 worktree_extensions
1070 ON
1071 project_activity_periods.project_id = worktree_extensions.project_id
1072 WHERE
1073 project_activity_periods.user_id = $1 AND
1074 $2 < project_activity_periods.ended_at AND
1075 project_activity_periods.ended_at <= $3
1076 ORDER BY project_activity_periods.id ASC
1077 ";
1078
1079 let mut rows = sqlx::query_as::<
1080 _,
1081 (
1082 PrimitiveDateTime,
1083 i32,
1084 ProjectId,
1085 Option<String>,
1086 Option<i32>,
1087 ),
1088 >(query)
1089 .bind(user_id)
1090 .bind(time_period.start)
1091 .bind(time_period.end)
1092 .fetch(&self.pool);
1093
1094 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1095 while let Some(row) = rows.next().await {
1096 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1097 let ended_at = ended_at.assume_utc();
1098 let duration = Duration::from_millis(duration_millis as u64);
1099 let started_at = ended_at - duration;
1100 let project_time_periods = time_periods.entry(project_id).or_default();
1101
1102 if let Some(prev_duration) = project_time_periods.last_mut() {
1103 if started_at <= prev_duration.end + COALESCE_THRESHOLD
1104 && ended_at >= prev_duration.start
1105 {
1106 prev_duration.end = cmp::max(prev_duration.end, ended_at);
1107 } else {
1108 project_time_periods.push(UserActivityPeriod {
1109 project_id,
1110 start: started_at,
1111 end: ended_at,
1112 extensions: Default::default(),
1113 });
1114 }
1115 } else {
1116 project_time_periods.push(UserActivityPeriod {
1117 project_id,
1118 start: started_at,
1119 end: ended_at,
1120 extensions: Default::default(),
1121 });
1122 }
1123
1124 if let Some((extension, extension_count)) = extension.zip(extension_count) {
1125 project_time_periods
1126 .last_mut()
1127 .unwrap()
1128 .extensions
1129 .insert(extension, extension_count as usize);
1130 }
1131 }
1132
1133 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1134 durations.sort_unstable_by_key(|duration| duration.start);
1135 Ok(durations)
1136 }
1137
1138 // contacts
1139
1140 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1141 let query = "
1142 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1143 FROM contacts
1144 WHERE user_id_a = $1 OR user_id_b = $1;
1145 ";
1146
1147 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1148 .bind(user_id)
1149 .fetch(&self.pool);
1150
1151 let mut contacts = Vec::new();
1152 while let Some(row) = rows.next().await {
1153 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1154
1155 if user_id_a == user_id {
1156 if accepted {
1157 contacts.push(Contact::Accepted {
1158 user_id: user_id_b,
1159 should_notify: should_notify && a_to_b,
1160 });
1161 } else if a_to_b {
1162 contacts.push(Contact::Outgoing { user_id: user_id_b })
1163 } else {
1164 contacts.push(Contact::Incoming {
1165 user_id: user_id_b,
1166 should_notify,
1167 });
1168 }
1169 } else if accepted {
1170 contacts.push(Contact::Accepted {
1171 user_id: user_id_a,
1172 should_notify: should_notify && !a_to_b,
1173 });
1174 } else if a_to_b {
1175 contacts.push(Contact::Incoming {
1176 user_id: user_id_a,
1177 should_notify,
1178 });
1179 } else {
1180 contacts.push(Contact::Outgoing { user_id: user_id_a });
1181 }
1182 }
1183
1184 contacts.sort_unstable_by_key(|contact| contact.user_id());
1185
1186 Ok(contacts)
1187 }
1188
1189 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1190 let (id_a, id_b) = if user_id_1 < user_id_2 {
1191 (user_id_1, user_id_2)
1192 } else {
1193 (user_id_2, user_id_1)
1194 };
1195
1196 let query = "
1197 SELECT 1 FROM contacts
1198 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1199 LIMIT 1
1200 ";
1201 Ok(sqlx::query_scalar::<_, i32>(query)
1202 .bind(id_a.0)
1203 .bind(id_b.0)
1204 .fetch_optional(&self.pool)
1205 .await?
1206 .is_some())
1207 }
1208
1209 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1210 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1211 (sender_id, receiver_id, true)
1212 } else {
1213 (receiver_id, sender_id, false)
1214 };
1215 let query = "
1216 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1217 VALUES ($1, $2, $3, 'f', 't')
1218 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1219 SET
1220 accepted = 't',
1221 should_notify = 'f'
1222 WHERE
1223 NOT contacts.accepted AND
1224 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1225 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1226 ";
1227 let result = sqlx::query(query)
1228 .bind(id_a.0)
1229 .bind(id_b.0)
1230 .bind(a_to_b)
1231 .execute(&self.pool)
1232 .await?;
1233
1234 if result.rows_affected() == 1 {
1235 Ok(())
1236 } else {
1237 Err(anyhow!("contact already requested"))?
1238 }
1239 }
1240
1241 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1242 let (id_a, id_b) = if responder_id < requester_id {
1243 (responder_id, requester_id)
1244 } else {
1245 (requester_id, responder_id)
1246 };
1247 let query = "
1248 DELETE FROM contacts
1249 WHERE user_id_a = $1 AND user_id_b = $2;
1250 ";
1251 let result = sqlx::query(query)
1252 .bind(id_a.0)
1253 .bind(id_b.0)
1254 .execute(&self.pool)
1255 .await?;
1256
1257 if result.rows_affected() == 1 {
1258 Ok(())
1259 } else {
1260 Err(anyhow!("no such contact"))?
1261 }
1262 }
1263
1264 async fn dismiss_contact_notification(
1265 &self,
1266 user_id: UserId,
1267 contact_user_id: UserId,
1268 ) -> Result<()> {
1269 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1270 (user_id, contact_user_id, true)
1271 } else {
1272 (contact_user_id, user_id, false)
1273 };
1274
1275 let query = "
1276 UPDATE contacts
1277 SET should_notify = 'f'
1278 WHERE
1279 user_id_a = $1 AND user_id_b = $2 AND
1280 (
1281 (a_to_b = $3 AND accepted) OR
1282 (a_to_b != $3 AND NOT accepted)
1283 );
1284 ";
1285
1286 let result = sqlx::query(query)
1287 .bind(id_a.0)
1288 .bind(id_b.0)
1289 .bind(a_to_b)
1290 .execute(&self.pool)
1291 .await?;
1292
1293 if result.rows_affected() == 0 {
1294 Err(anyhow!("no such contact request"))?;
1295 }
1296
1297 Ok(())
1298 }
1299
1300 async fn respond_to_contact_request(
1301 &self,
1302 responder_id: UserId,
1303 requester_id: UserId,
1304 accept: bool,
1305 ) -> Result<()> {
1306 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1307 (responder_id, requester_id, false)
1308 } else {
1309 (requester_id, responder_id, true)
1310 };
1311 let result = if accept {
1312 let query = "
1313 UPDATE contacts
1314 SET accepted = 't', should_notify = 't'
1315 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1316 ";
1317 sqlx::query(query)
1318 .bind(id_a.0)
1319 .bind(id_b.0)
1320 .bind(a_to_b)
1321 .execute(&self.pool)
1322 .await?
1323 } else {
1324 let query = "
1325 DELETE FROM contacts
1326 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1327 ";
1328 sqlx::query(query)
1329 .bind(id_a.0)
1330 .bind(id_b.0)
1331 .bind(a_to_b)
1332 .execute(&self.pool)
1333 .await?
1334 };
1335 if result.rows_affected() == 1 {
1336 Ok(())
1337 } else {
1338 Err(anyhow!("no such contact request"))?
1339 }
1340 }
1341
1342 // access tokens
1343
1344 async fn create_access_token_hash(
1345 &self,
1346 user_id: UserId,
1347 access_token_hash: &str,
1348 max_access_token_count: usize,
1349 ) -> Result<()> {
1350 let insert_query = "
1351 INSERT INTO access_tokens (user_id, hash)
1352 VALUES ($1, $2);
1353 ";
1354 let cleanup_query = "
1355 DELETE FROM access_tokens
1356 WHERE id IN (
1357 SELECT id from access_tokens
1358 WHERE user_id = $1
1359 ORDER BY id DESC
1360 OFFSET $3
1361 )
1362 ";
1363
1364 let mut tx = self.pool.begin().await?;
1365 sqlx::query(insert_query)
1366 .bind(user_id.0)
1367 .bind(access_token_hash)
1368 .execute(&mut tx)
1369 .await?;
1370 sqlx::query(cleanup_query)
1371 .bind(user_id.0)
1372 .bind(access_token_hash)
1373 .bind(max_access_token_count as i32)
1374 .execute(&mut tx)
1375 .await?;
1376 Ok(tx.commit().await?)
1377 }
1378
1379 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1380 let query = "
1381 SELECT hash
1382 FROM access_tokens
1383 WHERE user_id = $1
1384 ORDER BY id DESC
1385 ";
1386 Ok(sqlx::query_scalar(query)
1387 .bind(user_id.0)
1388 .fetch_all(&self.pool)
1389 .await?)
1390 }
1391
1392 // orgs
1393
1394 #[allow(unused)] // Help rust-analyzer
1395 #[cfg(any(test, feature = "seed-support"))]
1396 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1397 let query = "
1398 SELECT *
1399 FROM orgs
1400 WHERE slug = $1
1401 ";
1402 Ok(sqlx::query_as(query)
1403 .bind(slug)
1404 .fetch_optional(&self.pool)
1405 .await?)
1406 }
1407
1408 #[cfg(any(test, feature = "seed-support"))]
1409 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1410 let query = "
1411 INSERT INTO orgs (name, slug)
1412 VALUES ($1, $2)
1413 RETURNING id
1414 ";
1415 Ok(sqlx::query_scalar(query)
1416 .bind(name)
1417 .bind(slug)
1418 .fetch_one(&self.pool)
1419 .await
1420 .map(OrgId)?)
1421 }
1422
1423 #[cfg(any(test, feature = "seed-support"))]
1424 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1425 let query = "
1426 INSERT INTO org_memberships (org_id, user_id, admin)
1427 VALUES ($1, $2, $3)
1428 ON CONFLICT DO NOTHING
1429 ";
1430 Ok(sqlx::query(query)
1431 .bind(org_id.0)
1432 .bind(user_id.0)
1433 .bind(is_admin)
1434 .execute(&self.pool)
1435 .await
1436 .map(drop)?)
1437 }
1438
1439 // channels
1440
1441 #[cfg(any(test, feature = "seed-support"))]
1442 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1443 let query = "
1444 INSERT INTO channels (owner_id, owner_is_user, name)
1445 VALUES ($1, false, $2)
1446 RETURNING id
1447 ";
1448 Ok(sqlx::query_scalar(query)
1449 .bind(org_id.0)
1450 .bind(name)
1451 .fetch_one(&self.pool)
1452 .await
1453 .map(ChannelId)?)
1454 }
1455
1456 #[allow(unused)] // Help rust-analyzer
1457 #[cfg(any(test, feature = "seed-support"))]
1458 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1459 let query = "
1460 SELECT *
1461 FROM channels
1462 WHERE
1463 channels.owner_is_user = false AND
1464 channels.owner_id = $1
1465 ";
1466 Ok(sqlx::query_as(query)
1467 .bind(org_id.0)
1468 .fetch_all(&self.pool)
1469 .await?)
1470 }
1471
1472 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1473 let query = "
1474 SELECT
1475 channels.*
1476 FROM
1477 channel_memberships, channels
1478 WHERE
1479 channel_memberships.user_id = $1 AND
1480 channel_memberships.channel_id = channels.id
1481 ";
1482 Ok(sqlx::query_as(query)
1483 .bind(user_id.0)
1484 .fetch_all(&self.pool)
1485 .await?)
1486 }
1487
1488 async fn can_user_access_channel(
1489 &self,
1490 user_id: UserId,
1491 channel_id: ChannelId,
1492 ) -> Result<bool> {
1493 let query = "
1494 SELECT id
1495 FROM channel_memberships
1496 WHERE user_id = $1 AND channel_id = $2
1497 LIMIT 1
1498 ";
1499 Ok(sqlx::query_scalar::<_, i32>(query)
1500 .bind(user_id.0)
1501 .bind(channel_id.0)
1502 .fetch_optional(&self.pool)
1503 .await
1504 .map(|e| e.is_some())?)
1505 }
1506
1507 #[cfg(any(test, feature = "seed-support"))]
1508 async fn add_channel_member(
1509 &self,
1510 channel_id: ChannelId,
1511 user_id: UserId,
1512 is_admin: bool,
1513 ) -> Result<()> {
1514 let query = "
1515 INSERT INTO channel_memberships (channel_id, user_id, admin)
1516 VALUES ($1, $2, $3)
1517 ON CONFLICT DO NOTHING
1518 ";
1519 Ok(sqlx::query(query)
1520 .bind(channel_id.0)
1521 .bind(user_id.0)
1522 .bind(is_admin)
1523 .execute(&self.pool)
1524 .await
1525 .map(drop)?)
1526 }
1527
1528 // messages
1529
1530 async fn create_channel_message(
1531 &self,
1532 channel_id: ChannelId,
1533 sender_id: UserId,
1534 body: &str,
1535 timestamp: OffsetDateTime,
1536 nonce: u128,
1537 ) -> Result<MessageId> {
1538 let query = "
1539 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1540 VALUES ($1, $2, $3, $4, $5)
1541 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1542 RETURNING id
1543 ";
1544 Ok(sqlx::query_scalar(query)
1545 .bind(channel_id.0)
1546 .bind(sender_id.0)
1547 .bind(body)
1548 .bind(timestamp)
1549 .bind(Uuid::from_u128(nonce))
1550 .fetch_one(&self.pool)
1551 .await
1552 .map(MessageId)?)
1553 }
1554
1555 async fn get_channel_messages(
1556 &self,
1557 channel_id: ChannelId,
1558 count: usize,
1559 before_id: Option<MessageId>,
1560 ) -> Result<Vec<ChannelMessage>> {
1561 let query = r#"
1562 SELECT * FROM (
1563 SELECT
1564 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1565 FROM
1566 channel_messages
1567 WHERE
1568 channel_id = $1 AND
1569 id < $2
1570 ORDER BY id DESC
1571 LIMIT $3
1572 ) as recent_messages
1573 ORDER BY id ASC
1574 "#;
1575 Ok(sqlx::query_as(query)
1576 .bind(channel_id.0)
1577 .bind(before_id.unwrap_or(MessageId::MAX))
1578 .bind(count as i64)
1579 .fetch_all(&self.pool)
1580 .await?)
1581 }
1582
1583 #[cfg(test)]
1584 async fn teardown(&self, url: &str) {
1585 use util::ResultExt;
1586
1587 let query = "
1588 SELECT pg_terminate_backend(pg_stat_activity.pid)
1589 FROM pg_stat_activity
1590 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1591 ";
1592 sqlx::query(query).execute(&self.pool).await.log_err();
1593 self.pool.close().await;
1594 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1595 .await
1596 .log_err();
1597 }
1598
1599 #[cfg(test)]
1600 fn as_fake(&self) -> Option<&FakeDb> {
1601 None
1602 }
1603}
1604
1605macro_rules! id_type {
1606 ($name:ident) => {
1607 #[derive(
1608 Clone,
1609 Copy,
1610 Debug,
1611 Default,
1612 PartialEq,
1613 Eq,
1614 PartialOrd,
1615 Ord,
1616 Hash,
1617 sqlx::Type,
1618 Serialize,
1619 Deserialize,
1620 )]
1621 #[sqlx(transparent)]
1622 #[serde(transparent)]
1623 pub struct $name(pub i32);
1624
1625 impl $name {
1626 #[allow(unused)]
1627 pub const MAX: Self = Self(i32::MAX);
1628
1629 #[allow(unused)]
1630 pub fn from_proto(value: u64) -> Self {
1631 Self(value as i32)
1632 }
1633
1634 #[allow(unused)]
1635 pub fn to_proto(self) -> u64 {
1636 self.0 as u64
1637 }
1638 }
1639
1640 impl std::fmt::Display for $name {
1641 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1642 self.0.fmt(f)
1643 }
1644 }
1645 };
1646}
1647
1648id_type!(UserId);
1649#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1650pub struct User {
1651 pub id: UserId,
1652 pub github_login: String,
1653 pub github_user_id: Option<i32>,
1654 pub email_address: Option<String>,
1655 pub admin: bool,
1656 pub invite_code: Option<String>,
1657 pub invite_count: i32,
1658 pub connected_once: bool,
1659}
1660
1661id_type!(ProjectId);
1662#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1663pub struct Project {
1664 pub id: ProjectId,
1665 pub host_user_id: UserId,
1666 pub unregistered: bool,
1667}
1668
1669#[derive(Clone, Debug, PartialEq, Serialize)]
1670pub struct UserActivitySummary {
1671 pub id: UserId,
1672 pub github_login: String,
1673 pub project_activity: Vec<ProjectActivitySummary>,
1674}
1675
1676#[derive(Clone, Debug, PartialEq, Serialize)]
1677pub struct ProjectActivitySummary {
1678 pub id: ProjectId,
1679 pub duration: Duration,
1680 pub max_collaborators: usize,
1681}
1682
1683#[derive(Clone, Debug, PartialEq, Serialize)]
1684pub struct UserActivityPeriod {
1685 pub project_id: ProjectId,
1686 #[serde(with = "time::serde::iso8601")]
1687 pub start: OffsetDateTime,
1688 #[serde(with = "time::serde::iso8601")]
1689 pub end: OffsetDateTime,
1690 pub extensions: HashMap<String, usize>,
1691}
1692
1693id_type!(OrgId);
1694#[derive(FromRow)]
1695pub struct Org {
1696 pub id: OrgId,
1697 pub name: String,
1698 pub slug: String,
1699}
1700
1701id_type!(ChannelId);
1702#[derive(Clone, Debug, FromRow, Serialize)]
1703pub struct Channel {
1704 pub id: ChannelId,
1705 pub name: String,
1706 pub owner_id: i32,
1707 pub owner_is_user: bool,
1708}
1709
1710id_type!(MessageId);
1711#[derive(Clone, Debug, FromRow)]
1712pub struct ChannelMessage {
1713 pub id: MessageId,
1714 pub channel_id: ChannelId,
1715 pub sender_id: UserId,
1716 pub body: String,
1717 pub sent_at: OffsetDateTime,
1718 pub nonce: Uuid,
1719}
1720
1721#[derive(Clone, Debug, PartialEq, Eq)]
1722pub enum Contact {
1723 Accepted {
1724 user_id: UserId,
1725 should_notify: bool,
1726 },
1727 Outgoing {
1728 user_id: UserId,
1729 },
1730 Incoming {
1731 user_id: UserId,
1732 should_notify: bool,
1733 },
1734}
1735
1736impl Contact {
1737 pub fn user_id(&self) -> UserId {
1738 match self {
1739 Contact::Accepted { user_id, .. } => *user_id,
1740 Contact::Outgoing { user_id } => *user_id,
1741 Contact::Incoming { user_id, .. } => *user_id,
1742 }
1743 }
1744}
1745
1746#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1747pub struct IncomingContactRequest {
1748 pub requester_id: UserId,
1749 pub should_notify: bool,
1750}
1751
1752#[derive(Clone, Deserialize)]
1753pub struct Signup {
1754 pub email_address: String,
1755 pub platform_mac: bool,
1756 pub platform_windows: bool,
1757 pub platform_linux: bool,
1758 pub editor_features: Vec<String>,
1759 pub programming_languages: Vec<String>,
1760 pub device_id: Option<String>,
1761}
1762
1763#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1764pub struct WaitlistSummary {
1765 #[sqlx(default)]
1766 pub count: i64,
1767 #[sqlx(default)]
1768 pub linux_count: i64,
1769 #[sqlx(default)]
1770 pub mac_count: i64,
1771 #[sqlx(default)]
1772 pub windows_count: i64,
1773 #[sqlx(default)]
1774 pub unknown_count: i64,
1775}
1776
1777#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1778pub struct Invite {
1779 pub email_address: String,
1780 pub email_confirmation_code: String,
1781}
1782
1783#[derive(Debug, Serialize, Deserialize)]
1784pub struct NewUserParams {
1785 pub github_login: String,
1786 pub github_user_id: i32,
1787 pub invite_count: i32,
1788}
1789
1790#[derive(Debug)]
1791pub struct NewUserResult {
1792 pub user_id: UserId,
1793 pub metrics_id: String,
1794 pub inviting_user_id: Option<UserId>,
1795 pub signup_device_id: Option<String>,
1796}
1797
1798fn random_invite_code() -> String {
1799 nanoid::nanoid!(16)
1800}
1801
1802fn random_email_confirmation_code() -> String {
1803 nanoid::nanoid!(64)
1804}
1805
1806#[cfg(test)]
1807pub use test::*;
1808
1809#[cfg(test)]
1810mod test {
1811 use super::*;
1812 use anyhow::anyhow;
1813 use collections::BTreeMap;
1814 use gpui::executor::Background;
1815 use lazy_static::lazy_static;
1816 use parking_lot::Mutex;
1817 use rand::prelude::*;
1818 use sqlx::{migrate::MigrateDatabase, Postgres};
1819 use std::sync::Arc;
1820 use util::post_inc;
1821
1822 pub struct FakeDb {
1823 background: Arc<Background>,
1824 pub users: Mutex<BTreeMap<UserId, User>>,
1825 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1826 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1827 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1828 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1829 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1830 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1831 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1832 pub contacts: Mutex<Vec<FakeContact>>,
1833 next_channel_message_id: Mutex<i32>,
1834 next_user_id: Mutex<i32>,
1835 next_org_id: Mutex<i32>,
1836 next_channel_id: Mutex<i32>,
1837 next_project_id: Mutex<i32>,
1838 }
1839
1840 #[derive(Debug)]
1841 pub struct FakeContact {
1842 pub requester_id: UserId,
1843 pub responder_id: UserId,
1844 pub accepted: bool,
1845 pub should_notify: bool,
1846 }
1847
1848 impl FakeDb {
1849 pub fn new(background: Arc<Background>) -> Self {
1850 Self {
1851 background,
1852 users: Default::default(),
1853 next_user_id: Mutex::new(0),
1854 projects: Default::default(),
1855 worktree_extensions: Default::default(),
1856 next_project_id: Mutex::new(1),
1857 orgs: Default::default(),
1858 next_org_id: Mutex::new(1),
1859 org_memberships: Default::default(),
1860 channels: Default::default(),
1861 next_channel_id: Mutex::new(1),
1862 channel_memberships: Default::default(),
1863 channel_messages: Default::default(),
1864 next_channel_message_id: Mutex::new(1),
1865 contacts: Default::default(),
1866 }
1867 }
1868 }
1869
1870 #[async_trait]
1871 impl Db for FakeDb {
1872 async fn create_user(
1873 &self,
1874 email_address: &str,
1875 admin: bool,
1876 params: NewUserParams,
1877 ) -> Result<NewUserResult> {
1878 self.background.simulate_random_delay().await;
1879
1880 let mut users = self.users.lock();
1881 let user_id = if let Some(user) = users
1882 .values()
1883 .find(|user| user.github_login == params.github_login)
1884 {
1885 user.id
1886 } else {
1887 let id = post_inc(&mut *self.next_user_id.lock());
1888 let user_id = UserId(id);
1889 users.insert(
1890 user_id,
1891 User {
1892 id: user_id,
1893 github_login: params.github_login,
1894 github_user_id: Some(params.github_user_id),
1895 email_address: Some(email_address.to_string()),
1896 admin,
1897 invite_code: None,
1898 invite_count: 0,
1899 connected_once: false,
1900 },
1901 );
1902 user_id
1903 };
1904 Ok(NewUserResult {
1905 user_id,
1906 metrics_id: "the-metrics-id".to_string(),
1907 inviting_user_id: None,
1908 signup_device_id: None,
1909 })
1910 }
1911
1912 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1913 unimplemented!()
1914 }
1915
1916 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1917 unimplemented!()
1918 }
1919
1920 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1921 self.background.simulate_random_delay().await;
1922 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1923 }
1924
1925 async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
1926 Ok("the-metrics-id".to_string())
1927 }
1928
1929 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1930 self.background.simulate_random_delay().await;
1931 let users = self.users.lock();
1932 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1933 }
1934
1935 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
1936 unimplemented!()
1937 }
1938
1939 async fn get_user_by_github_account(
1940 &self,
1941 github_login: &str,
1942 github_user_id: Option<i32>,
1943 ) -> Result<Option<User>> {
1944 self.background.simulate_random_delay().await;
1945 if let Some(github_user_id) = github_user_id {
1946 for user in self.users.lock().values_mut() {
1947 if user.github_user_id == Some(github_user_id) {
1948 user.github_login = github_login.into();
1949 return Ok(Some(user.clone()));
1950 }
1951 if user.github_login == github_login {
1952 user.github_user_id = Some(github_user_id);
1953 return Ok(Some(user.clone()));
1954 }
1955 }
1956 Ok(None)
1957 } else {
1958 Ok(self
1959 .users
1960 .lock()
1961 .values()
1962 .find(|user| user.github_login == github_login)
1963 .cloned())
1964 }
1965 }
1966
1967 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1968 unimplemented!()
1969 }
1970
1971 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1972 self.background.simulate_random_delay().await;
1973 let mut users = self.users.lock();
1974 let mut user = users
1975 .get_mut(&id)
1976 .ok_or_else(|| anyhow!("user not found"))?;
1977 user.connected_once = connected_once;
1978 Ok(())
1979 }
1980
1981 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1982 unimplemented!()
1983 }
1984
1985 // signups
1986
1987 async fn create_signup(&self, _signup: Signup) -> Result<()> {
1988 unimplemented!()
1989 }
1990
1991 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
1992 unimplemented!()
1993 }
1994
1995 async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
1996 unimplemented!()
1997 }
1998
1999 async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
2000 unimplemented!()
2001 }
2002
2003 async fn create_user_from_invite(
2004 &self,
2005 _invite: &Invite,
2006 _user: NewUserParams,
2007 ) -> Result<Option<NewUserResult>> {
2008 unimplemented!()
2009 }
2010
2011 // invite codes
2012
2013 async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
2014 unimplemented!()
2015 }
2016
2017 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
2018 self.background.simulate_random_delay().await;
2019 Ok(None)
2020 }
2021
2022 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
2023 unimplemented!()
2024 }
2025
2026 async fn create_invite_from_code(
2027 &self,
2028 _code: &str,
2029 _email_address: &str,
2030 _device_id: Option<&str>,
2031 ) -> Result<Invite> {
2032 unimplemented!()
2033 }
2034
2035 // projects
2036
2037 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2038 self.background.simulate_random_delay().await;
2039 if !self.users.lock().contains_key(&host_user_id) {
2040 Err(anyhow!("no such user"))?;
2041 }
2042
2043 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2044 self.projects.lock().insert(
2045 project_id,
2046 Project {
2047 id: project_id,
2048 host_user_id,
2049 unregistered: false,
2050 },
2051 );
2052 Ok(project_id)
2053 }
2054
2055 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2056 self.background.simulate_random_delay().await;
2057 self.projects
2058 .lock()
2059 .get_mut(&project_id)
2060 .ok_or_else(|| anyhow!("no such project"))?
2061 .unregistered = true;
2062 Ok(())
2063 }
2064
2065 async fn update_worktree_extensions(
2066 &self,
2067 project_id: ProjectId,
2068 worktree_id: u64,
2069 extensions: HashMap<String, u32>,
2070 ) -> Result<()> {
2071 self.background.simulate_random_delay().await;
2072 if !self.projects.lock().contains_key(&project_id) {
2073 Err(anyhow!("no such project"))?;
2074 }
2075
2076 for (extension, count) in extensions {
2077 self.worktree_extensions
2078 .lock()
2079 .insert((project_id, worktree_id, extension), count);
2080 }
2081
2082 Ok(())
2083 }
2084
2085 async fn get_project_extensions(
2086 &self,
2087 _project_id: ProjectId,
2088 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2089 unimplemented!()
2090 }
2091
2092 async fn record_user_activity(
2093 &self,
2094 _time_period: Range<OffsetDateTime>,
2095 _active_projects: &[(UserId, ProjectId)],
2096 ) -> Result<()> {
2097 unimplemented!()
2098 }
2099
2100 async fn get_active_user_count(
2101 &self,
2102 _time_period: Range<OffsetDateTime>,
2103 _min_duration: Duration,
2104 _only_collaborative: bool,
2105 ) -> Result<usize> {
2106 unimplemented!()
2107 }
2108
2109 async fn get_top_users_activity_summary(
2110 &self,
2111 _time_period: Range<OffsetDateTime>,
2112 _limit: usize,
2113 ) -> Result<Vec<UserActivitySummary>> {
2114 unimplemented!()
2115 }
2116
2117 async fn get_user_activity_timeline(
2118 &self,
2119 _time_period: Range<OffsetDateTime>,
2120 _user_id: UserId,
2121 ) -> Result<Vec<UserActivityPeriod>> {
2122 unimplemented!()
2123 }
2124
2125 // contacts
2126
2127 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2128 self.background.simulate_random_delay().await;
2129 let mut contacts = Vec::new();
2130
2131 for contact in self.contacts.lock().iter() {
2132 if contact.requester_id == id {
2133 if contact.accepted {
2134 contacts.push(Contact::Accepted {
2135 user_id: contact.responder_id,
2136 should_notify: contact.should_notify,
2137 });
2138 } else {
2139 contacts.push(Contact::Outgoing {
2140 user_id: contact.responder_id,
2141 });
2142 }
2143 } else if contact.responder_id == id {
2144 if contact.accepted {
2145 contacts.push(Contact::Accepted {
2146 user_id: contact.requester_id,
2147 should_notify: false,
2148 });
2149 } else {
2150 contacts.push(Contact::Incoming {
2151 user_id: contact.requester_id,
2152 should_notify: contact.should_notify,
2153 });
2154 }
2155 }
2156 }
2157
2158 contacts.sort_unstable_by_key(|contact| contact.user_id());
2159 Ok(contacts)
2160 }
2161
2162 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2163 self.background.simulate_random_delay().await;
2164 Ok(self.contacts.lock().iter().any(|contact| {
2165 contact.accepted
2166 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2167 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2168 }))
2169 }
2170
2171 async fn send_contact_request(
2172 &self,
2173 requester_id: UserId,
2174 responder_id: UserId,
2175 ) -> Result<()> {
2176 self.background.simulate_random_delay().await;
2177 let mut contacts = self.contacts.lock();
2178 for contact in contacts.iter_mut() {
2179 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2180 if contact.accepted {
2181 Err(anyhow!("contact already exists"))?;
2182 } else {
2183 Err(anyhow!("contact already requested"))?;
2184 }
2185 }
2186 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2187 if contact.accepted {
2188 Err(anyhow!("contact already exists"))?;
2189 } else {
2190 contact.accepted = true;
2191 contact.should_notify = false;
2192 return Ok(());
2193 }
2194 }
2195 }
2196 contacts.push(FakeContact {
2197 requester_id,
2198 responder_id,
2199 accepted: false,
2200 should_notify: true,
2201 });
2202 Ok(())
2203 }
2204
2205 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2206 self.background.simulate_random_delay().await;
2207 self.contacts.lock().retain(|contact| {
2208 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2209 });
2210 Ok(())
2211 }
2212
2213 async fn dismiss_contact_notification(
2214 &self,
2215 user_id: UserId,
2216 contact_user_id: UserId,
2217 ) -> Result<()> {
2218 self.background.simulate_random_delay().await;
2219 let mut contacts = self.contacts.lock();
2220 for contact in contacts.iter_mut() {
2221 if contact.requester_id == contact_user_id
2222 && contact.responder_id == user_id
2223 && !contact.accepted
2224 {
2225 contact.should_notify = false;
2226 return Ok(());
2227 }
2228 if contact.requester_id == user_id
2229 && contact.responder_id == contact_user_id
2230 && contact.accepted
2231 {
2232 contact.should_notify = false;
2233 return Ok(());
2234 }
2235 }
2236 Err(anyhow!("no such notification"))?
2237 }
2238
2239 async fn respond_to_contact_request(
2240 &self,
2241 responder_id: UserId,
2242 requester_id: UserId,
2243 accept: bool,
2244 ) -> Result<()> {
2245 self.background.simulate_random_delay().await;
2246 let mut contacts = self.contacts.lock();
2247 for (ix, contact) in contacts.iter_mut().enumerate() {
2248 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2249 if contact.accepted {
2250 Err(anyhow!("contact already confirmed"))?;
2251 }
2252 if accept {
2253 contact.accepted = true;
2254 contact.should_notify = true;
2255 } else {
2256 contacts.remove(ix);
2257 }
2258 return Ok(());
2259 }
2260 }
2261 Err(anyhow!("no such contact request"))?
2262 }
2263
2264 async fn create_access_token_hash(
2265 &self,
2266 _user_id: UserId,
2267 _access_token_hash: &str,
2268 _max_access_token_count: usize,
2269 ) -> Result<()> {
2270 unimplemented!()
2271 }
2272
2273 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2274 unimplemented!()
2275 }
2276
2277 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2278 unimplemented!()
2279 }
2280
2281 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2282 self.background.simulate_random_delay().await;
2283 let mut orgs = self.orgs.lock();
2284 if orgs.values().any(|org| org.slug == slug) {
2285 Err(anyhow!("org already exists"))?
2286 } else {
2287 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2288 orgs.insert(
2289 org_id,
2290 Org {
2291 id: org_id,
2292 name: name.to_string(),
2293 slug: slug.to_string(),
2294 },
2295 );
2296 Ok(org_id)
2297 }
2298 }
2299
2300 async fn add_org_member(
2301 &self,
2302 org_id: OrgId,
2303 user_id: UserId,
2304 is_admin: bool,
2305 ) -> Result<()> {
2306 self.background.simulate_random_delay().await;
2307 if !self.orgs.lock().contains_key(&org_id) {
2308 Err(anyhow!("org does not exist"))?;
2309 }
2310 if !self.users.lock().contains_key(&user_id) {
2311 Err(anyhow!("user does not exist"))?;
2312 }
2313
2314 self.org_memberships
2315 .lock()
2316 .entry((org_id, user_id))
2317 .or_insert(is_admin);
2318 Ok(())
2319 }
2320
2321 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2322 self.background.simulate_random_delay().await;
2323 if !self.orgs.lock().contains_key(&org_id) {
2324 Err(anyhow!("org does not exist"))?;
2325 }
2326
2327 let mut channels = self.channels.lock();
2328 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2329 channels.insert(
2330 channel_id,
2331 Channel {
2332 id: channel_id,
2333 name: name.to_string(),
2334 owner_id: org_id.0,
2335 owner_is_user: false,
2336 },
2337 );
2338 Ok(channel_id)
2339 }
2340
2341 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2342 self.background.simulate_random_delay().await;
2343 Ok(self
2344 .channels
2345 .lock()
2346 .values()
2347 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2348 .cloned()
2349 .collect())
2350 }
2351
2352 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2353 self.background.simulate_random_delay().await;
2354 let channels = self.channels.lock();
2355 let memberships = self.channel_memberships.lock();
2356 Ok(channels
2357 .values()
2358 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2359 .cloned()
2360 .collect())
2361 }
2362
2363 async fn can_user_access_channel(
2364 &self,
2365 user_id: UserId,
2366 channel_id: ChannelId,
2367 ) -> Result<bool> {
2368 self.background.simulate_random_delay().await;
2369 Ok(self
2370 .channel_memberships
2371 .lock()
2372 .contains_key(&(channel_id, user_id)))
2373 }
2374
2375 async fn add_channel_member(
2376 &self,
2377 channel_id: ChannelId,
2378 user_id: UserId,
2379 is_admin: bool,
2380 ) -> Result<()> {
2381 self.background.simulate_random_delay().await;
2382 if !self.channels.lock().contains_key(&channel_id) {
2383 Err(anyhow!("channel does not exist"))?;
2384 }
2385 if !self.users.lock().contains_key(&user_id) {
2386 Err(anyhow!("user does not exist"))?;
2387 }
2388
2389 self.channel_memberships
2390 .lock()
2391 .entry((channel_id, user_id))
2392 .or_insert(is_admin);
2393 Ok(())
2394 }
2395
2396 async fn create_channel_message(
2397 &self,
2398 channel_id: ChannelId,
2399 sender_id: UserId,
2400 body: &str,
2401 timestamp: OffsetDateTime,
2402 nonce: u128,
2403 ) -> Result<MessageId> {
2404 self.background.simulate_random_delay().await;
2405 if !self.channels.lock().contains_key(&channel_id) {
2406 Err(anyhow!("channel does not exist"))?;
2407 }
2408 if !self.users.lock().contains_key(&sender_id) {
2409 Err(anyhow!("user does not exist"))?;
2410 }
2411
2412 let mut messages = self.channel_messages.lock();
2413 if let Some(message) = messages
2414 .values()
2415 .find(|message| message.nonce.as_u128() == nonce)
2416 {
2417 Ok(message.id)
2418 } else {
2419 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2420 messages.insert(
2421 message_id,
2422 ChannelMessage {
2423 id: message_id,
2424 channel_id,
2425 sender_id,
2426 body: body.to_string(),
2427 sent_at: timestamp,
2428 nonce: Uuid::from_u128(nonce),
2429 },
2430 );
2431 Ok(message_id)
2432 }
2433 }
2434
2435 async fn get_channel_messages(
2436 &self,
2437 channel_id: ChannelId,
2438 count: usize,
2439 before_id: Option<MessageId>,
2440 ) -> Result<Vec<ChannelMessage>> {
2441 self.background.simulate_random_delay().await;
2442 let mut messages = self
2443 .channel_messages
2444 .lock()
2445 .values()
2446 .rev()
2447 .filter(|message| {
2448 message.channel_id == channel_id
2449 && message.id < before_id.unwrap_or(MessageId::MAX)
2450 })
2451 .take(count)
2452 .cloned()
2453 .collect::<Vec<_>>();
2454 messages.sort_unstable_by_key(|message| message.id);
2455 Ok(messages)
2456 }
2457
2458 async fn teardown(&self, _: &str) {}
2459
2460 #[cfg(test)]
2461 fn as_fake(&self) -> Option<&FakeDb> {
2462 Some(self)
2463 }
2464 }
2465
2466 pub struct TestDb {
2467 pub db: Option<Arc<dyn Db>>,
2468 pub url: String,
2469 }
2470
2471 impl TestDb {
2472 #[allow(clippy::await_holding_lock)]
2473 pub async fn postgres() -> Self {
2474 lazy_static! {
2475 static ref LOCK: Mutex<()> = Mutex::new(());
2476 }
2477
2478 let _guard = LOCK.lock();
2479 let mut rng = StdRng::from_entropy();
2480 let name = format!("zed-test-{}", rng.gen::<u128>());
2481 let url = format!("postgres://postgres@localhost/{}", name);
2482 Postgres::create_database(&url)
2483 .await
2484 .expect("failed to create test db");
2485 let db = PostgresDb::new(&url, 5).await.unwrap();
2486 db.migrate(Path::new(DEFAULT_MIGRATIONS_PATH.unwrap()), false)
2487 .await
2488 .unwrap();
2489 Self {
2490 db: Some(Arc::new(db)),
2491 url,
2492 }
2493 }
2494
2495 pub fn fake(background: Arc<Background>) -> Self {
2496 Self {
2497 db: Some(Arc::new(FakeDb::new(background))),
2498 url: Default::default(),
2499 }
2500 }
2501
2502 pub fn db(&self) -> &Arc<dyn Db> {
2503 self.db.as_ref().unwrap()
2504 }
2505 }
2506
2507 impl Drop for TestDb {
2508 fn drop(&mut self) {
2509 if let Some(db) = self.db.take() {
2510 futures::executor::block_on(db.teardown(&self.url));
2511 }
2512 }
2513 }
2514}