1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use collections::HashMap;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow, QueryBuilder};
10use std::{cmp, ops::Range, time::Duration};
11use time::{OffsetDateTime, PrimitiveDateTime};
12
13#[async_trait]
14pub trait Db: Send + Sync {
15 async fn create_user(
16 &self,
17 email_address: &str,
18 admin: bool,
19 params: NewUserParams,
20 ) -> Result<NewUserResult>;
21 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
22 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
23 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
24 async fn get_user_metrics_id(&self, id: UserId) -> Result<String>;
25 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
26 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
27 async fn get_user_by_github_account(
28 &self,
29 github_login: &str,
30 github_user_id: Option<i32>,
31 ) -> Result<Option<User>>;
32 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
33 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
34 async fn destroy_user(&self, id: UserId) -> Result<()>;
35
36 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
37 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
38 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
39 async fn create_invite_from_code(
40 &self,
41 code: &str,
42 email_address: &str,
43 device_id: Option<&str>,
44 ) -> Result<Invite>;
45
46 async fn create_signup(&self, signup: Signup) -> Result<()>;
47 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
48 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<UnsentInvite>>;
49 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
50 async fn create_user_from_invite(
51 &self,
52 invite: &Invite,
53 user: NewUserParams,
54 ) -> Result<NewUserResult>;
55
56 /// Registers a new project for the given user.
57 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
58
59 /// Unregisters a project for the given project id.
60 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
61
62 /// Update file counts by extension for the given project and worktree.
63 async fn update_worktree_extensions(
64 &self,
65 project_id: ProjectId,
66 worktree_id: u64,
67 extensions: HashMap<String, u32>,
68 ) -> Result<()>;
69
70 /// Get the file counts on the given project keyed by their worktree and extension.
71 async fn get_project_extensions(
72 &self,
73 project_id: ProjectId,
74 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
75
76 /// Record which users have been active in which projects during
77 /// a given period of time.
78 async fn record_user_activity(
79 &self,
80 time_period: Range<OffsetDateTime>,
81 active_projects: &[(UserId, ProjectId)],
82 ) -> Result<()>;
83
84 /// Get the number of users who have been active in the given
85 /// time period for at least the given time duration.
86 async fn get_active_user_count(
87 &self,
88 time_period: Range<OffsetDateTime>,
89 min_duration: Duration,
90 only_collaborative: bool,
91 ) -> Result<usize>;
92
93 /// Get the users that have been most active during the given time period,
94 /// along with the amount of time they have been active in each project.
95 async fn get_top_users_activity_summary(
96 &self,
97 time_period: Range<OffsetDateTime>,
98 max_user_count: usize,
99 ) -> Result<Vec<UserActivitySummary>>;
100
101 /// Get the project activity for the given user and time period.
102 async fn get_user_activity_timeline(
103 &self,
104 time_period: Range<OffsetDateTime>,
105 user_id: UserId,
106 ) -> Result<Vec<UserActivityPeriod>>;
107
108 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
109 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
110 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
111 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
112 async fn dismiss_contact_notification(
113 &self,
114 responder_id: UserId,
115 requester_id: UserId,
116 ) -> Result<()>;
117 async fn respond_to_contact_request(
118 &self,
119 responder_id: UserId,
120 requester_id: UserId,
121 accept: bool,
122 ) -> Result<()>;
123
124 async fn create_access_token_hash(
125 &self,
126 user_id: UserId,
127 access_token_hash: &str,
128 max_access_token_count: usize,
129 ) -> Result<()>;
130 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
131
132 #[cfg(any(test, feature = "seed-support"))]
133 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
134 #[cfg(any(test, feature = "seed-support"))]
135 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
136 #[cfg(any(test, feature = "seed-support"))]
137 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
138 #[cfg(any(test, feature = "seed-support"))]
139 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
140 #[cfg(any(test, feature = "seed-support"))]
141
142 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
143 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
144 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
145 -> Result<bool>;
146
147 #[cfg(any(test, feature = "seed-support"))]
148 async fn add_channel_member(
149 &self,
150 channel_id: ChannelId,
151 user_id: UserId,
152 is_admin: bool,
153 ) -> Result<()>;
154 async fn create_channel_message(
155 &self,
156 channel_id: ChannelId,
157 sender_id: UserId,
158 body: &str,
159 timestamp: OffsetDateTime,
160 nonce: u128,
161 ) -> Result<MessageId>;
162 async fn get_channel_messages(
163 &self,
164 channel_id: ChannelId,
165 count: usize,
166 before_id: Option<MessageId>,
167 ) -> Result<Vec<ChannelMessage>>;
168
169 #[cfg(test)]
170 async fn teardown(&self, url: &str);
171
172 #[cfg(test)]
173 fn as_fake(&self) -> Option<&FakeDb>;
174}
175
176pub struct PostgresDb {
177 pool: sqlx::PgPool,
178}
179
180impl PostgresDb {
181 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
182 let pool = DbOptions::new()
183 .max_connections(max_connections)
184 .connect(url)
185 .await
186 .context("failed to connect to postgres database")?;
187 Ok(Self { pool })
188 }
189
190 pub fn fuzzy_like_string(string: &str) -> String {
191 let mut result = String::with_capacity(string.len() * 2 + 1);
192 for c in string.chars() {
193 if c.is_alphanumeric() {
194 result.push('%');
195 result.push(c);
196 }
197 }
198 result.push('%');
199 result
200 }
201}
202
203#[async_trait]
204impl Db for PostgresDb {
205 // users
206
207 async fn create_user(
208 &self,
209 email_address: &str,
210 admin: bool,
211 params: NewUserParams,
212 ) -> Result<NewUserResult> {
213 let query = "
214 INSERT INTO users (email_address, github_login, github_user_id, admin)
215 VALUES ($1, $2, $3, $4)
216 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
217 RETURNING id, metrics_id::text
218 ";
219 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
220 .bind(email_address)
221 .bind(params.github_login)
222 .bind(params.github_user_id)
223 .bind(admin)
224 .fetch_one(&self.pool)
225 .await?;
226 Ok(NewUserResult {
227 user_id,
228 metrics_id,
229 signup_device_id: None,
230 inviting_user_id: None,
231 })
232 }
233
234 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
235 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
236 Ok(sqlx::query_as(query)
237 .bind(limit as i32)
238 .bind((page * limit) as i32)
239 .fetch_all(&self.pool)
240 .await?)
241 }
242
243 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
244 let like_string = Self::fuzzy_like_string(name_query);
245 let query = "
246 SELECT users.*
247 FROM users
248 WHERE github_login ILIKE $1
249 ORDER BY github_login <-> $2
250 LIMIT $3
251 ";
252 Ok(sqlx::query_as(query)
253 .bind(like_string)
254 .bind(name_query)
255 .bind(limit as i32)
256 .fetch_all(&self.pool)
257 .await?)
258 }
259
260 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
261 let users = self.get_users_by_ids(vec![id]).await?;
262 Ok(users.into_iter().next())
263 }
264
265 async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
266 let query = "
267 SELECT metrics_id::text
268 FROM users
269 WHERE id = $1
270 ";
271 Ok(sqlx::query_scalar(query)
272 .bind(id)
273 .fetch_one(&self.pool)
274 .await?)
275 }
276
277 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
278 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
279 let query = "
280 SELECT users.*
281 FROM users
282 WHERE users.id = ANY ($1)
283 ";
284 Ok(sqlx::query_as(query)
285 .bind(&ids)
286 .fetch_all(&self.pool)
287 .await?)
288 }
289
290 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
291 let query = format!(
292 "
293 SELECT users.*
294 FROM users
295 WHERE invite_count = 0
296 AND inviter_id IS{} NULL
297 ",
298 if invited_by_another_user { " NOT" } else { "" }
299 );
300
301 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
302 }
303
304 async fn get_user_by_github_account(
305 &self,
306 github_login: &str,
307 github_user_id: Option<i32>,
308 ) -> Result<Option<User>> {
309 if let Some(github_user_id) = github_user_id {
310 let mut user = sqlx::query_as::<_, User>(
311 "
312 UPDATE users
313 SET github_login = $1
314 WHERE github_user_id = $2
315 RETURNING *
316 ",
317 )
318 .bind(github_login)
319 .bind(github_user_id)
320 .fetch_optional(&self.pool)
321 .await?;
322
323 if user.is_none() {
324 user = sqlx::query_as::<_, User>(
325 "
326 UPDATE users
327 SET github_user_id = $1
328 WHERE github_login = $2
329 RETURNING *
330 ",
331 )
332 .bind(github_user_id)
333 .bind(github_login)
334 .fetch_optional(&self.pool)
335 .await?;
336 }
337
338 Ok(user)
339 } else {
340 Ok(sqlx::query_as(
341 "
342 SELECT * FROM users
343 WHERE github_login = $1
344 LIMIT 1
345 ",
346 )
347 .bind(github_login)
348 .fetch_optional(&self.pool)
349 .await?)
350 }
351 }
352
353 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
354 let query = "UPDATE users SET admin = $1 WHERE id = $2";
355 Ok(sqlx::query(query)
356 .bind(is_admin)
357 .bind(id.0)
358 .execute(&self.pool)
359 .await
360 .map(drop)?)
361 }
362
363 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
364 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
365 Ok(sqlx::query(query)
366 .bind(connected_once)
367 .bind(id.0)
368 .execute(&self.pool)
369 .await
370 .map(drop)?)
371 }
372
373 async fn destroy_user(&self, id: UserId) -> Result<()> {
374 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
375 sqlx::query(query)
376 .bind(id.0)
377 .execute(&self.pool)
378 .await
379 .map(drop)?;
380 let query = "DELETE FROM users WHERE id = $1;";
381 Ok(sqlx::query(query)
382 .bind(id.0)
383 .execute(&self.pool)
384 .await
385 .map(drop)?)
386 }
387
388 // signups
389
390 async fn create_signup(&self, signup: Signup) -> Result<()> {
391 sqlx::query(
392 "
393 INSERT INTO signups
394 (
395 email_address,
396 email_confirmation_code,
397 email_confirmation_sent,
398 platform_linux,
399 platform_mac,
400 platform_windows,
401 platform_unknown,
402 editor_features,
403 programming_languages,
404 device_id
405 )
406 VALUES
407 ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
408 RETURNING id
409 ",
410 )
411 .bind(&signup.email_address)
412 .bind(&random_email_confirmation_code())
413 .bind(&signup.platform_linux)
414 .bind(&signup.platform_mac)
415 .bind(&signup.platform_windows)
416 .bind(&signup.editor_features)
417 .bind(&signup.programming_languages)
418 .bind(&signup.device_id)
419 .execute(&self.pool)
420 .await?;
421 Ok(())
422 }
423
424 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
425 Ok(sqlx::query_as(
426 "
427 SELECT
428 COUNT(*) as count,
429 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
430 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
431 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
432 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
433 FROM (
434 SELECT *
435 FROM signups
436 WHERE
437 NOT email_confirmation_sent
438 ) AS unsent
439 ",
440 )
441 .fetch_one(&self.pool)
442 .await?)
443 }
444
445 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<UnsentInvite>> {
446 let rows: Vec<(String, String, bool)> = sqlx::query_as(
447 "
448 SELECT
449 email_address, email_confirmation_code, platform_unknown
450 FROM signups
451 WHERE
452 NOT email_confirmation_sent AND
453 (platform_mac OR platform_unknown)
454 LIMIT $1
455 ",
456 )
457 .bind(count as i32)
458 .fetch_all(&self.pool)
459 .await?;
460 Ok(rows
461 .into_iter()
462 .map(
463 |(email_address, email_confirmation_code, platform_unknown)| UnsentInvite {
464 invite: Invite {
465 email_address,
466 email_confirmation_code,
467 },
468 platform_unknown,
469 },
470 )
471 .collect())
472 }
473
474 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
475 sqlx::query(
476 "
477 UPDATE signups
478 SET email_confirmation_sent = 't'
479 WHERE email_address = ANY ($1)
480 ",
481 )
482 .bind(
483 &invites
484 .iter()
485 .map(|s| s.email_address.as_str())
486 .collect::<Vec<_>>(),
487 )
488 .execute(&self.pool)
489 .await?;
490 Ok(())
491 }
492
493 async fn create_user_from_invite(
494 &self,
495 invite: &Invite,
496 user: NewUserParams,
497 ) -> Result<NewUserResult> {
498 let mut tx = self.pool.begin().await?;
499
500 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
501 i32,
502 Option<UserId>,
503 Option<UserId>,
504 Option<String>,
505 ) = sqlx::query_as(
506 "
507 SELECT id, user_id, inviting_user_id, device_id
508 FROM signups
509 WHERE
510 email_address = $1 AND
511 email_confirmation_code = $2
512 ",
513 )
514 .bind(&invite.email_address)
515 .bind(&invite.email_confirmation_code)
516 .fetch_optional(&mut tx)
517 .await?
518 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
519
520 if existing_user_id.is_some() {
521 Err(Error::Http(
522 StatusCode::UNPROCESSABLE_ENTITY,
523 "invitation already redeemed".to_string(),
524 ))?;
525 }
526
527 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
528 "
529 INSERT INTO users
530 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
531 VALUES
532 ($1, $2, $3, 'f', $4, $5)
533 RETURNING id, metrics_id::text
534 ",
535 )
536 .bind(&invite.email_address)
537 .bind(&user.github_login)
538 .bind(&user.github_user_id)
539 .bind(&user.invite_count)
540 .bind(random_invite_code())
541 .fetch_one(&mut tx)
542 .await?;
543
544 sqlx::query(
545 "
546 UPDATE signups
547 SET user_id = $1
548 WHERE id = $2
549 ",
550 )
551 .bind(&user_id)
552 .bind(&signup_id)
553 .execute(&mut tx)
554 .await?;
555
556 if let Some(inviting_user_id) = inviting_user_id {
557 let id: Option<UserId> = sqlx::query_scalar(
558 "
559 UPDATE users
560 SET invite_count = invite_count - 1
561 WHERE id = $1 AND invite_count > 0
562 RETURNING id
563 ",
564 )
565 .bind(&inviting_user_id)
566 .fetch_optional(&mut tx)
567 .await?;
568
569 if id.is_none() {
570 Err(Error::Http(
571 StatusCode::UNAUTHORIZED,
572 "no invites remaining".to_string(),
573 ))?;
574 }
575
576 sqlx::query(
577 "
578 INSERT INTO contacts
579 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
580 VALUES
581 ($1, $2, 't', 't', 't')
582 ",
583 )
584 .bind(inviting_user_id)
585 .bind(user_id)
586 .execute(&mut tx)
587 .await?;
588 }
589
590 tx.commit().await?;
591 Ok(NewUserResult {
592 user_id,
593 metrics_id,
594 inviting_user_id,
595 signup_device_id,
596 })
597 }
598
599 // invite codes
600
601 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
602 let mut tx = self.pool.begin().await?;
603 if count > 0 {
604 sqlx::query(
605 "
606 UPDATE users
607 SET invite_code = $1
608 WHERE id = $2 AND invite_code IS NULL
609 ",
610 )
611 .bind(random_invite_code())
612 .bind(id)
613 .execute(&mut tx)
614 .await?;
615 }
616
617 sqlx::query(
618 "
619 UPDATE users
620 SET invite_count = $1
621 WHERE id = $2
622 ",
623 )
624 .bind(count as i32)
625 .bind(id)
626 .execute(&mut tx)
627 .await?;
628 tx.commit().await?;
629 Ok(())
630 }
631
632 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
633 let result: Option<(String, i32)> = sqlx::query_as(
634 "
635 SELECT invite_code, invite_count
636 FROM users
637 WHERE id = $1 AND invite_code IS NOT NULL
638 ",
639 )
640 .bind(id)
641 .fetch_optional(&self.pool)
642 .await?;
643 if let Some((code, count)) = result {
644 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
645 } else {
646 Ok(None)
647 }
648 }
649
650 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
651 sqlx::query_as(
652 "
653 SELECT *
654 FROM users
655 WHERE invite_code = $1
656 ",
657 )
658 .bind(code)
659 .fetch_optional(&self.pool)
660 .await?
661 .ok_or_else(|| {
662 Error::Http(
663 StatusCode::NOT_FOUND,
664 "that invite code does not exist".to_string(),
665 )
666 })
667 }
668
669 async fn create_invite_from_code(
670 &self,
671 code: &str,
672 email_address: &str,
673 device_id: Option<&str>,
674 ) -> Result<Invite> {
675 let mut tx = self.pool.begin().await?;
676
677 let existing_user: Option<UserId> = sqlx::query_scalar(
678 "
679 SELECT id
680 FROM users
681 WHERE email_address = $1
682 ",
683 )
684 .bind(email_address)
685 .fetch_optional(&mut tx)
686 .await?;
687 if existing_user.is_some() {
688 Err(anyhow!("email address is already in use"))?;
689 }
690
691 let row: Option<(UserId, i32)> = sqlx::query_as(
692 "
693 SELECT id, invite_count
694 FROM users
695 WHERE invite_code = $1
696 ",
697 )
698 .bind(code)
699 .fetch_optional(&mut tx)
700 .await?;
701
702 let (inviter_id, invite_count) = match row {
703 Some(row) => row,
704 None => Err(Error::Http(
705 StatusCode::NOT_FOUND,
706 "invite code not found".to_string(),
707 ))?,
708 };
709
710 if invite_count == 0 {
711 Err(Error::Http(
712 StatusCode::UNAUTHORIZED,
713 "no invites remaining".to_string(),
714 ))?;
715 }
716
717 let email_confirmation_code: String = sqlx::query_scalar(
718 "
719 INSERT INTO signups
720 (
721 email_address,
722 email_confirmation_code,
723 email_confirmation_sent,
724 inviting_user_id,
725 platform_linux,
726 platform_mac,
727 platform_windows,
728 platform_unknown,
729 device_id
730 )
731 VALUES
732 ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4)
733 ON CONFLICT (email_address)
734 DO UPDATE SET
735 inviting_user_id = excluded.inviting_user_id
736 RETURNING email_confirmation_code
737 ",
738 )
739 .bind(&email_address)
740 .bind(&random_email_confirmation_code())
741 .bind(&inviter_id)
742 .bind(&device_id)
743 .fetch_one(&mut tx)
744 .await?;
745
746 tx.commit().await?;
747
748 Ok(Invite {
749 email_address: email_address.into(),
750 email_confirmation_code,
751 })
752 }
753
754 // projects
755
756 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
757 Ok(sqlx::query_scalar(
758 "
759 INSERT INTO projects(host_user_id)
760 VALUES ($1)
761 RETURNING id
762 ",
763 )
764 .bind(host_user_id)
765 .fetch_one(&self.pool)
766 .await
767 .map(ProjectId)?)
768 }
769
770 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
771 sqlx::query(
772 "
773 UPDATE projects
774 SET unregistered = 't'
775 WHERE id = $1
776 ",
777 )
778 .bind(project_id)
779 .execute(&self.pool)
780 .await?;
781 Ok(())
782 }
783
784 async fn update_worktree_extensions(
785 &self,
786 project_id: ProjectId,
787 worktree_id: u64,
788 extensions: HashMap<String, u32>,
789 ) -> Result<()> {
790 if extensions.is_empty() {
791 return Ok(());
792 }
793
794 let mut query = QueryBuilder::new(
795 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
796 );
797 query.push_values(extensions, |mut query, (extension, count)| {
798 query
799 .push_bind(project_id)
800 .push_bind(worktree_id as i32)
801 .push_bind(extension)
802 .push_bind(count as i32);
803 });
804 query.push(
805 "
806 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
807 count = excluded.count
808 ",
809 );
810 query.build().execute(&self.pool).await?;
811
812 Ok(())
813 }
814
815 async fn get_project_extensions(
816 &self,
817 project_id: ProjectId,
818 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
819 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
820 struct WorktreeExtension {
821 worktree_id: i32,
822 extension: String,
823 count: i32,
824 }
825
826 let query = "
827 SELECT worktree_id, extension, count
828 FROM worktree_extensions
829 WHERE project_id = $1
830 ";
831 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
832 .bind(&project_id)
833 .fetch_all(&self.pool)
834 .await?;
835
836 let mut extension_counts = HashMap::default();
837 for count in counts {
838 extension_counts
839 .entry(count.worktree_id as u64)
840 .or_insert_with(HashMap::default)
841 .insert(count.extension, count.count as usize);
842 }
843 Ok(extension_counts)
844 }
845
846 async fn record_user_activity(
847 &self,
848 time_period: Range<OffsetDateTime>,
849 projects: &[(UserId, ProjectId)],
850 ) -> Result<()> {
851 let query = "
852 INSERT INTO project_activity_periods
853 (ended_at, duration_millis, user_id, project_id)
854 VALUES
855 ($1, $2, $3, $4);
856 ";
857
858 let mut tx = self.pool.begin().await?;
859 let duration_millis =
860 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
861 for (user_id, project_id) in projects {
862 sqlx::query(query)
863 .bind(time_period.end)
864 .bind(duration_millis)
865 .bind(user_id)
866 .bind(project_id)
867 .execute(&mut tx)
868 .await?;
869 }
870 tx.commit().await?;
871
872 Ok(())
873 }
874
875 async fn get_active_user_count(
876 &self,
877 time_period: Range<OffsetDateTime>,
878 min_duration: Duration,
879 only_collaborative: bool,
880 ) -> Result<usize> {
881 let mut with_clause = String::new();
882 with_clause.push_str("WITH\n");
883 with_clause.push_str(
884 "
885 project_durations AS (
886 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
887 FROM project_activity_periods
888 WHERE $1 < ended_at AND ended_at <= $2
889 GROUP BY user_id, project_id
890 ),
891 ",
892 );
893 with_clause.push_str(
894 "
895 project_collaborators as (
896 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
897 FROM project_durations
898 GROUP BY project_id
899 ),
900 ",
901 );
902
903 if only_collaborative {
904 with_clause.push_str(
905 "
906 user_durations AS (
907 SELECT user_id, SUM(project_duration) as total_duration
908 FROM project_durations, project_collaborators
909 WHERE
910 project_durations.project_id = project_collaborators.project_id AND
911 max_collaborators > 1
912 GROUP BY user_id
913 ORDER BY total_duration DESC
914 LIMIT $3
915 )
916 ",
917 );
918 } else {
919 with_clause.push_str(
920 "
921 user_durations AS (
922 SELECT user_id, SUM(project_duration) as total_duration
923 FROM project_durations
924 GROUP BY user_id
925 ORDER BY total_duration DESC
926 LIMIT $3
927 )
928 ",
929 );
930 }
931
932 let query = format!(
933 "
934 {with_clause}
935 SELECT count(user_durations.user_id)
936 FROM user_durations
937 WHERE user_durations.total_duration >= $3
938 "
939 );
940
941 let count: i64 = sqlx::query_scalar(&query)
942 .bind(time_period.start)
943 .bind(time_period.end)
944 .bind(min_duration.as_millis() as i64)
945 .fetch_one(&self.pool)
946 .await?;
947 Ok(count as usize)
948 }
949
950 async fn get_top_users_activity_summary(
951 &self,
952 time_period: Range<OffsetDateTime>,
953 max_user_count: usize,
954 ) -> Result<Vec<UserActivitySummary>> {
955 let query = "
956 WITH
957 project_durations AS (
958 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
959 FROM project_activity_periods
960 WHERE $1 < ended_at AND ended_at <= $2
961 GROUP BY user_id, project_id
962 ),
963 user_durations AS (
964 SELECT user_id, SUM(project_duration) as total_duration
965 FROM project_durations
966 GROUP BY user_id
967 ORDER BY total_duration DESC
968 LIMIT $3
969 ),
970 project_collaborators as (
971 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
972 FROM project_durations
973 GROUP BY project_id
974 )
975 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
976 FROM user_durations, project_durations, project_collaborators, users
977 WHERE
978 user_durations.user_id = project_durations.user_id AND
979 user_durations.user_id = users.id AND
980 project_durations.project_id = project_collaborators.project_id
981 ORDER BY total_duration DESC, user_id ASC, project_id ASC
982 ";
983
984 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
985 .bind(time_period.start)
986 .bind(time_period.end)
987 .bind(max_user_count as i32)
988 .fetch(&self.pool);
989
990 let mut result = Vec::<UserActivitySummary>::new();
991 while let Some(row) = rows.next().await {
992 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
993 let project_id = project_id;
994 let duration = Duration::from_millis(duration_millis as u64);
995 let project_activity = ProjectActivitySummary {
996 id: project_id,
997 duration,
998 max_collaborators: project_collaborators as usize,
999 };
1000 if let Some(last_summary) = result.last_mut() {
1001 if last_summary.id == user_id {
1002 last_summary.project_activity.push(project_activity);
1003 continue;
1004 }
1005 }
1006 result.push(UserActivitySummary {
1007 id: user_id,
1008 project_activity: vec![project_activity],
1009 github_login,
1010 });
1011 }
1012
1013 Ok(result)
1014 }
1015
1016 async fn get_user_activity_timeline(
1017 &self,
1018 time_period: Range<OffsetDateTime>,
1019 user_id: UserId,
1020 ) -> Result<Vec<UserActivityPeriod>> {
1021 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
1022
1023 let query = "
1024 SELECT
1025 project_activity_periods.ended_at,
1026 project_activity_periods.duration_millis,
1027 project_activity_periods.project_id,
1028 worktree_extensions.extension,
1029 worktree_extensions.count
1030 FROM project_activity_periods
1031 LEFT OUTER JOIN
1032 worktree_extensions
1033 ON
1034 project_activity_periods.project_id = worktree_extensions.project_id
1035 WHERE
1036 project_activity_periods.user_id = $1 AND
1037 $2 < project_activity_periods.ended_at AND
1038 project_activity_periods.ended_at <= $3
1039 ORDER BY project_activity_periods.id ASC
1040 ";
1041
1042 let mut rows = sqlx::query_as::<
1043 _,
1044 (
1045 PrimitiveDateTime,
1046 i32,
1047 ProjectId,
1048 Option<String>,
1049 Option<i32>,
1050 ),
1051 >(query)
1052 .bind(user_id)
1053 .bind(time_period.start)
1054 .bind(time_period.end)
1055 .fetch(&self.pool);
1056
1057 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1058 while let Some(row) = rows.next().await {
1059 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1060 let ended_at = ended_at.assume_utc();
1061 let duration = Duration::from_millis(duration_millis as u64);
1062 let started_at = ended_at - duration;
1063 let project_time_periods = time_periods.entry(project_id).or_default();
1064
1065 if let Some(prev_duration) = project_time_periods.last_mut() {
1066 if started_at <= prev_duration.end + COALESCE_THRESHOLD
1067 && ended_at >= prev_duration.start
1068 {
1069 prev_duration.end = cmp::max(prev_duration.end, ended_at);
1070 } else {
1071 project_time_periods.push(UserActivityPeriod {
1072 project_id,
1073 start: started_at,
1074 end: ended_at,
1075 extensions: Default::default(),
1076 });
1077 }
1078 } else {
1079 project_time_periods.push(UserActivityPeriod {
1080 project_id,
1081 start: started_at,
1082 end: ended_at,
1083 extensions: Default::default(),
1084 });
1085 }
1086
1087 if let Some((extension, extension_count)) = extension.zip(extension_count) {
1088 project_time_periods
1089 .last_mut()
1090 .unwrap()
1091 .extensions
1092 .insert(extension, extension_count as usize);
1093 }
1094 }
1095
1096 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1097 durations.sort_unstable_by_key(|duration| duration.start);
1098 Ok(durations)
1099 }
1100
1101 // contacts
1102
1103 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1104 let query = "
1105 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1106 FROM contacts
1107 WHERE user_id_a = $1 OR user_id_b = $1;
1108 ";
1109
1110 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1111 .bind(user_id)
1112 .fetch(&self.pool);
1113
1114 let mut contacts = Vec::new();
1115 while let Some(row) = rows.next().await {
1116 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1117
1118 if user_id_a == user_id {
1119 if accepted {
1120 contacts.push(Contact::Accepted {
1121 user_id: user_id_b,
1122 should_notify: should_notify && a_to_b,
1123 });
1124 } else if a_to_b {
1125 contacts.push(Contact::Outgoing { user_id: user_id_b })
1126 } else {
1127 contacts.push(Contact::Incoming {
1128 user_id: user_id_b,
1129 should_notify,
1130 });
1131 }
1132 } else if accepted {
1133 contacts.push(Contact::Accepted {
1134 user_id: user_id_a,
1135 should_notify: should_notify && !a_to_b,
1136 });
1137 } else if a_to_b {
1138 contacts.push(Contact::Incoming {
1139 user_id: user_id_a,
1140 should_notify,
1141 });
1142 } else {
1143 contacts.push(Contact::Outgoing { user_id: user_id_a });
1144 }
1145 }
1146
1147 contacts.sort_unstable_by_key(|contact| contact.user_id());
1148
1149 Ok(contacts)
1150 }
1151
1152 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1153 let (id_a, id_b) = if user_id_1 < user_id_2 {
1154 (user_id_1, user_id_2)
1155 } else {
1156 (user_id_2, user_id_1)
1157 };
1158
1159 let query = "
1160 SELECT 1 FROM contacts
1161 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1162 LIMIT 1
1163 ";
1164 Ok(sqlx::query_scalar::<_, i32>(query)
1165 .bind(id_a.0)
1166 .bind(id_b.0)
1167 .fetch_optional(&self.pool)
1168 .await?
1169 .is_some())
1170 }
1171
1172 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1173 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1174 (sender_id, receiver_id, true)
1175 } else {
1176 (receiver_id, sender_id, false)
1177 };
1178 let query = "
1179 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1180 VALUES ($1, $2, $3, 'f', 't')
1181 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1182 SET
1183 accepted = 't',
1184 should_notify = 'f'
1185 WHERE
1186 NOT contacts.accepted AND
1187 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1188 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1189 ";
1190 let result = sqlx::query(query)
1191 .bind(id_a.0)
1192 .bind(id_b.0)
1193 .bind(a_to_b)
1194 .execute(&self.pool)
1195 .await?;
1196
1197 if result.rows_affected() == 1 {
1198 Ok(())
1199 } else {
1200 Err(anyhow!("contact already requested"))?
1201 }
1202 }
1203
1204 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1205 let (id_a, id_b) = if responder_id < requester_id {
1206 (responder_id, requester_id)
1207 } else {
1208 (requester_id, responder_id)
1209 };
1210 let query = "
1211 DELETE FROM contacts
1212 WHERE user_id_a = $1 AND user_id_b = $2;
1213 ";
1214 let result = sqlx::query(query)
1215 .bind(id_a.0)
1216 .bind(id_b.0)
1217 .execute(&self.pool)
1218 .await?;
1219
1220 if result.rows_affected() == 1 {
1221 Ok(())
1222 } else {
1223 Err(anyhow!("no such contact"))?
1224 }
1225 }
1226
1227 async fn dismiss_contact_notification(
1228 &self,
1229 user_id: UserId,
1230 contact_user_id: UserId,
1231 ) -> Result<()> {
1232 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1233 (user_id, contact_user_id, true)
1234 } else {
1235 (contact_user_id, user_id, false)
1236 };
1237
1238 let query = "
1239 UPDATE contacts
1240 SET should_notify = 'f'
1241 WHERE
1242 user_id_a = $1 AND user_id_b = $2 AND
1243 (
1244 (a_to_b = $3 AND accepted) OR
1245 (a_to_b != $3 AND NOT accepted)
1246 );
1247 ";
1248
1249 let result = sqlx::query(query)
1250 .bind(id_a.0)
1251 .bind(id_b.0)
1252 .bind(a_to_b)
1253 .execute(&self.pool)
1254 .await?;
1255
1256 if result.rows_affected() == 0 {
1257 Err(anyhow!("no such contact request"))?;
1258 }
1259
1260 Ok(())
1261 }
1262
1263 async fn respond_to_contact_request(
1264 &self,
1265 responder_id: UserId,
1266 requester_id: UserId,
1267 accept: bool,
1268 ) -> Result<()> {
1269 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1270 (responder_id, requester_id, false)
1271 } else {
1272 (requester_id, responder_id, true)
1273 };
1274 let result = if accept {
1275 let query = "
1276 UPDATE contacts
1277 SET accepted = 't', should_notify = 't'
1278 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1279 ";
1280 sqlx::query(query)
1281 .bind(id_a.0)
1282 .bind(id_b.0)
1283 .bind(a_to_b)
1284 .execute(&self.pool)
1285 .await?
1286 } else {
1287 let query = "
1288 DELETE FROM contacts
1289 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1290 ";
1291 sqlx::query(query)
1292 .bind(id_a.0)
1293 .bind(id_b.0)
1294 .bind(a_to_b)
1295 .execute(&self.pool)
1296 .await?
1297 };
1298 if result.rows_affected() == 1 {
1299 Ok(())
1300 } else {
1301 Err(anyhow!("no such contact request"))?
1302 }
1303 }
1304
1305 // access tokens
1306
1307 async fn create_access_token_hash(
1308 &self,
1309 user_id: UserId,
1310 access_token_hash: &str,
1311 max_access_token_count: usize,
1312 ) -> Result<()> {
1313 let insert_query = "
1314 INSERT INTO access_tokens (user_id, hash)
1315 VALUES ($1, $2);
1316 ";
1317 let cleanup_query = "
1318 DELETE FROM access_tokens
1319 WHERE id IN (
1320 SELECT id from access_tokens
1321 WHERE user_id = $1
1322 ORDER BY id DESC
1323 OFFSET $3
1324 )
1325 ";
1326
1327 let mut tx = self.pool.begin().await?;
1328 sqlx::query(insert_query)
1329 .bind(user_id.0)
1330 .bind(access_token_hash)
1331 .execute(&mut tx)
1332 .await?;
1333 sqlx::query(cleanup_query)
1334 .bind(user_id.0)
1335 .bind(access_token_hash)
1336 .bind(max_access_token_count as i32)
1337 .execute(&mut tx)
1338 .await?;
1339 Ok(tx.commit().await?)
1340 }
1341
1342 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1343 let query = "
1344 SELECT hash
1345 FROM access_tokens
1346 WHERE user_id = $1
1347 ORDER BY id DESC
1348 ";
1349 Ok(sqlx::query_scalar(query)
1350 .bind(user_id.0)
1351 .fetch_all(&self.pool)
1352 .await?)
1353 }
1354
1355 // orgs
1356
1357 #[allow(unused)] // Help rust-analyzer
1358 #[cfg(any(test, feature = "seed-support"))]
1359 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1360 let query = "
1361 SELECT *
1362 FROM orgs
1363 WHERE slug = $1
1364 ";
1365 Ok(sqlx::query_as(query)
1366 .bind(slug)
1367 .fetch_optional(&self.pool)
1368 .await?)
1369 }
1370
1371 #[cfg(any(test, feature = "seed-support"))]
1372 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1373 let query = "
1374 INSERT INTO orgs (name, slug)
1375 VALUES ($1, $2)
1376 RETURNING id
1377 ";
1378 Ok(sqlx::query_scalar(query)
1379 .bind(name)
1380 .bind(slug)
1381 .fetch_one(&self.pool)
1382 .await
1383 .map(OrgId)?)
1384 }
1385
1386 #[cfg(any(test, feature = "seed-support"))]
1387 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1388 let query = "
1389 INSERT INTO org_memberships (org_id, user_id, admin)
1390 VALUES ($1, $2, $3)
1391 ON CONFLICT DO NOTHING
1392 ";
1393 Ok(sqlx::query(query)
1394 .bind(org_id.0)
1395 .bind(user_id.0)
1396 .bind(is_admin)
1397 .execute(&self.pool)
1398 .await
1399 .map(drop)?)
1400 }
1401
1402 // channels
1403
1404 #[cfg(any(test, feature = "seed-support"))]
1405 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1406 let query = "
1407 INSERT INTO channels (owner_id, owner_is_user, name)
1408 VALUES ($1, false, $2)
1409 RETURNING id
1410 ";
1411 Ok(sqlx::query_scalar(query)
1412 .bind(org_id.0)
1413 .bind(name)
1414 .fetch_one(&self.pool)
1415 .await
1416 .map(ChannelId)?)
1417 }
1418
1419 #[allow(unused)] // Help rust-analyzer
1420 #[cfg(any(test, feature = "seed-support"))]
1421 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1422 let query = "
1423 SELECT *
1424 FROM channels
1425 WHERE
1426 channels.owner_is_user = false AND
1427 channels.owner_id = $1
1428 ";
1429 Ok(sqlx::query_as(query)
1430 .bind(org_id.0)
1431 .fetch_all(&self.pool)
1432 .await?)
1433 }
1434
1435 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1436 let query = "
1437 SELECT
1438 channels.*
1439 FROM
1440 channel_memberships, channels
1441 WHERE
1442 channel_memberships.user_id = $1 AND
1443 channel_memberships.channel_id = channels.id
1444 ";
1445 Ok(sqlx::query_as(query)
1446 .bind(user_id.0)
1447 .fetch_all(&self.pool)
1448 .await?)
1449 }
1450
1451 async fn can_user_access_channel(
1452 &self,
1453 user_id: UserId,
1454 channel_id: ChannelId,
1455 ) -> Result<bool> {
1456 let query = "
1457 SELECT id
1458 FROM channel_memberships
1459 WHERE user_id = $1 AND channel_id = $2
1460 LIMIT 1
1461 ";
1462 Ok(sqlx::query_scalar::<_, i32>(query)
1463 .bind(user_id.0)
1464 .bind(channel_id.0)
1465 .fetch_optional(&self.pool)
1466 .await
1467 .map(|e| e.is_some())?)
1468 }
1469
1470 #[cfg(any(test, feature = "seed-support"))]
1471 async fn add_channel_member(
1472 &self,
1473 channel_id: ChannelId,
1474 user_id: UserId,
1475 is_admin: bool,
1476 ) -> Result<()> {
1477 let query = "
1478 INSERT INTO channel_memberships (channel_id, user_id, admin)
1479 VALUES ($1, $2, $3)
1480 ON CONFLICT DO NOTHING
1481 ";
1482 Ok(sqlx::query(query)
1483 .bind(channel_id.0)
1484 .bind(user_id.0)
1485 .bind(is_admin)
1486 .execute(&self.pool)
1487 .await
1488 .map(drop)?)
1489 }
1490
1491 // messages
1492
1493 async fn create_channel_message(
1494 &self,
1495 channel_id: ChannelId,
1496 sender_id: UserId,
1497 body: &str,
1498 timestamp: OffsetDateTime,
1499 nonce: u128,
1500 ) -> Result<MessageId> {
1501 let query = "
1502 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1503 VALUES ($1, $2, $3, $4, $5)
1504 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1505 RETURNING id
1506 ";
1507 Ok(sqlx::query_scalar(query)
1508 .bind(channel_id.0)
1509 .bind(sender_id.0)
1510 .bind(body)
1511 .bind(timestamp)
1512 .bind(Uuid::from_u128(nonce))
1513 .fetch_one(&self.pool)
1514 .await
1515 .map(MessageId)?)
1516 }
1517
1518 async fn get_channel_messages(
1519 &self,
1520 channel_id: ChannelId,
1521 count: usize,
1522 before_id: Option<MessageId>,
1523 ) -> Result<Vec<ChannelMessage>> {
1524 let query = r#"
1525 SELECT * FROM (
1526 SELECT
1527 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1528 FROM
1529 channel_messages
1530 WHERE
1531 channel_id = $1 AND
1532 id < $2
1533 ORDER BY id DESC
1534 LIMIT $3
1535 ) as recent_messages
1536 ORDER BY id ASC
1537 "#;
1538 Ok(sqlx::query_as(query)
1539 .bind(channel_id.0)
1540 .bind(before_id.unwrap_or(MessageId::MAX))
1541 .bind(count as i64)
1542 .fetch_all(&self.pool)
1543 .await?)
1544 }
1545
1546 #[cfg(test)]
1547 async fn teardown(&self, url: &str) {
1548 use util::ResultExt;
1549
1550 let query = "
1551 SELECT pg_terminate_backend(pg_stat_activity.pid)
1552 FROM pg_stat_activity
1553 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1554 ";
1555 sqlx::query(query).execute(&self.pool).await.log_err();
1556 self.pool.close().await;
1557 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1558 .await
1559 .log_err();
1560 }
1561
1562 #[cfg(test)]
1563 fn as_fake(&self) -> Option<&FakeDb> {
1564 None
1565 }
1566}
1567
1568macro_rules! id_type {
1569 ($name:ident) => {
1570 #[derive(
1571 Clone,
1572 Copy,
1573 Debug,
1574 Default,
1575 PartialEq,
1576 Eq,
1577 PartialOrd,
1578 Ord,
1579 Hash,
1580 sqlx::Type,
1581 Serialize,
1582 Deserialize,
1583 )]
1584 #[sqlx(transparent)]
1585 #[serde(transparent)]
1586 pub struct $name(pub i32);
1587
1588 impl $name {
1589 #[allow(unused)]
1590 pub const MAX: Self = Self(i32::MAX);
1591
1592 #[allow(unused)]
1593 pub fn from_proto(value: u64) -> Self {
1594 Self(value as i32)
1595 }
1596
1597 #[allow(unused)]
1598 pub fn to_proto(self) -> u64 {
1599 self.0 as u64
1600 }
1601 }
1602
1603 impl std::fmt::Display for $name {
1604 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1605 self.0.fmt(f)
1606 }
1607 }
1608 };
1609}
1610
1611id_type!(UserId);
1612#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1613pub struct User {
1614 pub id: UserId,
1615 pub github_login: String,
1616 pub github_user_id: Option<i32>,
1617 pub email_address: Option<String>,
1618 pub admin: bool,
1619 pub invite_code: Option<String>,
1620 pub invite_count: i32,
1621 pub connected_once: bool,
1622}
1623
1624id_type!(ProjectId);
1625#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1626pub struct Project {
1627 pub id: ProjectId,
1628 pub host_user_id: UserId,
1629 pub unregistered: bool,
1630}
1631
1632#[derive(Clone, Debug, PartialEq, Serialize)]
1633pub struct UserActivitySummary {
1634 pub id: UserId,
1635 pub github_login: String,
1636 pub project_activity: Vec<ProjectActivitySummary>,
1637}
1638
1639#[derive(Clone, Debug, PartialEq, Serialize)]
1640pub struct ProjectActivitySummary {
1641 pub id: ProjectId,
1642 pub duration: Duration,
1643 pub max_collaborators: usize,
1644}
1645
1646#[derive(Clone, Debug, PartialEq, Serialize)]
1647pub struct UserActivityPeriod {
1648 pub project_id: ProjectId,
1649 #[serde(with = "time::serde::iso8601")]
1650 pub start: OffsetDateTime,
1651 #[serde(with = "time::serde::iso8601")]
1652 pub end: OffsetDateTime,
1653 pub extensions: HashMap<String, usize>,
1654}
1655
1656id_type!(OrgId);
1657#[derive(FromRow)]
1658pub struct Org {
1659 pub id: OrgId,
1660 pub name: String,
1661 pub slug: String,
1662}
1663
1664id_type!(ChannelId);
1665#[derive(Clone, Debug, FromRow, Serialize)]
1666pub struct Channel {
1667 pub id: ChannelId,
1668 pub name: String,
1669 pub owner_id: i32,
1670 pub owner_is_user: bool,
1671}
1672
1673id_type!(MessageId);
1674#[derive(Clone, Debug, FromRow)]
1675pub struct ChannelMessage {
1676 pub id: MessageId,
1677 pub channel_id: ChannelId,
1678 pub sender_id: UserId,
1679 pub body: String,
1680 pub sent_at: OffsetDateTime,
1681 pub nonce: Uuid,
1682}
1683
1684#[derive(Clone, Debug, PartialEq, Eq)]
1685pub enum Contact {
1686 Accepted {
1687 user_id: UserId,
1688 should_notify: bool,
1689 },
1690 Outgoing {
1691 user_id: UserId,
1692 },
1693 Incoming {
1694 user_id: UserId,
1695 should_notify: bool,
1696 },
1697}
1698
1699impl Contact {
1700 pub fn user_id(&self) -> UserId {
1701 match self {
1702 Contact::Accepted { user_id, .. } => *user_id,
1703 Contact::Outgoing { user_id } => *user_id,
1704 Contact::Incoming { user_id, .. } => *user_id,
1705 }
1706 }
1707}
1708
1709#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1710pub struct IncomingContactRequest {
1711 pub requester_id: UserId,
1712 pub should_notify: bool,
1713}
1714
1715#[derive(Clone, Deserialize)]
1716pub struct Signup {
1717 pub email_address: String,
1718 pub platform_mac: bool,
1719 pub platform_windows: bool,
1720 pub platform_linux: bool,
1721 pub editor_features: Vec<String>,
1722 pub programming_languages: Vec<String>,
1723 pub device_id: Option<String>,
1724}
1725
1726#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1727pub struct WaitlistSummary {
1728 #[sqlx(default)]
1729 pub count: i64,
1730 #[sqlx(default)]
1731 pub linux_count: i64,
1732 #[sqlx(default)]
1733 pub mac_count: i64,
1734 #[sqlx(default)]
1735 pub windows_count: i64,
1736 #[sqlx(default)]
1737 pub unknown_count: i64,
1738}
1739
1740#[derive(Clone, FromRow, PartialEq, Debug, Serialize, Deserialize)]
1741pub struct Invite {
1742 pub email_address: String,
1743 pub email_confirmation_code: String,
1744}
1745
1746#[derive(Serialize, Debug, PartialEq)]
1747pub struct UnsentInvite {
1748 #[serde(flatten)]
1749 pub invite: Invite,
1750 pub platform_unknown: bool,
1751}
1752
1753#[derive(Debug, Serialize, Deserialize)]
1754pub struct NewUserParams {
1755 pub github_login: String,
1756 pub github_user_id: i32,
1757 pub invite_count: i32,
1758}
1759
1760#[derive(Debug)]
1761pub struct NewUserResult {
1762 pub user_id: UserId,
1763 pub metrics_id: String,
1764 pub inviting_user_id: Option<UserId>,
1765 pub signup_device_id: Option<String>,
1766}
1767
1768fn random_invite_code() -> String {
1769 nanoid::nanoid!(16)
1770}
1771
1772fn random_email_confirmation_code() -> String {
1773 nanoid::nanoid!(64)
1774}
1775
1776#[cfg(test)]
1777pub use test::*;
1778
1779#[cfg(test)]
1780mod test {
1781 use super::*;
1782 use anyhow::anyhow;
1783 use collections::BTreeMap;
1784 use gpui::executor::Background;
1785 use lazy_static::lazy_static;
1786 use parking_lot::Mutex;
1787 use rand::prelude::*;
1788 use sqlx::{
1789 migrate::{MigrateDatabase, Migrator},
1790 Postgres,
1791 };
1792 use std::{path::Path, sync::Arc};
1793 use util::post_inc;
1794
1795 pub struct FakeDb {
1796 background: Arc<Background>,
1797 pub users: Mutex<BTreeMap<UserId, User>>,
1798 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1799 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1800 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1801 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1802 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1803 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1804 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1805 pub contacts: Mutex<Vec<FakeContact>>,
1806 next_channel_message_id: Mutex<i32>,
1807 next_user_id: Mutex<i32>,
1808 next_org_id: Mutex<i32>,
1809 next_channel_id: Mutex<i32>,
1810 next_project_id: Mutex<i32>,
1811 }
1812
1813 #[derive(Debug)]
1814 pub struct FakeContact {
1815 pub requester_id: UserId,
1816 pub responder_id: UserId,
1817 pub accepted: bool,
1818 pub should_notify: bool,
1819 }
1820
1821 impl FakeDb {
1822 pub fn new(background: Arc<Background>) -> Self {
1823 Self {
1824 background,
1825 users: Default::default(),
1826 next_user_id: Mutex::new(0),
1827 projects: Default::default(),
1828 worktree_extensions: Default::default(),
1829 next_project_id: Mutex::new(1),
1830 orgs: Default::default(),
1831 next_org_id: Mutex::new(1),
1832 org_memberships: Default::default(),
1833 channels: Default::default(),
1834 next_channel_id: Mutex::new(1),
1835 channel_memberships: Default::default(),
1836 channel_messages: Default::default(),
1837 next_channel_message_id: Mutex::new(1),
1838 contacts: Default::default(),
1839 }
1840 }
1841 }
1842
1843 #[async_trait]
1844 impl Db for FakeDb {
1845 async fn create_user(
1846 &self,
1847 email_address: &str,
1848 admin: bool,
1849 params: NewUserParams,
1850 ) -> Result<NewUserResult> {
1851 self.background.simulate_random_delay().await;
1852
1853 let mut users = self.users.lock();
1854 let user_id = if let Some(user) = users
1855 .values()
1856 .find(|user| user.github_login == params.github_login)
1857 {
1858 user.id
1859 } else {
1860 let id = post_inc(&mut *self.next_user_id.lock());
1861 let user_id = UserId(id);
1862 users.insert(
1863 user_id,
1864 User {
1865 id: user_id,
1866 github_login: params.github_login,
1867 github_user_id: Some(params.github_user_id),
1868 email_address: Some(email_address.to_string()),
1869 admin,
1870 invite_code: None,
1871 invite_count: 0,
1872 connected_once: false,
1873 },
1874 );
1875 user_id
1876 };
1877 Ok(NewUserResult {
1878 user_id,
1879 metrics_id: "the-metrics-id".to_string(),
1880 inviting_user_id: None,
1881 signup_device_id: None,
1882 })
1883 }
1884
1885 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1886 unimplemented!()
1887 }
1888
1889 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1890 unimplemented!()
1891 }
1892
1893 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1894 self.background.simulate_random_delay().await;
1895 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1896 }
1897
1898 async fn get_user_metrics_id(&self, _id: UserId) -> Result<String> {
1899 Ok("the-metrics-id".to_string())
1900 }
1901
1902 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1903 self.background.simulate_random_delay().await;
1904 let users = self.users.lock();
1905 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1906 }
1907
1908 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
1909 unimplemented!()
1910 }
1911
1912 async fn get_user_by_github_account(
1913 &self,
1914 github_login: &str,
1915 github_user_id: Option<i32>,
1916 ) -> Result<Option<User>> {
1917 self.background.simulate_random_delay().await;
1918 if let Some(github_user_id) = github_user_id {
1919 for user in self.users.lock().values_mut() {
1920 if user.github_user_id == Some(github_user_id) {
1921 user.github_login = github_login.into();
1922 return Ok(Some(user.clone()));
1923 }
1924 if user.github_login == github_login {
1925 user.github_user_id = Some(github_user_id);
1926 return Ok(Some(user.clone()));
1927 }
1928 }
1929 Ok(None)
1930 } else {
1931 Ok(self
1932 .users
1933 .lock()
1934 .values()
1935 .find(|user| user.github_login == github_login)
1936 .cloned())
1937 }
1938 }
1939
1940 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1941 unimplemented!()
1942 }
1943
1944 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1945 self.background.simulate_random_delay().await;
1946 let mut users = self.users.lock();
1947 let mut user = users
1948 .get_mut(&id)
1949 .ok_or_else(|| anyhow!("user not found"))?;
1950 user.connected_once = connected_once;
1951 Ok(())
1952 }
1953
1954 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1955 unimplemented!()
1956 }
1957
1958 // signups
1959
1960 async fn create_signup(&self, _signup: Signup) -> Result<()> {
1961 unimplemented!()
1962 }
1963
1964 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
1965 unimplemented!()
1966 }
1967
1968 async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<UnsentInvite>> {
1969 unimplemented!()
1970 }
1971
1972 async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
1973 unimplemented!()
1974 }
1975
1976 async fn create_user_from_invite(
1977 &self,
1978 _invite: &Invite,
1979 _user: NewUserParams,
1980 ) -> Result<NewUserResult> {
1981 unimplemented!()
1982 }
1983
1984 // invite codes
1985
1986 async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
1987 unimplemented!()
1988 }
1989
1990 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1991 self.background.simulate_random_delay().await;
1992 Ok(None)
1993 }
1994
1995 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1996 unimplemented!()
1997 }
1998
1999 async fn create_invite_from_code(
2000 &self,
2001 _code: &str,
2002 _email_address: &str,
2003 _device_id: Option<&str>,
2004 ) -> Result<Invite> {
2005 unimplemented!()
2006 }
2007
2008 // projects
2009
2010 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
2011 self.background.simulate_random_delay().await;
2012 if !self.users.lock().contains_key(&host_user_id) {
2013 Err(anyhow!("no such user"))?;
2014 }
2015
2016 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
2017 self.projects.lock().insert(
2018 project_id,
2019 Project {
2020 id: project_id,
2021 host_user_id,
2022 unregistered: false,
2023 },
2024 );
2025 Ok(project_id)
2026 }
2027
2028 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
2029 self.background.simulate_random_delay().await;
2030 self.projects
2031 .lock()
2032 .get_mut(&project_id)
2033 .ok_or_else(|| anyhow!("no such project"))?
2034 .unregistered = true;
2035 Ok(())
2036 }
2037
2038 async fn update_worktree_extensions(
2039 &self,
2040 project_id: ProjectId,
2041 worktree_id: u64,
2042 extensions: HashMap<String, u32>,
2043 ) -> Result<()> {
2044 self.background.simulate_random_delay().await;
2045 if !self.projects.lock().contains_key(&project_id) {
2046 Err(anyhow!("no such project"))?;
2047 }
2048
2049 for (extension, count) in extensions {
2050 self.worktree_extensions
2051 .lock()
2052 .insert((project_id, worktree_id, extension), count);
2053 }
2054
2055 Ok(())
2056 }
2057
2058 async fn get_project_extensions(
2059 &self,
2060 _project_id: ProjectId,
2061 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2062 unimplemented!()
2063 }
2064
2065 async fn record_user_activity(
2066 &self,
2067 _time_period: Range<OffsetDateTime>,
2068 _active_projects: &[(UserId, ProjectId)],
2069 ) -> Result<()> {
2070 unimplemented!()
2071 }
2072
2073 async fn get_active_user_count(
2074 &self,
2075 _time_period: Range<OffsetDateTime>,
2076 _min_duration: Duration,
2077 _only_collaborative: bool,
2078 ) -> Result<usize> {
2079 unimplemented!()
2080 }
2081
2082 async fn get_top_users_activity_summary(
2083 &self,
2084 _time_period: Range<OffsetDateTime>,
2085 _limit: usize,
2086 ) -> Result<Vec<UserActivitySummary>> {
2087 unimplemented!()
2088 }
2089
2090 async fn get_user_activity_timeline(
2091 &self,
2092 _time_period: Range<OffsetDateTime>,
2093 _user_id: UserId,
2094 ) -> Result<Vec<UserActivityPeriod>> {
2095 unimplemented!()
2096 }
2097
2098 // contacts
2099
2100 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2101 self.background.simulate_random_delay().await;
2102 let mut contacts = Vec::new();
2103
2104 for contact in self.contacts.lock().iter() {
2105 if contact.requester_id == id {
2106 if contact.accepted {
2107 contacts.push(Contact::Accepted {
2108 user_id: contact.responder_id,
2109 should_notify: contact.should_notify,
2110 });
2111 } else {
2112 contacts.push(Contact::Outgoing {
2113 user_id: contact.responder_id,
2114 });
2115 }
2116 } else if contact.responder_id == id {
2117 if contact.accepted {
2118 contacts.push(Contact::Accepted {
2119 user_id: contact.requester_id,
2120 should_notify: false,
2121 });
2122 } else {
2123 contacts.push(Contact::Incoming {
2124 user_id: contact.requester_id,
2125 should_notify: contact.should_notify,
2126 });
2127 }
2128 }
2129 }
2130
2131 contacts.sort_unstable_by_key(|contact| contact.user_id());
2132 Ok(contacts)
2133 }
2134
2135 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2136 self.background.simulate_random_delay().await;
2137 Ok(self.contacts.lock().iter().any(|contact| {
2138 contact.accepted
2139 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2140 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2141 }))
2142 }
2143
2144 async fn send_contact_request(
2145 &self,
2146 requester_id: UserId,
2147 responder_id: UserId,
2148 ) -> Result<()> {
2149 self.background.simulate_random_delay().await;
2150 let mut contacts = self.contacts.lock();
2151 for contact in contacts.iter_mut() {
2152 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2153 if contact.accepted {
2154 Err(anyhow!("contact already exists"))?;
2155 } else {
2156 Err(anyhow!("contact already requested"))?;
2157 }
2158 }
2159 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2160 if contact.accepted {
2161 Err(anyhow!("contact already exists"))?;
2162 } else {
2163 contact.accepted = true;
2164 contact.should_notify = false;
2165 return Ok(());
2166 }
2167 }
2168 }
2169 contacts.push(FakeContact {
2170 requester_id,
2171 responder_id,
2172 accepted: false,
2173 should_notify: true,
2174 });
2175 Ok(())
2176 }
2177
2178 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2179 self.background.simulate_random_delay().await;
2180 self.contacts.lock().retain(|contact| {
2181 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2182 });
2183 Ok(())
2184 }
2185
2186 async fn dismiss_contact_notification(
2187 &self,
2188 user_id: UserId,
2189 contact_user_id: UserId,
2190 ) -> Result<()> {
2191 self.background.simulate_random_delay().await;
2192 let mut contacts = self.contacts.lock();
2193 for contact in contacts.iter_mut() {
2194 if contact.requester_id == contact_user_id
2195 && contact.responder_id == user_id
2196 && !contact.accepted
2197 {
2198 contact.should_notify = false;
2199 return Ok(());
2200 }
2201 if contact.requester_id == user_id
2202 && contact.responder_id == contact_user_id
2203 && contact.accepted
2204 {
2205 contact.should_notify = false;
2206 return Ok(());
2207 }
2208 }
2209 Err(anyhow!("no such notification"))?
2210 }
2211
2212 async fn respond_to_contact_request(
2213 &self,
2214 responder_id: UserId,
2215 requester_id: UserId,
2216 accept: bool,
2217 ) -> Result<()> {
2218 self.background.simulate_random_delay().await;
2219 let mut contacts = self.contacts.lock();
2220 for (ix, contact) in contacts.iter_mut().enumerate() {
2221 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2222 if contact.accepted {
2223 Err(anyhow!("contact already confirmed"))?;
2224 }
2225 if accept {
2226 contact.accepted = true;
2227 contact.should_notify = true;
2228 } else {
2229 contacts.remove(ix);
2230 }
2231 return Ok(());
2232 }
2233 }
2234 Err(anyhow!("no such contact request"))?
2235 }
2236
2237 async fn create_access_token_hash(
2238 &self,
2239 _user_id: UserId,
2240 _access_token_hash: &str,
2241 _max_access_token_count: usize,
2242 ) -> Result<()> {
2243 unimplemented!()
2244 }
2245
2246 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2247 unimplemented!()
2248 }
2249
2250 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2251 unimplemented!()
2252 }
2253
2254 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2255 self.background.simulate_random_delay().await;
2256 let mut orgs = self.orgs.lock();
2257 if orgs.values().any(|org| org.slug == slug) {
2258 Err(anyhow!("org already exists"))?
2259 } else {
2260 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2261 orgs.insert(
2262 org_id,
2263 Org {
2264 id: org_id,
2265 name: name.to_string(),
2266 slug: slug.to_string(),
2267 },
2268 );
2269 Ok(org_id)
2270 }
2271 }
2272
2273 async fn add_org_member(
2274 &self,
2275 org_id: OrgId,
2276 user_id: UserId,
2277 is_admin: bool,
2278 ) -> Result<()> {
2279 self.background.simulate_random_delay().await;
2280 if !self.orgs.lock().contains_key(&org_id) {
2281 Err(anyhow!("org does not exist"))?;
2282 }
2283 if !self.users.lock().contains_key(&user_id) {
2284 Err(anyhow!("user does not exist"))?;
2285 }
2286
2287 self.org_memberships
2288 .lock()
2289 .entry((org_id, user_id))
2290 .or_insert(is_admin);
2291 Ok(())
2292 }
2293
2294 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2295 self.background.simulate_random_delay().await;
2296 if !self.orgs.lock().contains_key(&org_id) {
2297 Err(anyhow!("org does not exist"))?;
2298 }
2299
2300 let mut channels = self.channels.lock();
2301 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2302 channels.insert(
2303 channel_id,
2304 Channel {
2305 id: channel_id,
2306 name: name.to_string(),
2307 owner_id: org_id.0,
2308 owner_is_user: false,
2309 },
2310 );
2311 Ok(channel_id)
2312 }
2313
2314 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2315 self.background.simulate_random_delay().await;
2316 Ok(self
2317 .channels
2318 .lock()
2319 .values()
2320 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2321 .cloned()
2322 .collect())
2323 }
2324
2325 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2326 self.background.simulate_random_delay().await;
2327 let channels = self.channels.lock();
2328 let memberships = self.channel_memberships.lock();
2329 Ok(channels
2330 .values()
2331 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2332 .cloned()
2333 .collect())
2334 }
2335
2336 async fn can_user_access_channel(
2337 &self,
2338 user_id: UserId,
2339 channel_id: ChannelId,
2340 ) -> Result<bool> {
2341 self.background.simulate_random_delay().await;
2342 Ok(self
2343 .channel_memberships
2344 .lock()
2345 .contains_key(&(channel_id, user_id)))
2346 }
2347
2348 async fn add_channel_member(
2349 &self,
2350 channel_id: ChannelId,
2351 user_id: UserId,
2352 is_admin: bool,
2353 ) -> Result<()> {
2354 self.background.simulate_random_delay().await;
2355 if !self.channels.lock().contains_key(&channel_id) {
2356 Err(anyhow!("channel does not exist"))?;
2357 }
2358 if !self.users.lock().contains_key(&user_id) {
2359 Err(anyhow!("user does not exist"))?;
2360 }
2361
2362 self.channel_memberships
2363 .lock()
2364 .entry((channel_id, user_id))
2365 .or_insert(is_admin);
2366 Ok(())
2367 }
2368
2369 async fn create_channel_message(
2370 &self,
2371 channel_id: ChannelId,
2372 sender_id: UserId,
2373 body: &str,
2374 timestamp: OffsetDateTime,
2375 nonce: u128,
2376 ) -> Result<MessageId> {
2377 self.background.simulate_random_delay().await;
2378 if !self.channels.lock().contains_key(&channel_id) {
2379 Err(anyhow!("channel does not exist"))?;
2380 }
2381 if !self.users.lock().contains_key(&sender_id) {
2382 Err(anyhow!("user does not exist"))?;
2383 }
2384
2385 let mut messages = self.channel_messages.lock();
2386 if let Some(message) = messages
2387 .values()
2388 .find(|message| message.nonce.as_u128() == nonce)
2389 {
2390 Ok(message.id)
2391 } else {
2392 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2393 messages.insert(
2394 message_id,
2395 ChannelMessage {
2396 id: message_id,
2397 channel_id,
2398 sender_id,
2399 body: body.to_string(),
2400 sent_at: timestamp,
2401 nonce: Uuid::from_u128(nonce),
2402 },
2403 );
2404 Ok(message_id)
2405 }
2406 }
2407
2408 async fn get_channel_messages(
2409 &self,
2410 channel_id: ChannelId,
2411 count: usize,
2412 before_id: Option<MessageId>,
2413 ) -> Result<Vec<ChannelMessage>> {
2414 self.background.simulate_random_delay().await;
2415 let mut messages = self
2416 .channel_messages
2417 .lock()
2418 .values()
2419 .rev()
2420 .filter(|message| {
2421 message.channel_id == channel_id
2422 && message.id < before_id.unwrap_or(MessageId::MAX)
2423 })
2424 .take(count)
2425 .cloned()
2426 .collect::<Vec<_>>();
2427 messages.sort_unstable_by_key(|message| message.id);
2428 Ok(messages)
2429 }
2430
2431 async fn teardown(&self, _: &str) {}
2432
2433 #[cfg(test)]
2434 fn as_fake(&self) -> Option<&FakeDb> {
2435 Some(self)
2436 }
2437 }
2438
2439 pub struct TestDb {
2440 pub db: Option<Arc<dyn Db>>,
2441 pub url: String,
2442 }
2443
2444 impl TestDb {
2445 #[allow(clippy::await_holding_lock)]
2446 pub async fn postgres() -> Self {
2447 lazy_static! {
2448 static ref LOCK: Mutex<()> = Mutex::new(());
2449 }
2450
2451 let _guard = LOCK.lock();
2452 let mut rng = StdRng::from_entropy();
2453 let name = format!("zed-test-{}", rng.gen::<u128>());
2454 let url = format!("postgres://postgres@localhost/{}", name);
2455 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2456 Postgres::create_database(&url)
2457 .await
2458 .expect("failed to create test db");
2459 let db = PostgresDb::new(&url, 5).await.unwrap();
2460 let migrator = Migrator::new(migrations_path).await.unwrap();
2461 migrator.run(&db.pool).await.unwrap();
2462 Self {
2463 db: Some(Arc::new(db)),
2464 url,
2465 }
2466 }
2467
2468 pub fn fake(background: Arc<Background>) -> Self {
2469 Self {
2470 db: Some(Arc::new(FakeDb::new(background))),
2471 url: Default::default(),
2472 }
2473 }
2474
2475 pub fn db(&self) -> &Arc<dyn Db> {
2476 self.db.as_ref().unwrap()
2477 }
2478 }
2479
2480 impl Drop for TestDb {
2481 fn drop(&mut self) {
2482 if let Some(db) = self.db.take() {
2483 futures::executor::block_on(db.teardown(&self.url));
2484 }
2485 }
2486 }
2487}