1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use collections::HashMap;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow, QueryBuilder};
10use std::{cmp, ops::Range, time::Duration};
11use time::{OffsetDateTime, PrimitiveDateTime};
12
13#[async_trait]
14pub trait Db: Send + Sync {
15 async fn create_user(
16 &self,
17 email_address: &str,
18 admin: bool,
19 params: NewUserParams,
20 ) -> Result<UserId>;
21 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
22 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
23 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
24 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
25 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
26 async fn get_user_by_github_account(
27 &self,
28 github_login: &str,
29 github_user_id: Option<i32>,
30 ) -> Result<Option<User>>;
31 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
32 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
33 async fn destroy_user(&self, id: UserId) -> Result<()>;
34
35 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
36 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
37 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
38 async fn create_invite_from_code(
39 &self,
40 code: &str,
41 email_address: &str,
42 device_id: Option<&str>,
43 ) -> Result<Invite>;
44
45 async fn create_signup(&self, signup: Signup) -> Result<()>;
46 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
47 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
48 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
49 async fn create_user_from_invite(
50 &self,
51 invite: &Invite,
52 user: NewUserParams,
53 ) -> Result<NewUserResult>;
54
55 /// Registers a new project for the given user.
56 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
57
58 /// Unregisters a project for the given project id.
59 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
60
61 /// Update file counts by extension for the given project and worktree.
62 async fn update_worktree_extensions(
63 &self,
64 project_id: ProjectId,
65 worktree_id: u64,
66 extensions: HashMap<String, u32>,
67 ) -> Result<()>;
68
69 /// Get the file counts on the given project keyed by their worktree and extension.
70 async fn get_project_extensions(
71 &self,
72 project_id: ProjectId,
73 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
74
75 /// Record which users have been active in which projects during
76 /// a given period of time.
77 async fn record_user_activity(
78 &self,
79 time_period: Range<OffsetDateTime>,
80 active_projects: &[(UserId, ProjectId)],
81 ) -> Result<()>;
82
83 /// Get the number of users who have been active in the given
84 /// time period for at least the given time duration.
85 async fn get_active_user_count(
86 &self,
87 time_period: Range<OffsetDateTime>,
88 min_duration: Duration,
89 only_collaborative: bool,
90 ) -> Result<usize>;
91
92 /// Get the users that have been most active during the given time period,
93 /// along with the amount of time they have been active in each project.
94 async fn get_top_users_activity_summary(
95 &self,
96 time_period: Range<OffsetDateTime>,
97 max_user_count: usize,
98 ) -> Result<Vec<UserActivitySummary>>;
99
100 /// Get the project activity for the given user and time period.
101 async fn get_user_activity_timeline(
102 &self,
103 time_period: Range<OffsetDateTime>,
104 user_id: UserId,
105 ) -> Result<Vec<UserActivityPeriod>>;
106
107 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
108 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
109 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
110 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
111 async fn dismiss_contact_notification(
112 &self,
113 responder_id: UserId,
114 requester_id: UserId,
115 ) -> Result<()>;
116 async fn respond_to_contact_request(
117 &self,
118 responder_id: UserId,
119 requester_id: UserId,
120 accept: bool,
121 ) -> Result<()>;
122
123 async fn create_access_token_hash(
124 &self,
125 user_id: UserId,
126 access_token_hash: &str,
127 max_access_token_count: usize,
128 ) -> Result<()>;
129 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
130
131 #[cfg(any(test, feature = "seed-support"))]
132 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
133 #[cfg(any(test, feature = "seed-support"))]
134 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
135 #[cfg(any(test, feature = "seed-support"))]
136 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
137 #[cfg(any(test, feature = "seed-support"))]
138 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
139 #[cfg(any(test, feature = "seed-support"))]
140
141 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
142 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
143 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
144 -> Result<bool>;
145
146 #[cfg(any(test, feature = "seed-support"))]
147 async fn add_channel_member(
148 &self,
149 channel_id: ChannelId,
150 user_id: UserId,
151 is_admin: bool,
152 ) -> Result<()>;
153 async fn create_channel_message(
154 &self,
155 channel_id: ChannelId,
156 sender_id: UserId,
157 body: &str,
158 timestamp: OffsetDateTime,
159 nonce: u128,
160 ) -> Result<MessageId>;
161 async fn get_channel_messages(
162 &self,
163 channel_id: ChannelId,
164 count: usize,
165 before_id: Option<MessageId>,
166 ) -> Result<Vec<ChannelMessage>>;
167
168 #[cfg(test)]
169 async fn teardown(&self, url: &str);
170
171 #[cfg(test)]
172 fn as_fake(&self) -> Option<&FakeDb>;
173}
174
175pub struct PostgresDb {
176 pool: sqlx::PgPool,
177}
178
179impl PostgresDb {
180 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
181 let pool = DbOptions::new()
182 .max_connections(max_connections)
183 .connect(url)
184 .await
185 .context("failed to connect to postgres database")?;
186 Ok(Self { pool })
187 }
188
189 pub fn fuzzy_like_string(string: &str) -> String {
190 let mut result = String::with_capacity(string.len() * 2 + 1);
191 for c in string.chars() {
192 if c.is_alphanumeric() {
193 result.push('%');
194 result.push(c);
195 }
196 }
197 result.push('%');
198 result
199 }
200}
201
202#[async_trait]
203impl Db for PostgresDb {
204 // users
205
206 async fn create_user(
207 &self,
208 email_address: &str,
209 admin: bool,
210 params: NewUserParams,
211 ) -> Result<UserId> {
212 let query = "
213 INSERT INTO users (email_address, github_login, github_user_id, admin)
214 VALUES ($1, $2, $3, $4)
215 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
216 RETURNING id
217 ";
218 Ok(sqlx::query_scalar(query)
219 .bind(email_address)
220 .bind(params.github_login)
221 .bind(params.github_user_id)
222 .bind(admin)
223 .fetch_one(&self.pool)
224 .await
225 .map(UserId)?)
226 }
227
228 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
229 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
230 Ok(sqlx::query_as(query)
231 .bind(limit as i32)
232 .bind((page * limit) as i32)
233 .fetch_all(&self.pool)
234 .await?)
235 }
236
237 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
238 let like_string = Self::fuzzy_like_string(name_query);
239 let query = "
240 SELECT users.*
241 FROM users
242 WHERE github_login ILIKE $1
243 ORDER BY github_login <-> $2
244 LIMIT $3
245 ";
246 Ok(sqlx::query_as(query)
247 .bind(like_string)
248 .bind(name_query)
249 .bind(limit as i32)
250 .fetch_all(&self.pool)
251 .await?)
252 }
253
254 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
255 let users = self.get_users_by_ids(vec![id]).await?;
256 Ok(users.into_iter().next())
257 }
258
259 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
260 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
261 let query = "
262 SELECT users.*
263 FROM users
264 WHERE users.id = ANY ($1)
265 ";
266 Ok(sqlx::query_as(query)
267 .bind(&ids)
268 .fetch_all(&self.pool)
269 .await?)
270 }
271
272 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
273 let query = format!(
274 "
275 SELECT users.*
276 FROM users
277 WHERE invite_count = 0
278 AND inviter_id IS{} NULL
279 ",
280 if invited_by_another_user { " NOT" } else { "" }
281 );
282
283 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
284 }
285
286 async fn get_user_by_github_account(
287 &self,
288 github_login: &str,
289 github_user_id: Option<i32>,
290 ) -> Result<Option<User>> {
291 if let Some(github_user_id) = github_user_id {
292 let mut user = sqlx::query_as::<_, User>(
293 "
294 UPDATE users
295 SET github_login = $1
296 WHERE github_user_id = $2
297 RETURNING *
298 ",
299 )
300 .bind(github_login)
301 .bind(github_user_id)
302 .fetch_optional(&self.pool)
303 .await?;
304
305 if user.is_none() {
306 user = sqlx::query_as::<_, User>(
307 "
308 UPDATE users
309 SET github_user_id = $1
310 WHERE github_login = $2
311 RETURNING *
312 ",
313 )
314 .bind(github_user_id)
315 .bind(github_login)
316 .fetch_optional(&self.pool)
317 .await?;
318 }
319
320 Ok(user)
321 } else {
322 Ok(sqlx::query_as(
323 "
324 SELECT * FROM users
325 WHERE github_login = $1
326 LIMIT 1
327 ",
328 )
329 .bind(github_login)
330 .fetch_optional(&self.pool)
331 .await?)
332 }
333 }
334
335 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
336 let query = "UPDATE users SET admin = $1 WHERE id = $2";
337 Ok(sqlx::query(query)
338 .bind(is_admin)
339 .bind(id.0)
340 .execute(&self.pool)
341 .await
342 .map(drop)?)
343 }
344
345 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
346 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
347 Ok(sqlx::query(query)
348 .bind(connected_once)
349 .bind(id.0)
350 .execute(&self.pool)
351 .await
352 .map(drop)?)
353 }
354
355 async fn destroy_user(&self, id: UserId) -> Result<()> {
356 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
357 sqlx::query(query)
358 .bind(id.0)
359 .execute(&self.pool)
360 .await
361 .map(drop)?;
362 let query = "DELETE FROM users WHERE id = $1;";
363 Ok(sqlx::query(query)
364 .bind(id.0)
365 .execute(&self.pool)
366 .await
367 .map(drop)?)
368 }
369
370 // signups
371
372 async fn create_signup(&self, signup: Signup) -> Result<()> {
373 sqlx::query(
374 "
375 INSERT INTO signups
376 (
377 email_address,
378 email_confirmation_code,
379 email_confirmation_sent,
380 platform_linux,
381 platform_mac,
382 platform_windows,
383 platform_unknown,
384 editor_features,
385 programming_languages,
386 device_id
387 )
388 VALUES
389 ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
390 RETURNING id
391 ",
392 )
393 .bind(&signup.email_address)
394 .bind(&random_email_confirmation_code())
395 .bind(&signup.platform_linux)
396 .bind(&signup.platform_mac)
397 .bind(&signup.platform_windows)
398 .bind(&signup.editor_features)
399 .bind(&signup.programming_languages)
400 .bind(&signup.device_id)
401 .execute(&self.pool)
402 .await?;
403 Ok(())
404 }
405
406 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
407 Ok(sqlx::query_as(
408 "
409 SELECT
410 COUNT(*) as count,
411 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
412 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
413 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count
414 FROM (
415 SELECT *
416 FROM signups
417 WHERE
418 NOT email_confirmation_sent
419 ) AS unsent
420 ",
421 )
422 .fetch_one(&self.pool)
423 .await?)
424 }
425
426 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
427 Ok(sqlx::query_as(
428 "
429 SELECT
430 email_address, email_confirmation_code
431 FROM signups
432 WHERE
433 NOT email_confirmation_sent AND
434 platform_mac
435 LIMIT $1
436 ",
437 )
438 .bind(count as i32)
439 .fetch_all(&self.pool)
440 .await?)
441 }
442
443 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
444 sqlx::query(
445 "
446 UPDATE signups
447 SET email_confirmation_sent = 't'
448 WHERE email_address = ANY ($1)
449 ",
450 )
451 .bind(
452 &invites
453 .iter()
454 .map(|s| s.email_address.as_str())
455 .collect::<Vec<_>>(),
456 )
457 .execute(&self.pool)
458 .await?;
459 Ok(())
460 }
461
462 async fn create_user_from_invite(
463 &self,
464 invite: &Invite,
465 user: NewUserParams,
466 ) -> Result<NewUserResult> {
467 let mut tx = self.pool.begin().await?;
468
469 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
470 i32,
471 Option<UserId>,
472 Option<UserId>,
473 Option<String>,
474 ) = sqlx::query_as(
475 "
476 SELECT id, user_id, inviting_user_id, device_id
477 FROM signups
478 WHERE
479 email_address = $1 AND
480 email_confirmation_code = $2
481 ",
482 )
483 .bind(&invite.email_address)
484 .bind(&invite.email_confirmation_code)
485 .fetch_optional(&mut tx)
486 .await?
487 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
488
489 if existing_user_id.is_some() {
490 Err(Error::Http(
491 StatusCode::UNPROCESSABLE_ENTITY,
492 "invitation already redeemed".to_string(),
493 ))?;
494 }
495
496 let user_id: UserId = sqlx::query_scalar(
497 "
498 INSERT INTO users
499 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
500 VALUES
501 ($1, $2, $3, 'f', $4, $5)
502 RETURNING id
503 ",
504 )
505 .bind(&invite.email_address)
506 .bind(&user.github_login)
507 .bind(&user.github_user_id)
508 .bind(&user.invite_count)
509 .bind(random_invite_code())
510 .fetch_one(&mut tx)
511 .await?;
512
513 sqlx::query(
514 "
515 UPDATE signups
516 SET user_id = $1
517 WHERE id = $2
518 ",
519 )
520 .bind(&user_id)
521 .bind(&signup_id)
522 .execute(&mut tx)
523 .await?;
524
525 if let Some(inviting_user_id) = inviting_user_id {
526 let id: Option<UserId> = sqlx::query_scalar(
527 "
528 UPDATE users
529 SET invite_count = invite_count - 1
530 WHERE id = $1 AND invite_count > 0
531 RETURNING id
532 ",
533 )
534 .bind(&inviting_user_id)
535 .fetch_optional(&mut tx)
536 .await?;
537
538 if id.is_none() {
539 Err(Error::Http(
540 StatusCode::UNAUTHORIZED,
541 "no invites remaining".to_string(),
542 ))?;
543 }
544
545 sqlx::query(
546 "
547 INSERT INTO contacts
548 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
549 VALUES
550 ($1, $2, 't', 't', 't')
551 ",
552 )
553 .bind(inviting_user_id)
554 .bind(user_id)
555 .execute(&mut tx)
556 .await?;
557 }
558
559 tx.commit().await?;
560 Ok(NewUserResult {
561 user_id,
562 inviting_user_id,
563 signup_device_id,
564 })
565 }
566
567 // invite codes
568
569 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
570 let mut tx = self.pool.begin().await?;
571 if count > 0 {
572 sqlx::query(
573 "
574 UPDATE users
575 SET invite_code = $1
576 WHERE id = $2 AND invite_code IS NULL
577 ",
578 )
579 .bind(random_invite_code())
580 .bind(id)
581 .execute(&mut tx)
582 .await?;
583 }
584
585 sqlx::query(
586 "
587 UPDATE users
588 SET invite_count = $1
589 WHERE id = $2
590 ",
591 )
592 .bind(count as i32)
593 .bind(id)
594 .execute(&mut tx)
595 .await?;
596 tx.commit().await?;
597 Ok(())
598 }
599
600 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
601 let result: Option<(String, i32)> = sqlx::query_as(
602 "
603 SELECT invite_code, invite_count
604 FROM users
605 WHERE id = $1 AND invite_code IS NOT NULL
606 ",
607 )
608 .bind(id)
609 .fetch_optional(&self.pool)
610 .await?;
611 if let Some((code, count)) = result {
612 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
613 } else {
614 Ok(None)
615 }
616 }
617
618 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
619 sqlx::query_as(
620 "
621 SELECT *
622 FROM users
623 WHERE invite_code = $1
624 ",
625 )
626 .bind(code)
627 .fetch_optional(&self.pool)
628 .await?
629 .ok_or_else(|| {
630 Error::Http(
631 StatusCode::NOT_FOUND,
632 "that invite code does not exist".to_string(),
633 )
634 })
635 }
636
637 async fn create_invite_from_code(
638 &self,
639 code: &str,
640 email_address: &str,
641 device_id: Option<&str>,
642 ) -> Result<Invite> {
643 let mut tx = self.pool.begin().await?;
644
645 let existing_user: Option<UserId> = sqlx::query_scalar(
646 "
647 SELECT id
648 FROM users
649 WHERE email_address = $1
650 ",
651 )
652 .bind(email_address)
653 .fetch_optional(&mut tx)
654 .await?;
655 if existing_user.is_some() {
656 Err(anyhow!("email address is already in use"))?;
657 }
658
659 let row: Option<(UserId, i32)> = sqlx::query_as(
660 "
661 SELECT id, invite_count
662 FROM users
663 WHERE invite_code = $1
664 ",
665 )
666 .bind(code)
667 .fetch_optional(&mut tx)
668 .await?;
669
670 let (inviter_id, invite_count) = match row {
671 Some(row) => row,
672 None => Err(Error::Http(
673 StatusCode::NOT_FOUND,
674 "invite code not found".to_string(),
675 ))?,
676 };
677
678 if invite_count == 0 {
679 Err(Error::Http(
680 StatusCode::UNAUTHORIZED,
681 "no invites remaining".to_string(),
682 ))?;
683 }
684
685 let email_confirmation_code: String = sqlx::query_scalar(
686 "
687 INSERT INTO signups
688 (
689 email_address,
690 email_confirmation_code,
691 email_confirmation_sent,
692 inviting_user_id,
693 platform_linux,
694 platform_mac,
695 platform_windows,
696 platform_unknown,
697 device_id
698 )
699 VALUES
700 ($1, $2, 'f', $3, 'f', 'f', 'f', 't', $4)
701 ON CONFLICT (email_address)
702 DO UPDATE SET
703 inviting_user_id = excluded.inviting_user_id
704 RETURNING email_confirmation_code
705 ",
706 )
707 .bind(&email_address)
708 .bind(&random_email_confirmation_code())
709 .bind(&inviter_id)
710 .bind(&device_id)
711 .fetch_one(&mut tx)
712 .await?;
713
714 tx.commit().await?;
715
716 Ok(Invite {
717 email_address: email_address.into(),
718 email_confirmation_code,
719 })
720 }
721
722 // projects
723
724 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
725 Ok(sqlx::query_scalar(
726 "
727 INSERT INTO projects(host_user_id)
728 VALUES ($1)
729 RETURNING id
730 ",
731 )
732 .bind(host_user_id)
733 .fetch_one(&self.pool)
734 .await
735 .map(ProjectId)?)
736 }
737
738 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
739 sqlx::query(
740 "
741 UPDATE projects
742 SET unregistered = 't'
743 WHERE id = $1
744 ",
745 )
746 .bind(project_id)
747 .execute(&self.pool)
748 .await?;
749 Ok(())
750 }
751
752 async fn update_worktree_extensions(
753 &self,
754 project_id: ProjectId,
755 worktree_id: u64,
756 extensions: HashMap<String, u32>,
757 ) -> Result<()> {
758 if extensions.is_empty() {
759 return Ok(());
760 }
761
762 let mut query = QueryBuilder::new(
763 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
764 );
765 query.push_values(extensions, |mut query, (extension, count)| {
766 query
767 .push_bind(project_id)
768 .push_bind(worktree_id as i32)
769 .push_bind(extension)
770 .push_bind(count as i32);
771 });
772 query.push(
773 "
774 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
775 count = excluded.count
776 ",
777 );
778 query.build().execute(&self.pool).await?;
779
780 Ok(())
781 }
782
783 async fn get_project_extensions(
784 &self,
785 project_id: ProjectId,
786 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
787 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
788 struct WorktreeExtension {
789 worktree_id: i32,
790 extension: String,
791 count: i32,
792 }
793
794 let query = "
795 SELECT worktree_id, extension, count
796 FROM worktree_extensions
797 WHERE project_id = $1
798 ";
799 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
800 .bind(&project_id)
801 .fetch_all(&self.pool)
802 .await?;
803
804 let mut extension_counts = HashMap::default();
805 for count in counts {
806 extension_counts
807 .entry(count.worktree_id as u64)
808 .or_insert_with(HashMap::default)
809 .insert(count.extension, count.count as usize);
810 }
811 Ok(extension_counts)
812 }
813
814 async fn record_user_activity(
815 &self,
816 time_period: Range<OffsetDateTime>,
817 projects: &[(UserId, ProjectId)],
818 ) -> Result<()> {
819 let query = "
820 INSERT INTO project_activity_periods
821 (ended_at, duration_millis, user_id, project_id)
822 VALUES
823 ($1, $2, $3, $4);
824 ";
825
826 let mut tx = self.pool.begin().await?;
827 let duration_millis =
828 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
829 for (user_id, project_id) in projects {
830 sqlx::query(query)
831 .bind(time_period.end)
832 .bind(duration_millis)
833 .bind(user_id)
834 .bind(project_id)
835 .execute(&mut tx)
836 .await?;
837 }
838 tx.commit().await?;
839
840 Ok(())
841 }
842
843 async fn get_active_user_count(
844 &self,
845 time_period: Range<OffsetDateTime>,
846 min_duration: Duration,
847 only_collaborative: bool,
848 ) -> Result<usize> {
849 let mut with_clause = String::new();
850 with_clause.push_str("WITH\n");
851 with_clause.push_str(
852 "
853 project_durations AS (
854 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
855 FROM project_activity_periods
856 WHERE $1 < ended_at AND ended_at <= $2
857 GROUP BY user_id, project_id
858 ),
859 ",
860 );
861 with_clause.push_str(
862 "
863 project_collaborators as (
864 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
865 FROM project_durations
866 GROUP BY project_id
867 ),
868 ",
869 );
870
871 if only_collaborative {
872 with_clause.push_str(
873 "
874 user_durations AS (
875 SELECT user_id, SUM(project_duration) as total_duration
876 FROM project_durations, project_collaborators
877 WHERE
878 project_durations.project_id = project_collaborators.project_id AND
879 max_collaborators > 1
880 GROUP BY user_id
881 ORDER BY total_duration DESC
882 LIMIT $3
883 )
884 ",
885 );
886 } else {
887 with_clause.push_str(
888 "
889 user_durations AS (
890 SELECT user_id, SUM(project_duration) as total_duration
891 FROM project_durations
892 GROUP BY user_id
893 ORDER BY total_duration DESC
894 LIMIT $3
895 )
896 ",
897 );
898 }
899
900 let query = format!(
901 "
902 {with_clause}
903 SELECT count(user_durations.user_id)
904 FROM user_durations
905 WHERE user_durations.total_duration >= $3
906 "
907 );
908
909 let count: i64 = sqlx::query_scalar(&query)
910 .bind(time_period.start)
911 .bind(time_period.end)
912 .bind(min_duration.as_millis() as i64)
913 .fetch_one(&self.pool)
914 .await?;
915 Ok(count as usize)
916 }
917
918 async fn get_top_users_activity_summary(
919 &self,
920 time_period: Range<OffsetDateTime>,
921 max_user_count: usize,
922 ) -> Result<Vec<UserActivitySummary>> {
923 let query = "
924 WITH
925 project_durations AS (
926 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
927 FROM project_activity_periods
928 WHERE $1 < ended_at AND ended_at <= $2
929 GROUP BY user_id, project_id
930 ),
931 user_durations AS (
932 SELECT user_id, SUM(project_duration) as total_duration
933 FROM project_durations
934 GROUP BY user_id
935 ORDER BY total_duration DESC
936 LIMIT $3
937 ),
938 project_collaborators as (
939 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
940 FROM project_durations
941 GROUP BY project_id
942 )
943 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
944 FROM user_durations, project_durations, project_collaborators, users
945 WHERE
946 user_durations.user_id = project_durations.user_id AND
947 user_durations.user_id = users.id AND
948 project_durations.project_id = project_collaborators.project_id
949 ORDER BY total_duration DESC, user_id ASC, project_id ASC
950 ";
951
952 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
953 .bind(time_period.start)
954 .bind(time_period.end)
955 .bind(max_user_count as i32)
956 .fetch(&self.pool);
957
958 let mut result = Vec::<UserActivitySummary>::new();
959 while let Some(row) = rows.next().await {
960 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
961 let project_id = project_id;
962 let duration = Duration::from_millis(duration_millis as u64);
963 let project_activity = ProjectActivitySummary {
964 id: project_id,
965 duration,
966 max_collaborators: project_collaborators as usize,
967 };
968 if let Some(last_summary) = result.last_mut() {
969 if last_summary.id == user_id {
970 last_summary.project_activity.push(project_activity);
971 continue;
972 }
973 }
974 result.push(UserActivitySummary {
975 id: user_id,
976 project_activity: vec![project_activity],
977 github_login,
978 });
979 }
980
981 Ok(result)
982 }
983
984 async fn get_user_activity_timeline(
985 &self,
986 time_period: Range<OffsetDateTime>,
987 user_id: UserId,
988 ) -> Result<Vec<UserActivityPeriod>> {
989 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
990
991 let query = "
992 SELECT
993 project_activity_periods.ended_at,
994 project_activity_periods.duration_millis,
995 project_activity_periods.project_id,
996 worktree_extensions.extension,
997 worktree_extensions.count
998 FROM project_activity_periods
999 LEFT OUTER JOIN
1000 worktree_extensions
1001 ON
1002 project_activity_periods.project_id = worktree_extensions.project_id
1003 WHERE
1004 project_activity_periods.user_id = $1 AND
1005 $2 < project_activity_periods.ended_at AND
1006 project_activity_periods.ended_at <= $3
1007 ORDER BY project_activity_periods.id ASC
1008 ";
1009
1010 let mut rows = sqlx::query_as::<
1011 _,
1012 (
1013 PrimitiveDateTime,
1014 i32,
1015 ProjectId,
1016 Option<String>,
1017 Option<i32>,
1018 ),
1019 >(query)
1020 .bind(user_id)
1021 .bind(time_period.start)
1022 .bind(time_period.end)
1023 .fetch(&self.pool);
1024
1025 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1026 while let Some(row) = rows.next().await {
1027 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1028 let ended_at = ended_at.assume_utc();
1029 let duration = Duration::from_millis(duration_millis as u64);
1030 let started_at = ended_at - duration;
1031 let project_time_periods = time_periods.entry(project_id).or_default();
1032
1033 if let Some(prev_duration) = project_time_periods.last_mut() {
1034 if started_at <= prev_duration.end + COALESCE_THRESHOLD
1035 && ended_at >= prev_duration.start
1036 {
1037 prev_duration.end = cmp::max(prev_duration.end, ended_at);
1038 } else {
1039 project_time_periods.push(UserActivityPeriod {
1040 project_id,
1041 start: started_at,
1042 end: ended_at,
1043 extensions: Default::default(),
1044 });
1045 }
1046 } else {
1047 project_time_periods.push(UserActivityPeriod {
1048 project_id,
1049 start: started_at,
1050 end: ended_at,
1051 extensions: Default::default(),
1052 });
1053 }
1054
1055 if let Some((extension, extension_count)) = extension.zip(extension_count) {
1056 project_time_periods
1057 .last_mut()
1058 .unwrap()
1059 .extensions
1060 .insert(extension, extension_count as usize);
1061 }
1062 }
1063
1064 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1065 durations.sort_unstable_by_key(|duration| duration.start);
1066 Ok(durations)
1067 }
1068
1069 // contacts
1070
1071 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1072 let query = "
1073 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1074 FROM contacts
1075 WHERE user_id_a = $1 OR user_id_b = $1;
1076 ";
1077
1078 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1079 .bind(user_id)
1080 .fetch(&self.pool);
1081
1082 let mut contacts = vec![Contact::Accepted {
1083 user_id,
1084 should_notify: false,
1085 }];
1086 while let Some(row) = rows.next().await {
1087 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1088
1089 if user_id_a == user_id {
1090 if accepted {
1091 contacts.push(Contact::Accepted {
1092 user_id: user_id_b,
1093 should_notify: should_notify && a_to_b,
1094 });
1095 } else if a_to_b {
1096 contacts.push(Contact::Outgoing { user_id: user_id_b })
1097 } else {
1098 contacts.push(Contact::Incoming {
1099 user_id: user_id_b,
1100 should_notify,
1101 });
1102 }
1103 } else if accepted {
1104 contacts.push(Contact::Accepted {
1105 user_id: user_id_a,
1106 should_notify: should_notify && !a_to_b,
1107 });
1108 } else if a_to_b {
1109 contacts.push(Contact::Incoming {
1110 user_id: user_id_a,
1111 should_notify,
1112 });
1113 } else {
1114 contacts.push(Contact::Outgoing { user_id: user_id_a });
1115 }
1116 }
1117
1118 contacts.sort_unstable_by_key(|contact| contact.user_id());
1119
1120 Ok(contacts)
1121 }
1122
1123 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1124 let (id_a, id_b) = if user_id_1 < user_id_2 {
1125 (user_id_1, user_id_2)
1126 } else {
1127 (user_id_2, user_id_1)
1128 };
1129
1130 let query = "
1131 SELECT 1 FROM contacts
1132 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1133 LIMIT 1
1134 ";
1135 Ok(sqlx::query_scalar::<_, i32>(query)
1136 .bind(id_a.0)
1137 .bind(id_b.0)
1138 .fetch_optional(&self.pool)
1139 .await?
1140 .is_some())
1141 }
1142
1143 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1144 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1145 (sender_id, receiver_id, true)
1146 } else {
1147 (receiver_id, sender_id, false)
1148 };
1149 let query = "
1150 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1151 VALUES ($1, $2, $3, 'f', 't')
1152 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1153 SET
1154 accepted = 't',
1155 should_notify = 'f'
1156 WHERE
1157 NOT contacts.accepted AND
1158 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1159 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1160 ";
1161 let result = sqlx::query(query)
1162 .bind(id_a.0)
1163 .bind(id_b.0)
1164 .bind(a_to_b)
1165 .execute(&self.pool)
1166 .await?;
1167
1168 if result.rows_affected() == 1 {
1169 Ok(())
1170 } else {
1171 Err(anyhow!("contact already requested"))?
1172 }
1173 }
1174
1175 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1176 let (id_a, id_b) = if responder_id < requester_id {
1177 (responder_id, requester_id)
1178 } else {
1179 (requester_id, responder_id)
1180 };
1181 let query = "
1182 DELETE FROM contacts
1183 WHERE user_id_a = $1 AND user_id_b = $2;
1184 ";
1185 let result = sqlx::query(query)
1186 .bind(id_a.0)
1187 .bind(id_b.0)
1188 .execute(&self.pool)
1189 .await?;
1190
1191 if result.rows_affected() == 1 {
1192 Ok(())
1193 } else {
1194 Err(anyhow!("no such contact"))?
1195 }
1196 }
1197
1198 async fn dismiss_contact_notification(
1199 &self,
1200 user_id: UserId,
1201 contact_user_id: UserId,
1202 ) -> Result<()> {
1203 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1204 (user_id, contact_user_id, true)
1205 } else {
1206 (contact_user_id, user_id, false)
1207 };
1208
1209 let query = "
1210 UPDATE contacts
1211 SET should_notify = 'f'
1212 WHERE
1213 user_id_a = $1 AND user_id_b = $2 AND
1214 (
1215 (a_to_b = $3 AND accepted) OR
1216 (a_to_b != $3 AND NOT accepted)
1217 );
1218 ";
1219
1220 let result = sqlx::query(query)
1221 .bind(id_a.0)
1222 .bind(id_b.0)
1223 .bind(a_to_b)
1224 .execute(&self.pool)
1225 .await?;
1226
1227 if result.rows_affected() == 0 {
1228 Err(anyhow!("no such contact request"))?;
1229 }
1230
1231 Ok(())
1232 }
1233
1234 async fn respond_to_contact_request(
1235 &self,
1236 responder_id: UserId,
1237 requester_id: UserId,
1238 accept: bool,
1239 ) -> Result<()> {
1240 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1241 (responder_id, requester_id, false)
1242 } else {
1243 (requester_id, responder_id, true)
1244 };
1245 let result = if accept {
1246 let query = "
1247 UPDATE contacts
1248 SET accepted = 't', should_notify = 't'
1249 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1250 ";
1251 sqlx::query(query)
1252 .bind(id_a.0)
1253 .bind(id_b.0)
1254 .bind(a_to_b)
1255 .execute(&self.pool)
1256 .await?
1257 } else {
1258 let query = "
1259 DELETE FROM contacts
1260 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1261 ";
1262 sqlx::query(query)
1263 .bind(id_a.0)
1264 .bind(id_b.0)
1265 .bind(a_to_b)
1266 .execute(&self.pool)
1267 .await?
1268 };
1269 if result.rows_affected() == 1 {
1270 Ok(())
1271 } else {
1272 Err(anyhow!("no such contact request"))?
1273 }
1274 }
1275
1276 // access tokens
1277
1278 async fn create_access_token_hash(
1279 &self,
1280 user_id: UserId,
1281 access_token_hash: &str,
1282 max_access_token_count: usize,
1283 ) -> Result<()> {
1284 let insert_query = "
1285 INSERT INTO access_tokens (user_id, hash)
1286 VALUES ($1, $2);
1287 ";
1288 let cleanup_query = "
1289 DELETE FROM access_tokens
1290 WHERE id IN (
1291 SELECT id from access_tokens
1292 WHERE user_id = $1
1293 ORDER BY id DESC
1294 OFFSET $3
1295 )
1296 ";
1297
1298 let mut tx = self.pool.begin().await?;
1299 sqlx::query(insert_query)
1300 .bind(user_id.0)
1301 .bind(access_token_hash)
1302 .execute(&mut tx)
1303 .await?;
1304 sqlx::query(cleanup_query)
1305 .bind(user_id.0)
1306 .bind(access_token_hash)
1307 .bind(max_access_token_count as i32)
1308 .execute(&mut tx)
1309 .await?;
1310 Ok(tx.commit().await?)
1311 }
1312
1313 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1314 let query = "
1315 SELECT hash
1316 FROM access_tokens
1317 WHERE user_id = $1
1318 ORDER BY id DESC
1319 ";
1320 Ok(sqlx::query_scalar(query)
1321 .bind(user_id.0)
1322 .fetch_all(&self.pool)
1323 .await?)
1324 }
1325
1326 // orgs
1327
1328 #[allow(unused)] // Help rust-analyzer
1329 #[cfg(any(test, feature = "seed-support"))]
1330 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1331 let query = "
1332 SELECT *
1333 FROM orgs
1334 WHERE slug = $1
1335 ";
1336 Ok(sqlx::query_as(query)
1337 .bind(slug)
1338 .fetch_optional(&self.pool)
1339 .await?)
1340 }
1341
1342 #[cfg(any(test, feature = "seed-support"))]
1343 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1344 let query = "
1345 INSERT INTO orgs (name, slug)
1346 VALUES ($1, $2)
1347 RETURNING id
1348 ";
1349 Ok(sqlx::query_scalar(query)
1350 .bind(name)
1351 .bind(slug)
1352 .fetch_one(&self.pool)
1353 .await
1354 .map(OrgId)?)
1355 }
1356
1357 #[cfg(any(test, feature = "seed-support"))]
1358 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1359 let query = "
1360 INSERT INTO org_memberships (org_id, user_id, admin)
1361 VALUES ($1, $2, $3)
1362 ON CONFLICT DO NOTHING
1363 ";
1364 Ok(sqlx::query(query)
1365 .bind(org_id.0)
1366 .bind(user_id.0)
1367 .bind(is_admin)
1368 .execute(&self.pool)
1369 .await
1370 .map(drop)?)
1371 }
1372
1373 // channels
1374
1375 #[cfg(any(test, feature = "seed-support"))]
1376 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1377 let query = "
1378 INSERT INTO channels (owner_id, owner_is_user, name)
1379 VALUES ($1, false, $2)
1380 RETURNING id
1381 ";
1382 Ok(sqlx::query_scalar(query)
1383 .bind(org_id.0)
1384 .bind(name)
1385 .fetch_one(&self.pool)
1386 .await
1387 .map(ChannelId)?)
1388 }
1389
1390 #[allow(unused)] // Help rust-analyzer
1391 #[cfg(any(test, feature = "seed-support"))]
1392 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1393 let query = "
1394 SELECT *
1395 FROM channels
1396 WHERE
1397 channels.owner_is_user = false AND
1398 channels.owner_id = $1
1399 ";
1400 Ok(sqlx::query_as(query)
1401 .bind(org_id.0)
1402 .fetch_all(&self.pool)
1403 .await?)
1404 }
1405
1406 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1407 let query = "
1408 SELECT
1409 channels.*
1410 FROM
1411 channel_memberships, channels
1412 WHERE
1413 channel_memberships.user_id = $1 AND
1414 channel_memberships.channel_id = channels.id
1415 ";
1416 Ok(sqlx::query_as(query)
1417 .bind(user_id.0)
1418 .fetch_all(&self.pool)
1419 .await?)
1420 }
1421
1422 async fn can_user_access_channel(
1423 &self,
1424 user_id: UserId,
1425 channel_id: ChannelId,
1426 ) -> Result<bool> {
1427 let query = "
1428 SELECT id
1429 FROM channel_memberships
1430 WHERE user_id = $1 AND channel_id = $2
1431 LIMIT 1
1432 ";
1433 Ok(sqlx::query_scalar::<_, i32>(query)
1434 .bind(user_id.0)
1435 .bind(channel_id.0)
1436 .fetch_optional(&self.pool)
1437 .await
1438 .map(|e| e.is_some())?)
1439 }
1440
1441 #[cfg(any(test, feature = "seed-support"))]
1442 async fn add_channel_member(
1443 &self,
1444 channel_id: ChannelId,
1445 user_id: UserId,
1446 is_admin: bool,
1447 ) -> Result<()> {
1448 let query = "
1449 INSERT INTO channel_memberships (channel_id, user_id, admin)
1450 VALUES ($1, $2, $3)
1451 ON CONFLICT DO NOTHING
1452 ";
1453 Ok(sqlx::query(query)
1454 .bind(channel_id.0)
1455 .bind(user_id.0)
1456 .bind(is_admin)
1457 .execute(&self.pool)
1458 .await
1459 .map(drop)?)
1460 }
1461
1462 // messages
1463
1464 async fn create_channel_message(
1465 &self,
1466 channel_id: ChannelId,
1467 sender_id: UserId,
1468 body: &str,
1469 timestamp: OffsetDateTime,
1470 nonce: u128,
1471 ) -> Result<MessageId> {
1472 let query = "
1473 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1474 VALUES ($1, $2, $3, $4, $5)
1475 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1476 RETURNING id
1477 ";
1478 Ok(sqlx::query_scalar(query)
1479 .bind(channel_id.0)
1480 .bind(sender_id.0)
1481 .bind(body)
1482 .bind(timestamp)
1483 .bind(Uuid::from_u128(nonce))
1484 .fetch_one(&self.pool)
1485 .await
1486 .map(MessageId)?)
1487 }
1488
1489 async fn get_channel_messages(
1490 &self,
1491 channel_id: ChannelId,
1492 count: usize,
1493 before_id: Option<MessageId>,
1494 ) -> Result<Vec<ChannelMessage>> {
1495 let query = r#"
1496 SELECT * FROM (
1497 SELECT
1498 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1499 FROM
1500 channel_messages
1501 WHERE
1502 channel_id = $1 AND
1503 id < $2
1504 ORDER BY id DESC
1505 LIMIT $3
1506 ) as recent_messages
1507 ORDER BY id ASC
1508 "#;
1509 Ok(sqlx::query_as(query)
1510 .bind(channel_id.0)
1511 .bind(before_id.unwrap_or(MessageId::MAX))
1512 .bind(count as i64)
1513 .fetch_all(&self.pool)
1514 .await?)
1515 }
1516
1517 #[cfg(test)]
1518 async fn teardown(&self, url: &str) {
1519 use util::ResultExt;
1520
1521 let query = "
1522 SELECT pg_terminate_backend(pg_stat_activity.pid)
1523 FROM pg_stat_activity
1524 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1525 ";
1526 sqlx::query(query).execute(&self.pool).await.log_err();
1527 self.pool.close().await;
1528 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1529 .await
1530 .log_err();
1531 }
1532
1533 #[cfg(test)]
1534 fn as_fake(&self) -> Option<&FakeDb> {
1535 None
1536 }
1537}
1538
1539macro_rules! id_type {
1540 ($name:ident) => {
1541 #[derive(
1542 Clone,
1543 Copy,
1544 Debug,
1545 Default,
1546 PartialEq,
1547 Eq,
1548 PartialOrd,
1549 Ord,
1550 Hash,
1551 sqlx::Type,
1552 Serialize,
1553 Deserialize,
1554 )]
1555 #[sqlx(transparent)]
1556 #[serde(transparent)]
1557 pub struct $name(pub i32);
1558
1559 impl $name {
1560 #[allow(unused)]
1561 pub const MAX: Self = Self(i32::MAX);
1562
1563 #[allow(unused)]
1564 pub fn from_proto(value: u64) -> Self {
1565 Self(value as i32)
1566 }
1567
1568 #[allow(unused)]
1569 pub fn to_proto(self) -> u64 {
1570 self.0 as u64
1571 }
1572 }
1573
1574 impl std::fmt::Display for $name {
1575 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1576 self.0.fmt(f)
1577 }
1578 }
1579 };
1580}
1581
1582id_type!(UserId);
1583#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1584pub struct User {
1585 pub id: UserId,
1586 pub github_login: String,
1587 pub github_user_id: Option<i32>,
1588 pub email_address: Option<String>,
1589 pub admin: bool,
1590 pub invite_code: Option<String>,
1591 pub invite_count: i32,
1592 pub connected_once: bool,
1593}
1594
1595id_type!(ProjectId);
1596#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1597pub struct Project {
1598 pub id: ProjectId,
1599 pub host_user_id: UserId,
1600 pub unregistered: bool,
1601}
1602
1603#[derive(Clone, Debug, PartialEq, Serialize)]
1604pub struct UserActivitySummary {
1605 pub id: UserId,
1606 pub github_login: String,
1607 pub project_activity: Vec<ProjectActivitySummary>,
1608}
1609
1610#[derive(Clone, Debug, PartialEq, Serialize)]
1611pub struct ProjectActivitySummary {
1612 pub id: ProjectId,
1613 pub duration: Duration,
1614 pub max_collaborators: usize,
1615}
1616
1617#[derive(Clone, Debug, PartialEq, Serialize)]
1618pub struct UserActivityPeriod {
1619 pub project_id: ProjectId,
1620 #[serde(with = "time::serde::iso8601")]
1621 pub start: OffsetDateTime,
1622 #[serde(with = "time::serde::iso8601")]
1623 pub end: OffsetDateTime,
1624 pub extensions: HashMap<String, usize>,
1625}
1626
1627id_type!(OrgId);
1628#[derive(FromRow)]
1629pub struct Org {
1630 pub id: OrgId,
1631 pub name: String,
1632 pub slug: String,
1633}
1634
1635id_type!(ChannelId);
1636#[derive(Clone, Debug, FromRow, Serialize)]
1637pub struct Channel {
1638 pub id: ChannelId,
1639 pub name: String,
1640 pub owner_id: i32,
1641 pub owner_is_user: bool,
1642}
1643
1644id_type!(MessageId);
1645#[derive(Clone, Debug, FromRow)]
1646pub struct ChannelMessage {
1647 pub id: MessageId,
1648 pub channel_id: ChannelId,
1649 pub sender_id: UserId,
1650 pub body: String,
1651 pub sent_at: OffsetDateTime,
1652 pub nonce: Uuid,
1653}
1654
1655#[derive(Clone, Debug, PartialEq, Eq)]
1656pub enum Contact {
1657 Accepted {
1658 user_id: UserId,
1659 should_notify: bool,
1660 },
1661 Outgoing {
1662 user_id: UserId,
1663 },
1664 Incoming {
1665 user_id: UserId,
1666 should_notify: bool,
1667 },
1668}
1669
1670impl Contact {
1671 pub fn user_id(&self) -> UserId {
1672 match self {
1673 Contact::Accepted { user_id, .. } => *user_id,
1674 Contact::Outgoing { user_id } => *user_id,
1675 Contact::Incoming { user_id, .. } => *user_id,
1676 }
1677 }
1678}
1679
1680#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1681pub struct IncomingContactRequest {
1682 pub requester_id: UserId,
1683 pub should_notify: bool,
1684}
1685
1686#[derive(Clone, Deserialize)]
1687pub struct Signup {
1688 pub email_address: String,
1689 pub platform_mac: bool,
1690 pub platform_windows: bool,
1691 pub platform_linux: bool,
1692 pub editor_features: Vec<String>,
1693 pub programming_languages: Vec<String>,
1694 pub device_id: Option<String>,
1695}
1696
1697#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1698pub struct WaitlistSummary {
1699 #[sqlx(default)]
1700 pub count: i64,
1701 #[sqlx(default)]
1702 pub linux_count: i64,
1703 #[sqlx(default)]
1704 pub mac_count: i64,
1705 #[sqlx(default)]
1706 pub windows_count: i64,
1707}
1708
1709#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1710pub struct Invite {
1711 pub email_address: String,
1712 pub email_confirmation_code: String,
1713}
1714
1715#[derive(Debug, Serialize, Deserialize)]
1716pub struct NewUserParams {
1717 pub github_login: String,
1718 pub github_user_id: i32,
1719 pub invite_count: i32,
1720}
1721
1722#[derive(Debug)]
1723pub struct NewUserResult {
1724 pub user_id: UserId,
1725 pub inviting_user_id: Option<UserId>,
1726 pub signup_device_id: Option<String>,
1727}
1728
1729fn random_invite_code() -> String {
1730 nanoid::nanoid!(16)
1731}
1732
1733fn random_email_confirmation_code() -> String {
1734 nanoid::nanoid!(64)
1735}
1736
1737#[cfg(test)]
1738pub use test::*;
1739
1740#[cfg(test)]
1741mod test {
1742 use super::*;
1743 use anyhow::anyhow;
1744 use collections::BTreeMap;
1745 use gpui::executor::Background;
1746 use lazy_static::lazy_static;
1747 use parking_lot::Mutex;
1748 use rand::prelude::*;
1749 use sqlx::{
1750 migrate::{MigrateDatabase, Migrator},
1751 Postgres,
1752 };
1753 use std::{path::Path, sync::Arc};
1754 use util::post_inc;
1755
1756 pub struct FakeDb {
1757 background: Arc<Background>,
1758 pub users: Mutex<BTreeMap<UserId, User>>,
1759 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1760 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1761 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1762 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1763 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1764 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1765 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1766 pub contacts: Mutex<Vec<FakeContact>>,
1767 next_channel_message_id: Mutex<i32>,
1768 next_user_id: Mutex<i32>,
1769 next_org_id: Mutex<i32>,
1770 next_channel_id: Mutex<i32>,
1771 next_project_id: Mutex<i32>,
1772 }
1773
1774 #[derive(Debug)]
1775 pub struct FakeContact {
1776 pub requester_id: UserId,
1777 pub responder_id: UserId,
1778 pub accepted: bool,
1779 pub should_notify: bool,
1780 }
1781
1782 impl FakeDb {
1783 pub fn new(background: Arc<Background>) -> Self {
1784 Self {
1785 background,
1786 users: Default::default(),
1787 next_user_id: Mutex::new(0),
1788 projects: Default::default(),
1789 worktree_extensions: Default::default(),
1790 next_project_id: Mutex::new(1),
1791 orgs: Default::default(),
1792 next_org_id: Mutex::new(1),
1793 org_memberships: Default::default(),
1794 channels: Default::default(),
1795 next_channel_id: Mutex::new(1),
1796 channel_memberships: Default::default(),
1797 channel_messages: Default::default(),
1798 next_channel_message_id: Mutex::new(1),
1799 contacts: Default::default(),
1800 }
1801 }
1802 }
1803
1804 #[async_trait]
1805 impl Db for FakeDb {
1806 async fn create_user(
1807 &self,
1808 email_address: &str,
1809 admin: bool,
1810 params: NewUserParams,
1811 ) -> Result<UserId> {
1812 self.background.simulate_random_delay().await;
1813
1814 let mut users = self.users.lock();
1815 if let Some(user) = users
1816 .values()
1817 .find(|user| user.github_login == params.github_login)
1818 {
1819 Ok(user.id)
1820 } else {
1821 let id = post_inc(&mut *self.next_user_id.lock());
1822 let user_id = UserId(id);
1823 users.insert(
1824 user_id,
1825 User {
1826 id: user_id,
1827 github_login: params.github_login,
1828 github_user_id: Some(params.github_user_id),
1829 email_address: Some(email_address.to_string()),
1830 admin,
1831 invite_code: None,
1832 invite_count: 0,
1833 connected_once: false,
1834 },
1835 );
1836 Ok(user_id)
1837 }
1838 }
1839
1840 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1841 unimplemented!()
1842 }
1843
1844 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1845 unimplemented!()
1846 }
1847
1848 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1849 self.background.simulate_random_delay().await;
1850 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1851 }
1852
1853 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1854 self.background.simulate_random_delay().await;
1855 let users = self.users.lock();
1856 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1857 }
1858
1859 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
1860 unimplemented!()
1861 }
1862
1863 async fn get_user_by_github_account(
1864 &self,
1865 github_login: &str,
1866 github_user_id: Option<i32>,
1867 ) -> Result<Option<User>> {
1868 self.background.simulate_random_delay().await;
1869 if let Some(github_user_id) = github_user_id {
1870 for user in self.users.lock().values_mut() {
1871 if user.github_user_id == Some(github_user_id) {
1872 user.github_login = github_login.into();
1873 return Ok(Some(user.clone()));
1874 }
1875 if user.github_login == github_login {
1876 user.github_user_id = Some(github_user_id);
1877 return Ok(Some(user.clone()));
1878 }
1879 }
1880 Ok(None)
1881 } else {
1882 Ok(self
1883 .users
1884 .lock()
1885 .values()
1886 .find(|user| user.github_login == github_login)
1887 .cloned())
1888 }
1889 }
1890
1891 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1892 unimplemented!()
1893 }
1894
1895 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1896 self.background.simulate_random_delay().await;
1897 let mut users = self.users.lock();
1898 let mut user = users
1899 .get_mut(&id)
1900 .ok_or_else(|| anyhow!("user not found"))?;
1901 user.connected_once = connected_once;
1902 Ok(())
1903 }
1904
1905 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1906 unimplemented!()
1907 }
1908
1909 // signups
1910
1911 async fn create_signup(&self, _signup: Signup) -> Result<()> {
1912 unimplemented!()
1913 }
1914
1915 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
1916 unimplemented!()
1917 }
1918
1919 async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
1920 unimplemented!()
1921 }
1922
1923 async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
1924 unimplemented!()
1925 }
1926
1927 async fn create_user_from_invite(
1928 &self,
1929 _invite: &Invite,
1930 _user: NewUserParams,
1931 ) -> Result<NewUserResult> {
1932 unimplemented!()
1933 }
1934
1935 // invite codes
1936
1937 async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
1938 unimplemented!()
1939 }
1940
1941 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1942 self.background.simulate_random_delay().await;
1943 Ok(None)
1944 }
1945
1946 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1947 unimplemented!()
1948 }
1949
1950 async fn create_invite_from_code(
1951 &self,
1952 _code: &str,
1953 _email_address: &str,
1954 _device_id: Option<&str>,
1955 ) -> Result<Invite> {
1956 unimplemented!()
1957 }
1958
1959 // projects
1960
1961 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
1962 self.background.simulate_random_delay().await;
1963 if !self.users.lock().contains_key(&host_user_id) {
1964 Err(anyhow!("no such user"))?;
1965 }
1966
1967 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
1968 self.projects.lock().insert(
1969 project_id,
1970 Project {
1971 id: project_id,
1972 host_user_id,
1973 unregistered: false,
1974 },
1975 );
1976 Ok(project_id)
1977 }
1978
1979 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
1980 self.background.simulate_random_delay().await;
1981 self.projects
1982 .lock()
1983 .get_mut(&project_id)
1984 .ok_or_else(|| anyhow!("no such project"))?
1985 .unregistered = true;
1986 Ok(())
1987 }
1988
1989 async fn update_worktree_extensions(
1990 &self,
1991 project_id: ProjectId,
1992 worktree_id: u64,
1993 extensions: HashMap<String, u32>,
1994 ) -> Result<()> {
1995 self.background.simulate_random_delay().await;
1996 if !self.projects.lock().contains_key(&project_id) {
1997 Err(anyhow!("no such project"))?;
1998 }
1999
2000 for (extension, count) in extensions {
2001 self.worktree_extensions
2002 .lock()
2003 .insert((project_id, worktree_id, extension), count);
2004 }
2005
2006 Ok(())
2007 }
2008
2009 async fn get_project_extensions(
2010 &self,
2011 _project_id: ProjectId,
2012 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
2013 unimplemented!()
2014 }
2015
2016 async fn record_user_activity(
2017 &self,
2018 _time_period: Range<OffsetDateTime>,
2019 _active_projects: &[(UserId, ProjectId)],
2020 ) -> Result<()> {
2021 unimplemented!()
2022 }
2023
2024 async fn get_active_user_count(
2025 &self,
2026 _time_period: Range<OffsetDateTime>,
2027 _min_duration: Duration,
2028 _only_collaborative: bool,
2029 ) -> Result<usize> {
2030 unimplemented!()
2031 }
2032
2033 async fn get_top_users_activity_summary(
2034 &self,
2035 _time_period: Range<OffsetDateTime>,
2036 _limit: usize,
2037 ) -> Result<Vec<UserActivitySummary>> {
2038 unimplemented!()
2039 }
2040
2041 async fn get_user_activity_timeline(
2042 &self,
2043 _time_period: Range<OffsetDateTime>,
2044 _user_id: UserId,
2045 ) -> Result<Vec<UserActivityPeriod>> {
2046 unimplemented!()
2047 }
2048
2049 // contacts
2050
2051 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2052 self.background.simulate_random_delay().await;
2053 let mut contacts = vec![Contact::Accepted {
2054 user_id: id,
2055 should_notify: false,
2056 }];
2057
2058 for contact in self.contacts.lock().iter() {
2059 if contact.requester_id == id {
2060 if contact.accepted {
2061 contacts.push(Contact::Accepted {
2062 user_id: contact.responder_id,
2063 should_notify: contact.should_notify,
2064 });
2065 } else {
2066 contacts.push(Contact::Outgoing {
2067 user_id: contact.responder_id,
2068 });
2069 }
2070 } else if contact.responder_id == id {
2071 if contact.accepted {
2072 contacts.push(Contact::Accepted {
2073 user_id: contact.requester_id,
2074 should_notify: false,
2075 });
2076 } else {
2077 contacts.push(Contact::Incoming {
2078 user_id: contact.requester_id,
2079 should_notify: contact.should_notify,
2080 });
2081 }
2082 }
2083 }
2084
2085 contacts.sort_unstable_by_key(|contact| contact.user_id());
2086 Ok(contacts)
2087 }
2088
2089 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2090 self.background.simulate_random_delay().await;
2091 Ok(self.contacts.lock().iter().any(|contact| {
2092 contact.accepted
2093 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2094 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2095 }))
2096 }
2097
2098 async fn send_contact_request(
2099 &self,
2100 requester_id: UserId,
2101 responder_id: UserId,
2102 ) -> Result<()> {
2103 self.background.simulate_random_delay().await;
2104 let mut contacts = self.contacts.lock();
2105 for contact in contacts.iter_mut() {
2106 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2107 if contact.accepted {
2108 Err(anyhow!("contact already exists"))?;
2109 } else {
2110 Err(anyhow!("contact already requested"))?;
2111 }
2112 }
2113 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2114 if contact.accepted {
2115 Err(anyhow!("contact already exists"))?;
2116 } else {
2117 contact.accepted = true;
2118 contact.should_notify = false;
2119 return Ok(());
2120 }
2121 }
2122 }
2123 contacts.push(FakeContact {
2124 requester_id,
2125 responder_id,
2126 accepted: false,
2127 should_notify: true,
2128 });
2129 Ok(())
2130 }
2131
2132 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2133 self.background.simulate_random_delay().await;
2134 self.contacts.lock().retain(|contact| {
2135 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2136 });
2137 Ok(())
2138 }
2139
2140 async fn dismiss_contact_notification(
2141 &self,
2142 user_id: UserId,
2143 contact_user_id: UserId,
2144 ) -> Result<()> {
2145 self.background.simulate_random_delay().await;
2146 let mut contacts = self.contacts.lock();
2147 for contact in contacts.iter_mut() {
2148 if contact.requester_id == contact_user_id
2149 && contact.responder_id == user_id
2150 && !contact.accepted
2151 {
2152 contact.should_notify = false;
2153 return Ok(());
2154 }
2155 if contact.requester_id == user_id
2156 && contact.responder_id == contact_user_id
2157 && contact.accepted
2158 {
2159 contact.should_notify = false;
2160 return Ok(());
2161 }
2162 }
2163 Err(anyhow!("no such notification"))?
2164 }
2165
2166 async fn respond_to_contact_request(
2167 &self,
2168 responder_id: UserId,
2169 requester_id: UserId,
2170 accept: bool,
2171 ) -> Result<()> {
2172 self.background.simulate_random_delay().await;
2173 let mut contacts = self.contacts.lock();
2174 for (ix, contact) in contacts.iter_mut().enumerate() {
2175 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2176 if contact.accepted {
2177 Err(anyhow!("contact already confirmed"))?;
2178 }
2179 if accept {
2180 contact.accepted = true;
2181 contact.should_notify = true;
2182 } else {
2183 contacts.remove(ix);
2184 }
2185 return Ok(());
2186 }
2187 }
2188 Err(anyhow!("no such contact request"))?
2189 }
2190
2191 async fn create_access_token_hash(
2192 &self,
2193 _user_id: UserId,
2194 _access_token_hash: &str,
2195 _max_access_token_count: usize,
2196 ) -> Result<()> {
2197 unimplemented!()
2198 }
2199
2200 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2201 unimplemented!()
2202 }
2203
2204 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2205 unimplemented!()
2206 }
2207
2208 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2209 self.background.simulate_random_delay().await;
2210 let mut orgs = self.orgs.lock();
2211 if orgs.values().any(|org| org.slug == slug) {
2212 Err(anyhow!("org already exists"))?
2213 } else {
2214 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2215 orgs.insert(
2216 org_id,
2217 Org {
2218 id: org_id,
2219 name: name.to_string(),
2220 slug: slug.to_string(),
2221 },
2222 );
2223 Ok(org_id)
2224 }
2225 }
2226
2227 async fn add_org_member(
2228 &self,
2229 org_id: OrgId,
2230 user_id: UserId,
2231 is_admin: bool,
2232 ) -> Result<()> {
2233 self.background.simulate_random_delay().await;
2234 if !self.orgs.lock().contains_key(&org_id) {
2235 Err(anyhow!("org does not exist"))?;
2236 }
2237 if !self.users.lock().contains_key(&user_id) {
2238 Err(anyhow!("user does not exist"))?;
2239 }
2240
2241 self.org_memberships
2242 .lock()
2243 .entry((org_id, user_id))
2244 .or_insert(is_admin);
2245 Ok(())
2246 }
2247
2248 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2249 self.background.simulate_random_delay().await;
2250 if !self.orgs.lock().contains_key(&org_id) {
2251 Err(anyhow!("org does not exist"))?;
2252 }
2253
2254 let mut channels = self.channels.lock();
2255 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2256 channels.insert(
2257 channel_id,
2258 Channel {
2259 id: channel_id,
2260 name: name.to_string(),
2261 owner_id: org_id.0,
2262 owner_is_user: false,
2263 },
2264 );
2265 Ok(channel_id)
2266 }
2267
2268 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2269 self.background.simulate_random_delay().await;
2270 Ok(self
2271 .channels
2272 .lock()
2273 .values()
2274 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2275 .cloned()
2276 .collect())
2277 }
2278
2279 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2280 self.background.simulate_random_delay().await;
2281 let channels = self.channels.lock();
2282 let memberships = self.channel_memberships.lock();
2283 Ok(channels
2284 .values()
2285 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2286 .cloned()
2287 .collect())
2288 }
2289
2290 async fn can_user_access_channel(
2291 &self,
2292 user_id: UserId,
2293 channel_id: ChannelId,
2294 ) -> Result<bool> {
2295 self.background.simulate_random_delay().await;
2296 Ok(self
2297 .channel_memberships
2298 .lock()
2299 .contains_key(&(channel_id, user_id)))
2300 }
2301
2302 async fn add_channel_member(
2303 &self,
2304 channel_id: ChannelId,
2305 user_id: UserId,
2306 is_admin: bool,
2307 ) -> Result<()> {
2308 self.background.simulate_random_delay().await;
2309 if !self.channels.lock().contains_key(&channel_id) {
2310 Err(anyhow!("channel does not exist"))?;
2311 }
2312 if !self.users.lock().contains_key(&user_id) {
2313 Err(anyhow!("user does not exist"))?;
2314 }
2315
2316 self.channel_memberships
2317 .lock()
2318 .entry((channel_id, user_id))
2319 .or_insert(is_admin);
2320 Ok(())
2321 }
2322
2323 async fn create_channel_message(
2324 &self,
2325 channel_id: ChannelId,
2326 sender_id: UserId,
2327 body: &str,
2328 timestamp: OffsetDateTime,
2329 nonce: u128,
2330 ) -> Result<MessageId> {
2331 self.background.simulate_random_delay().await;
2332 if !self.channels.lock().contains_key(&channel_id) {
2333 Err(anyhow!("channel does not exist"))?;
2334 }
2335 if !self.users.lock().contains_key(&sender_id) {
2336 Err(anyhow!("user does not exist"))?;
2337 }
2338
2339 let mut messages = self.channel_messages.lock();
2340 if let Some(message) = messages
2341 .values()
2342 .find(|message| message.nonce.as_u128() == nonce)
2343 {
2344 Ok(message.id)
2345 } else {
2346 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2347 messages.insert(
2348 message_id,
2349 ChannelMessage {
2350 id: message_id,
2351 channel_id,
2352 sender_id,
2353 body: body.to_string(),
2354 sent_at: timestamp,
2355 nonce: Uuid::from_u128(nonce),
2356 },
2357 );
2358 Ok(message_id)
2359 }
2360 }
2361
2362 async fn get_channel_messages(
2363 &self,
2364 channel_id: ChannelId,
2365 count: usize,
2366 before_id: Option<MessageId>,
2367 ) -> Result<Vec<ChannelMessage>> {
2368 self.background.simulate_random_delay().await;
2369 let mut messages = self
2370 .channel_messages
2371 .lock()
2372 .values()
2373 .rev()
2374 .filter(|message| {
2375 message.channel_id == channel_id
2376 && message.id < before_id.unwrap_or(MessageId::MAX)
2377 })
2378 .take(count)
2379 .cloned()
2380 .collect::<Vec<_>>();
2381 messages.sort_unstable_by_key(|message| message.id);
2382 Ok(messages)
2383 }
2384
2385 async fn teardown(&self, _: &str) {}
2386
2387 #[cfg(test)]
2388 fn as_fake(&self) -> Option<&FakeDb> {
2389 Some(self)
2390 }
2391 }
2392
2393 pub struct TestDb {
2394 pub db: Option<Arc<dyn Db>>,
2395 pub url: String,
2396 }
2397
2398 impl TestDb {
2399 #[allow(clippy::await_holding_lock)]
2400 pub async fn postgres() -> Self {
2401 lazy_static! {
2402 static ref LOCK: Mutex<()> = Mutex::new(());
2403 }
2404
2405 let _guard = LOCK.lock();
2406 let mut rng = StdRng::from_entropy();
2407 let name = format!("zed-test-{}", rng.gen::<u128>());
2408 let url = format!("postgres://postgres@localhost/{}", name);
2409 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2410 Postgres::create_database(&url)
2411 .await
2412 .expect("failed to create test db");
2413 let db = PostgresDb::new(&url, 5).await.unwrap();
2414 let migrator = Migrator::new(migrations_path).await.unwrap();
2415 migrator.run(&db.pool).await.unwrap();
2416 Self {
2417 db: Some(Arc::new(db)),
2418 url,
2419 }
2420 }
2421
2422 pub fn fake(background: Arc<Background>) -> Self {
2423 Self {
2424 db: Some(Arc::new(FakeDb::new(background))),
2425 url: Default::default(),
2426 }
2427 }
2428
2429 pub fn db(&self) -> &Arc<dyn Db> {
2430 self.db.as_ref().unwrap()
2431 }
2432 }
2433
2434 impl Drop for TestDb {
2435 fn drop(&mut self) {
2436 if let Some(db) = self.db.take() {
2437 futures::executor::block_on(db.teardown(&self.url));
2438 }
2439 }
2440 }
2441}