1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use collections::HashMap;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow, QueryBuilder};
10use std::{cmp, ops::Range, time::Duration};
11use time::{OffsetDateTime, PrimitiveDateTime};
12
13#[async_trait]
14pub trait Db: Send + Sync {
15 async fn create_user(
16 &self,
17 email_address: &str,
18 admin: bool,
19 params: NewUserParams,
20 ) -> Result<UserId>;
21 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
22 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
23 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
24 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
25 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
26 async fn get_user_by_github_account(
27 &self,
28 github_login: &str,
29 github_user_id: Option<i32>,
30 ) -> Result<Option<User>>;
31 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
32 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
33 async fn destroy_user(&self, id: UserId) -> Result<()>;
34
35 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
36 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
37 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
38 async fn create_invite_from_code(&self, code: &str, email_address: &str) -> Result<Invite>;
39
40 async fn create_signup(&self, signup: Signup) -> Result<i32>;
41 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
42 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
43 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
44 async fn create_user_from_invite(
45 &self,
46 invite: &Invite,
47 user: NewUserParams,
48 ) -> Result<(UserId, Option<UserId>)>;
49
50 /// Registers a new project for the given user.
51 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
52
53 /// Unregisters a project for the given project id.
54 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
55
56 /// Update file counts by extension for the given project and worktree.
57 async fn update_worktree_extensions(
58 &self,
59 project_id: ProjectId,
60 worktree_id: u64,
61 extensions: HashMap<String, u32>,
62 ) -> Result<()>;
63
64 /// Get the file counts on the given project keyed by their worktree and extension.
65 async fn get_project_extensions(
66 &self,
67 project_id: ProjectId,
68 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
69
70 /// Record which users have been active in which projects during
71 /// a given period of time.
72 async fn record_user_activity(
73 &self,
74 time_period: Range<OffsetDateTime>,
75 active_projects: &[(UserId, ProjectId)],
76 ) -> Result<()>;
77
78 /// Get the number of users who have been active in the given
79 /// time period for at least the given time duration.
80 async fn get_active_user_count(
81 &self,
82 time_period: Range<OffsetDateTime>,
83 min_duration: Duration,
84 only_collaborative: bool,
85 ) -> Result<usize>;
86
87 /// Get the users that have been most active during the given time period,
88 /// along with the amount of time they have been active in each project.
89 async fn get_top_users_activity_summary(
90 &self,
91 time_period: Range<OffsetDateTime>,
92 max_user_count: usize,
93 ) -> Result<Vec<UserActivitySummary>>;
94
95 /// Get the project activity for the given user and time period.
96 async fn get_user_activity_timeline(
97 &self,
98 time_period: Range<OffsetDateTime>,
99 user_id: UserId,
100 ) -> Result<Vec<UserActivityPeriod>>;
101
102 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
103 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
104 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
105 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
106 async fn dismiss_contact_notification(
107 &self,
108 responder_id: UserId,
109 requester_id: UserId,
110 ) -> Result<()>;
111 async fn respond_to_contact_request(
112 &self,
113 responder_id: UserId,
114 requester_id: UserId,
115 accept: bool,
116 ) -> Result<()>;
117
118 async fn create_access_token_hash(
119 &self,
120 user_id: UserId,
121 access_token_hash: &str,
122 max_access_token_count: usize,
123 ) -> Result<()>;
124 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
125
126 #[cfg(any(test, feature = "seed-support"))]
127 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
128 #[cfg(any(test, feature = "seed-support"))]
129 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
130 #[cfg(any(test, feature = "seed-support"))]
131 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
132 #[cfg(any(test, feature = "seed-support"))]
133 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
134 #[cfg(any(test, feature = "seed-support"))]
135
136 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
137 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
138 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
139 -> Result<bool>;
140
141 #[cfg(any(test, feature = "seed-support"))]
142 async fn add_channel_member(
143 &self,
144 channel_id: ChannelId,
145 user_id: UserId,
146 is_admin: bool,
147 ) -> Result<()>;
148 async fn create_channel_message(
149 &self,
150 channel_id: ChannelId,
151 sender_id: UserId,
152 body: &str,
153 timestamp: OffsetDateTime,
154 nonce: u128,
155 ) -> Result<MessageId>;
156 async fn get_channel_messages(
157 &self,
158 channel_id: ChannelId,
159 count: usize,
160 before_id: Option<MessageId>,
161 ) -> Result<Vec<ChannelMessage>>;
162
163 #[cfg(test)]
164 async fn teardown(&self, url: &str);
165
166 #[cfg(test)]
167 fn as_fake(&self) -> Option<&FakeDb>;
168}
169
170pub struct PostgresDb {
171 pool: sqlx::PgPool,
172}
173
174impl PostgresDb {
175 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
176 let pool = DbOptions::new()
177 .max_connections(max_connections)
178 .connect(url)
179 .await
180 .context("failed to connect to postgres database")?;
181 Ok(Self { pool })
182 }
183
184 pub fn fuzzy_like_string(string: &str) -> String {
185 let mut result = String::with_capacity(string.len() * 2 + 1);
186 for c in string.chars() {
187 if c.is_alphanumeric() {
188 result.push('%');
189 result.push(c);
190 }
191 }
192 result.push('%');
193 result
194 }
195}
196
197#[async_trait]
198impl Db for PostgresDb {
199 // users
200
201 async fn create_user(
202 &self,
203 email_address: &str,
204 admin: bool,
205 params: NewUserParams,
206 ) -> Result<UserId> {
207 let query = "
208 INSERT INTO users (email_address, github_login, github_user_id, admin)
209 VALUES ($1, $2, $3, $4)
210 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
211 RETURNING id
212 ";
213 Ok(sqlx::query_scalar(query)
214 .bind(email_address)
215 .bind(params.github_login)
216 .bind(params.github_user_id)
217 .bind(admin)
218 .fetch_one(&self.pool)
219 .await
220 .map(UserId)?)
221 }
222
223 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
224 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
225 Ok(sqlx::query_as(query)
226 .bind(limit as i32)
227 .bind((page * limit) as i32)
228 .fetch_all(&self.pool)
229 .await?)
230 }
231
232 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
233 let like_string = Self::fuzzy_like_string(name_query);
234 let query = "
235 SELECT users.*
236 FROM users
237 WHERE github_login ILIKE $1
238 ORDER BY github_login <-> $2
239 LIMIT $3
240 ";
241 Ok(sqlx::query_as(query)
242 .bind(like_string)
243 .bind(name_query)
244 .bind(limit as i32)
245 .fetch_all(&self.pool)
246 .await?)
247 }
248
249 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
250 let users = self.get_users_by_ids(vec![id]).await?;
251 Ok(users.into_iter().next())
252 }
253
254 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
255 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
256 let query = "
257 SELECT users.*
258 FROM users
259 WHERE users.id = ANY ($1)
260 ";
261 Ok(sqlx::query_as(query)
262 .bind(&ids)
263 .fetch_all(&self.pool)
264 .await?)
265 }
266
267 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
268 let query = format!(
269 "
270 SELECT users.*
271 FROM users
272 WHERE invite_count = 0
273 AND inviter_id IS{} NULL
274 ",
275 if invited_by_another_user { " NOT" } else { "" }
276 );
277
278 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
279 }
280
281 async fn get_user_by_github_account(
282 &self,
283 github_login: &str,
284 github_user_id: Option<i32>,
285 ) -> Result<Option<User>> {
286 if let Some(github_user_id) = github_user_id {
287 let mut user = sqlx::query_as::<_, User>(
288 "
289 UPDATE users
290 SET github_login = $1
291 WHERE github_user_id = $2
292 RETURNING *
293 ",
294 )
295 .bind(github_login)
296 .bind(github_user_id)
297 .fetch_optional(&self.pool)
298 .await?;
299
300 if user.is_none() {
301 user = sqlx::query_as::<_, User>(
302 "
303 UPDATE users
304 SET github_user_id = $1
305 WHERE github_login = $2
306 RETURNING *
307 ",
308 )
309 .bind(github_user_id)
310 .bind(github_login)
311 .fetch_optional(&self.pool)
312 .await?;
313 }
314
315 Ok(user)
316 } else {
317 Ok(sqlx::query_as(
318 "
319 SELECT * FROM users
320 WHERE github_login = $1
321 LIMIT 1
322 ",
323 )
324 .bind(github_login)
325 .fetch_optional(&self.pool)
326 .await?)
327 }
328 }
329
330 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
331 let query = "UPDATE users SET admin = $1 WHERE id = $2";
332 Ok(sqlx::query(query)
333 .bind(is_admin)
334 .bind(id.0)
335 .execute(&self.pool)
336 .await
337 .map(drop)?)
338 }
339
340 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
341 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
342 Ok(sqlx::query(query)
343 .bind(connected_once)
344 .bind(id.0)
345 .execute(&self.pool)
346 .await
347 .map(drop)?)
348 }
349
350 async fn destroy_user(&self, id: UserId) -> Result<()> {
351 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
352 sqlx::query(query)
353 .bind(id.0)
354 .execute(&self.pool)
355 .await
356 .map(drop)?;
357 let query = "DELETE FROM users WHERE id = $1;";
358 Ok(sqlx::query(query)
359 .bind(id.0)
360 .execute(&self.pool)
361 .await
362 .map(drop)?)
363 }
364
365 // signups
366
367 async fn create_signup(&self, signup: Signup) -> Result<i32> {
368 Ok(sqlx::query_scalar(
369 "
370 INSERT INTO signups
371 (
372 email_address,
373 email_confirmation_code,
374 email_confirmation_sent,
375 platform_linux,
376 platform_mac,
377 platform_windows,
378 platform_unknown,
379 editor_features,
380 programming_languages
381 )
382 VALUES
383 ($1, $2, 'f', $3, $4, $5, 'f', $6, $7)
384 RETURNING id
385 ",
386 )
387 .bind(&signup.email_address)
388 .bind(&random_email_confirmation_code())
389 .bind(&signup.platform_linux)
390 .bind(&signup.platform_mac)
391 .bind(&signup.platform_windows)
392 .bind(&signup.editor_features)
393 .bind(&signup.programming_languages)
394 .fetch_one(&self.pool)
395 .await?)
396 }
397
398 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
399 Ok(sqlx::query_as(
400 "
401 SELECT
402 COUNT(*) as count,
403 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
404 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
405 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count
406 FROM (
407 SELECT *
408 FROM signups
409 WHERE
410 NOT email_confirmation_sent
411 ) AS unsent
412 ",
413 )
414 .fetch_one(&self.pool)
415 .await?)
416 }
417
418 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
419 Ok(sqlx::query_as(
420 "
421 SELECT
422 email_address, email_confirmation_code
423 FROM signups
424 WHERE
425 NOT email_confirmation_sent AND
426 platform_mac
427 LIMIT $1
428 ",
429 )
430 .bind(count as i32)
431 .fetch_all(&self.pool)
432 .await?)
433 }
434
435 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
436 sqlx::query(
437 "
438 UPDATE signups
439 SET email_confirmation_sent = 't'
440 WHERE email_address = ANY ($1)
441 ",
442 )
443 .bind(
444 &invites
445 .iter()
446 .map(|s| s.email_address.as_str())
447 .collect::<Vec<_>>(),
448 )
449 .execute(&self.pool)
450 .await?;
451 Ok(())
452 }
453
454 async fn create_user_from_invite(
455 &self,
456 invite: &Invite,
457 user: NewUserParams,
458 ) -> Result<(UserId, Option<UserId>)> {
459 let mut tx = self.pool.begin().await?;
460
461 let (signup_id, metrics_id, existing_user_id, inviting_user_id): (
462 i32,
463 i32,
464 Option<UserId>,
465 Option<UserId>,
466 ) = sqlx::query_as(
467 "
468 SELECT id, metrics_id, user_id, inviting_user_id
469 FROM signups
470 WHERE
471 email_address = $1 AND
472 email_confirmation_code = $2
473 ",
474 )
475 .bind(&invite.email_address)
476 .bind(&invite.email_confirmation_code)
477 .fetch_optional(&mut tx)
478 .await?
479 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
480
481 if existing_user_id.is_some() {
482 Err(Error::Http(
483 StatusCode::UNPROCESSABLE_ENTITY,
484 "invitation already redeemed".to_string(),
485 ))?;
486 }
487
488 let user_id: UserId = sqlx::query_scalar(
489 "
490 INSERT INTO users
491 (email_address, github_login, github_user_id, admin, invite_count, invite_code, metrics_id)
492 VALUES
493 ($1, $2, $3, 'f', $4, $5, $6)
494 RETURNING id
495 ",
496 )
497 .bind(&invite.email_address)
498 .bind(&user.github_login)
499 .bind(&user.github_user_id)
500 .bind(&user.invite_count)
501 .bind(random_invite_code())
502 .bind(metrics_id)
503 .fetch_one(&mut tx)
504 .await?;
505
506 sqlx::query(
507 "
508 UPDATE signups
509 SET user_id = $1
510 WHERE id = $2
511 ",
512 )
513 .bind(&user_id)
514 .bind(&signup_id)
515 .execute(&mut tx)
516 .await?;
517
518 if let Some(inviting_user_id) = inviting_user_id {
519 let id: Option<UserId> = sqlx::query_scalar(
520 "
521 UPDATE users
522 SET invite_count = invite_count - 1
523 WHERE id = $1 AND invite_count > 0
524 RETURNING id
525 ",
526 )
527 .bind(&inviting_user_id)
528 .fetch_optional(&mut tx)
529 .await?;
530
531 if id.is_none() {
532 Err(Error::Http(
533 StatusCode::UNAUTHORIZED,
534 "no invites remaining".to_string(),
535 ))?;
536 }
537
538 sqlx::query(
539 "
540 INSERT INTO contacts
541 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
542 VALUES
543 ($1, $2, 't', 't', 't')
544 ",
545 )
546 .bind(inviting_user_id)
547 .bind(user_id)
548 .execute(&mut tx)
549 .await?;
550 }
551
552 tx.commit().await?;
553 Ok((user_id, inviting_user_id))
554 }
555
556 // invite codes
557
558 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
559 let mut tx = self.pool.begin().await?;
560 if count > 0 {
561 sqlx::query(
562 "
563 UPDATE users
564 SET invite_code = $1
565 WHERE id = $2 AND invite_code IS NULL
566 ",
567 )
568 .bind(random_invite_code())
569 .bind(id)
570 .execute(&mut tx)
571 .await?;
572 }
573
574 sqlx::query(
575 "
576 UPDATE users
577 SET invite_count = $1
578 WHERE id = $2
579 ",
580 )
581 .bind(count as i32)
582 .bind(id)
583 .execute(&mut tx)
584 .await?;
585 tx.commit().await?;
586 Ok(())
587 }
588
589 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
590 let result: Option<(String, i32)> = sqlx::query_as(
591 "
592 SELECT invite_code, invite_count
593 FROM users
594 WHERE id = $1 AND invite_code IS NOT NULL
595 ",
596 )
597 .bind(id)
598 .fetch_optional(&self.pool)
599 .await?;
600 if let Some((code, count)) = result {
601 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
602 } else {
603 Ok(None)
604 }
605 }
606
607 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
608 sqlx::query_as(
609 "
610 SELECT *
611 FROM users
612 WHERE invite_code = $1
613 ",
614 )
615 .bind(code)
616 .fetch_optional(&self.pool)
617 .await?
618 .ok_or_else(|| {
619 Error::Http(
620 StatusCode::NOT_FOUND,
621 "that invite code does not exist".to_string(),
622 )
623 })
624 }
625
626 async fn create_invite_from_code(&self, code: &str, email_address: &str) -> Result<Invite> {
627 let mut tx = self.pool.begin().await?;
628
629 let existing_user: Option<UserId> = sqlx::query_scalar(
630 "
631 SELECT id
632 FROM users
633 WHERE email_address = $1
634 ",
635 )
636 .bind(email_address)
637 .fetch_optional(&mut tx)
638 .await?;
639 if existing_user.is_some() {
640 Err(anyhow!("email address is already in use"))?;
641 }
642
643 let row: Option<(UserId, i32)> = sqlx::query_as(
644 "
645 SELECT id, invite_count
646 FROM users
647 WHERE invite_code = $1
648 ",
649 )
650 .bind(code)
651 .fetch_optional(&mut tx)
652 .await?;
653
654 let (inviter_id, invite_count) = match row {
655 Some(row) => row,
656 None => Err(Error::Http(
657 StatusCode::NOT_FOUND,
658 "invite code not found".to_string(),
659 ))?,
660 };
661
662 if invite_count == 0 {
663 Err(Error::Http(
664 StatusCode::UNAUTHORIZED,
665 "no invites remaining".to_string(),
666 ))?;
667 }
668
669 let email_confirmation_code: String = sqlx::query_scalar(
670 "
671 INSERT INTO signups
672 (
673 email_address,
674 email_confirmation_code,
675 email_confirmation_sent,
676 inviting_user_id,
677 platform_linux,
678 platform_mac,
679 platform_windows,
680 platform_unknown
681 )
682 VALUES
683 ($1, $2, 'f', $3, 'f', 'f', 'f', 't')
684 ON CONFLICT (email_address)
685 DO UPDATE SET
686 inviting_user_id = excluded.inviting_user_id
687 RETURNING email_confirmation_code
688 ",
689 )
690 .bind(&email_address)
691 .bind(&random_email_confirmation_code())
692 .bind(&inviter_id)
693 .fetch_one(&mut tx)
694 .await?;
695
696 tx.commit().await?;
697
698 Ok(Invite {
699 email_address: email_address.into(),
700 email_confirmation_code,
701 })
702 }
703
704 // projects
705
706 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
707 Ok(sqlx::query_scalar(
708 "
709 INSERT INTO projects(host_user_id)
710 VALUES ($1)
711 RETURNING id
712 ",
713 )
714 .bind(host_user_id)
715 .fetch_one(&self.pool)
716 .await
717 .map(ProjectId)?)
718 }
719
720 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
721 sqlx::query(
722 "
723 UPDATE projects
724 SET unregistered = 't'
725 WHERE id = $1
726 ",
727 )
728 .bind(project_id)
729 .execute(&self.pool)
730 .await?;
731 Ok(())
732 }
733
734 async fn update_worktree_extensions(
735 &self,
736 project_id: ProjectId,
737 worktree_id: u64,
738 extensions: HashMap<String, u32>,
739 ) -> Result<()> {
740 if extensions.is_empty() {
741 return Ok(());
742 }
743
744 let mut query = QueryBuilder::new(
745 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
746 );
747 query.push_values(extensions, |mut query, (extension, count)| {
748 query
749 .push_bind(project_id)
750 .push_bind(worktree_id as i32)
751 .push_bind(extension)
752 .push_bind(count as i32);
753 });
754 query.push(
755 "
756 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
757 count = excluded.count
758 ",
759 );
760 query.build().execute(&self.pool).await?;
761
762 Ok(())
763 }
764
765 async fn get_project_extensions(
766 &self,
767 project_id: ProjectId,
768 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
769 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
770 struct WorktreeExtension {
771 worktree_id: i32,
772 extension: String,
773 count: i32,
774 }
775
776 let query = "
777 SELECT worktree_id, extension, count
778 FROM worktree_extensions
779 WHERE project_id = $1
780 ";
781 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
782 .bind(&project_id)
783 .fetch_all(&self.pool)
784 .await?;
785
786 let mut extension_counts = HashMap::default();
787 for count in counts {
788 extension_counts
789 .entry(count.worktree_id as u64)
790 .or_insert_with(HashMap::default)
791 .insert(count.extension, count.count as usize);
792 }
793 Ok(extension_counts)
794 }
795
796 async fn record_user_activity(
797 &self,
798 time_period: Range<OffsetDateTime>,
799 projects: &[(UserId, ProjectId)],
800 ) -> Result<()> {
801 let query = "
802 INSERT INTO project_activity_periods
803 (ended_at, duration_millis, user_id, project_id)
804 VALUES
805 ($1, $2, $3, $4);
806 ";
807
808 let mut tx = self.pool.begin().await?;
809 let duration_millis =
810 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
811 for (user_id, project_id) in projects {
812 sqlx::query(query)
813 .bind(time_period.end)
814 .bind(duration_millis)
815 .bind(user_id)
816 .bind(project_id)
817 .execute(&mut tx)
818 .await?;
819 }
820 tx.commit().await?;
821
822 Ok(())
823 }
824
825 async fn get_active_user_count(
826 &self,
827 time_period: Range<OffsetDateTime>,
828 min_duration: Duration,
829 only_collaborative: bool,
830 ) -> Result<usize> {
831 let mut with_clause = String::new();
832 with_clause.push_str("WITH\n");
833 with_clause.push_str(
834 "
835 project_durations AS (
836 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
837 FROM project_activity_periods
838 WHERE $1 < ended_at AND ended_at <= $2
839 GROUP BY user_id, project_id
840 ),
841 ",
842 );
843 with_clause.push_str(
844 "
845 project_collaborators as (
846 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
847 FROM project_durations
848 GROUP BY project_id
849 ),
850 ",
851 );
852
853 if only_collaborative {
854 with_clause.push_str(
855 "
856 user_durations AS (
857 SELECT user_id, SUM(project_duration) as total_duration
858 FROM project_durations, project_collaborators
859 WHERE
860 project_durations.project_id = project_collaborators.project_id AND
861 max_collaborators > 1
862 GROUP BY user_id
863 ORDER BY total_duration DESC
864 LIMIT $3
865 )
866 ",
867 );
868 } else {
869 with_clause.push_str(
870 "
871 user_durations AS (
872 SELECT user_id, SUM(project_duration) as total_duration
873 FROM project_durations
874 GROUP BY user_id
875 ORDER BY total_duration DESC
876 LIMIT $3
877 )
878 ",
879 );
880 }
881
882 let query = format!(
883 "
884 {with_clause}
885 SELECT count(user_durations.user_id)
886 FROM user_durations
887 WHERE user_durations.total_duration >= $3
888 "
889 );
890
891 let count: i64 = sqlx::query_scalar(&query)
892 .bind(time_period.start)
893 .bind(time_period.end)
894 .bind(min_duration.as_millis() as i64)
895 .fetch_one(&self.pool)
896 .await?;
897 Ok(count as usize)
898 }
899
900 async fn get_top_users_activity_summary(
901 &self,
902 time_period: Range<OffsetDateTime>,
903 max_user_count: usize,
904 ) -> Result<Vec<UserActivitySummary>> {
905 let query = "
906 WITH
907 project_durations AS (
908 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
909 FROM project_activity_periods
910 WHERE $1 < ended_at AND ended_at <= $2
911 GROUP BY user_id, project_id
912 ),
913 user_durations AS (
914 SELECT user_id, SUM(project_duration) as total_duration
915 FROM project_durations
916 GROUP BY user_id
917 ORDER BY total_duration DESC
918 LIMIT $3
919 ),
920 project_collaborators as (
921 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
922 FROM project_durations
923 GROUP BY project_id
924 )
925 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
926 FROM user_durations, project_durations, project_collaborators, users
927 WHERE
928 user_durations.user_id = project_durations.user_id AND
929 user_durations.user_id = users.id AND
930 project_durations.project_id = project_collaborators.project_id
931 ORDER BY total_duration DESC, user_id ASC, project_id ASC
932 ";
933
934 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
935 .bind(time_period.start)
936 .bind(time_period.end)
937 .bind(max_user_count as i32)
938 .fetch(&self.pool);
939
940 let mut result = Vec::<UserActivitySummary>::new();
941 while let Some(row) = rows.next().await {
942 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
943 let project_id = project_id;
944 let duration = Duration::from_millis(duration_millis as u64);
945 let project_activity = ProjectActivitySummary {
946 id: project_id,
947 duration,
948 max_collaborators: project_collaborators as usize,
949 };
950 if let Some(last_summary) = result.last_mut() {
951 if last_summary.id == user_id {
952 last_summary.project_activity.push(project_activity);
953 continue;
954 }
955 }
956 result.push(UserActivitySummary {
957 id: user_id,
958 project_activity: vec![project_activity],
959 github_login,
960 });
961 }
962
963 Ok(result)
964 }
965
966 async fn get_user_activity_timeline(
967 &self,
968 time_period: Range<OffsetDateTime>,
969 user_id: UserId,
970 ) -> Result<Vec<UserActivityPeriod>> {
971 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
972
973 let query = "
974 SELECT
975 project_activity_periods.ended_at,
976 project_activity_periods.duration_millis,
977 project_activity_periods.project_id,
978 worktree_extensions.extension,
979 worktree_extensions.count
980 FROM project_activity_periods
981 LEFT OUTER JOIN
982 worktree_extensions
983 ON
984 project_activity_periods.project_id = worktree_extensions.project_id
985 WHERE
986 project_activity_periods.user_id = $1 AND
987 $2 < project_activity_periods.ended_at AND
988 project_activity_periods.ended_at <= $3
989 ORDER BY project_activity_periods.id ASC
990 ";
991
992 let mut rows = sqlx::query_as::<
993 _,
994 (
995 PrimitiveDateTime,
996 i32,
997 ProjectId,
998 Option<String>,
999 Option<i32>,
1000 ),
1001 >(query)
1002 .bind(user_id)
1003 .bind(time_period.start)
1004 .bind(time_period.end)
1005 .fetch(&self.pool);
1006
1007 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1008 while let Some(row) = rows.next().await {
1009 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1010 let ended_at = ended_at.assume_utc();
1011 let duration = Duration::from_millis(duration_millis as u64);
1012 let started_at = ended_at - duration;
1013 let project_time_periods = time_periods.entry(project_id).or_default();
1014
1015 if let Some(prev_duration) = project_time_periods.last_mut() {
1016 if started_at <= prev_duration.end + COALESCE_THRESHOLD
1017 && ended_at >= prev_duration.start
1018 {
1019 prev_duration.end = cmp::max(prev_duration.end, ended_at);
1020 } else {
1021 project_time_periods.push(UserActivityPeriod {
1022 project_id,
1023 start: started_at,
1024 end: ended_at,
1025 extensions: Default::default(),
1026 });
1027 }
1028 } else {
1029 project_time_periods.push(UserActivityPeriod {
1030 project_id,
1031 start: started_at,
1032 end: ended_at,
1033 extensions: Default::default(),
1034 });
1035 }
1036
1037 if let Some((extension, extension_count)) = extension.zip(extension_count) {
1038 project_time_periods
1039 .last_mut()
1040 .unwrap()
1041 .extensions
1042 .insert(extension, extension_count as usize);
1043 }
1044 }
1045
1046 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1047 durations.sort_unstable_by_key(|duration| duration.start);
1048 Ok(durations)
1049 }
1050
1051 // contacts
1052
1053 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1054 let query = "
1055 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1056 FROM contacts
1057 WHERE user_id_a = $1 OR user_id_b = $1;
1058 ";
1059
1060 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1061 .bind(user_id)
1062 .fetch(&self.pool);
1063
1064 let mut contacts = vec![Contact::Accepted {
1065 user_id,
1066 should_notify: false,
1067 }];
1068 while let Some(row) = rows.next().await {
1069 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1070
1071 if user_id_a == user_id {
1072 if accepted {
1073 contacts.push(Contact::Accepted {
1074 user_id: user_id_b,
1075 should_notify: should_notify && a_to_b,
1076 });
1077 } else if a_to_b {
1078 contacts.push(Contact::Outgoing { user_id: user_id_b })
1079 } else {
1080 contacts.push(Contact::Incoming {
1081 user_id: user_id_b,
1082 should_notify,
1083 });
1084 }
1085 } else if accepted {
1086 contacts.push(Contact::Accepted {
1087 user_id: user_id_a,
1088 should_notify: should_notify && !a_to_b,
1089 });
1090 } else if a_to_b {
1091 contacts.push(Contact::Incoming {
1092 user_id: user_id_a,
1093 should_notify,
1094 });
1095 } else {
1096 contacts.push(Contact::Outgoing { user_id: user_id_a });
1097 }
1098 }
1099
1100 contacts.sort_unstable_by_key(|contact| contact.user_id());
1101
1102 Ok(contacts)
1103 }
1104
1105 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1106 let (id_a, id_b) = if user_id_1 < user_id_2 {
1107 (user_id_1, user_id_2)
1108 } else {
1109 (user_id_2, user_id_1)
1110 };
1111
1112 let query = "
1113 SELECT 1 FROM contacts
1114 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1115 LIMIT 1
1116 ";
1117 Ok(sqlx::query_scalar::<_, i32>(query)
1118 .bind(id_a.0)
1119 .bind(id_b.0)
1120 .fetch_optional(&self.pool)
1121 .await?
1122 .is_some())
1123 }
1124
1125 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1126 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1127 (sender_id, receiver_id, true)
1128 } else {
1129 (receiver_id, sender_id, false)
1130 };
1131 let query = "
1132 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1133 VALUES ($1, $2, $3, 'f', 't')
1134 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1135 SET
1136 accepted = 't',
1137 should_notify = 'f'
1138 WHERE
1139 NOT contacts.accepted AND
1140 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1141 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1142 ";
1143 let result = sqlx::query(query)
1144 .bind(id_a.0)
1145 .bind(id_b.0)
1146 .bind(a_to_b)
1147 .execute(&self.pool)
1148 .await?;
1149
1150 if result.rows_affected() == 1 {
1151 Ok(())
1152 } else {
1153 Err(anyhow!("contact already requested"))?
1154 }
1155 }
1156
1157 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1158 let (id_a, id_b) = if responder_id < requester_id {
1159 (responder_id, requester_id)
1160 } else {
1161 (requester_id, responder_id)
1162 };
1163 let query = "
1164 DELETE FROM contacts
1165 WHERE user_id_a = $1 AND user_id_b = $2;
1166 ";
1167 let result = sqlx::query(query)
1168 .bind(id_a.0)
1169 .bind(id_b.0)
1170 .execute(&self.pool)
1171 .await?;
1172
1173 if result.rows_affected() == 1 {
1174 Ok(())
1175 } else {
1176 Err(anyhow!("no such contact"))?
1177 }
1178 }
1179
1180 async fn dismiss_contact_notification(
1181 &self,
1182 user_id: UserId,
1183 contact_user_id: UserId,
1184 ) -> Result<()> {
1185 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1186 (user_id, contact_user_id, true)
1187 } else {
1188 (contact_user_id, user_id, false)
1189 };
1190
1191 let query = "
1192 UPDATE contacts
1193 SET should_notify = 'f'
1194 WHERE
1195 user_id_a = $1 AND user_id_b = $2 AND
1196 (
1197 (a_to_b = $3 AND accepted) OR
1198 (a_to_b != $3 AND NOT accepted)
1199 );
1200 ";
1201
1202 let result = sqlx::query(query)
1203 .bind(id_a.0)
1204 .bind(id_b.0)
1205 .bind(a_to_b)
1206 .execute(&self.pool)
1207 .await?;
1208
1209 if result.rows_affected() == 0 {
1210 Err(anyhow!("no such contact request"))?;
1211 }
1212
1213 Ok(())
1214 }
1215
1216 async fn respond_to_contact_request(
1217 &self,
1218 responder_id: UserId,
1219 requester_id: UserId,
1220 accept: bool,
1221 ) -> Result<()> {
1222 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1223 (responder_id, requester_id, false)
1224 } else {
1225 (requester_id, responder_id, true)
1226 };
1227 let result = if accept {
1228 let query = "
1229 UPDATE contacts
1230 SET accepted = 't', should_notify = 't'
1231 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1232 ";
1233 sqlx::query(query)
1234 .bind(id_a.0)
1235 .bind(id_b.0)
1236 .bind(a_to_b)
1237 .execute(&self.pool)
1238 .await?
1239 } else {
1240 let query = "
1241 DELETE FROM contacts
1242 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1243 ";
1244 sqlx::query(query)
1245 .bind(id_a.0)
1246 .bind(id_b.0)
1247 .bind(a_to_b)
1248 .execute(&self.pool)
1249 .await?
1250 };
1251 if result.rows_affected() == 1 {
1252 Ok(())
1253 } else {
1254 Err(anyhow!("no such contact request"))?
1255 }
1256 }
1257
1258 // access tokens
1259
1260 async fn create_access_token_hash(
1261 &self,
1262 user_id: UserId,
1263 access_token_hash: &str,
1264 max_access_token_count: usize,
1265 ) -> Result<()> {
1266 let insert_query = "
1267 INSERT INTO access_tokens (user_id, hash)
1268 VALUES ($1, $2);
1269 ";
1270 let cleanup_query = "
1271 DELETE FROM access_tokens
1272 WHERE id IN (
1273 SELECT id from access_tokens
1274 WHERE user_id = $1
1275 ORDER BY id DESC
1276 OFFSET $3
1277 )
1278 ";
1279
1280 let mut tx = self.pool.begin().await?;
1281 sqlx::query(insert_query)
1282 .bind(user_id.0)
1283 .bind(access_token_hash)
1284 .execute(&mut tx)
1285 .await?;
1286 sqlx::query(cleanup_query)
1287 .bind(user_id.0)
1288 .bind(access_token_hash)
1289 .bind(max_access_token_count as i32)
1290 .execute(&mut tx)
1291 .await?;
1292 Ok(tx.commit().await?)
1293 }
1294
1295 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1296 let query = "
1297 SELECT hash
1298 FROM access_tokens
1299 WHERE user_id = $1
1300 ORDER BY id DESC
1301 ";
1302 Ok(sqlx::query_scalar(query)
1303 .bind(user_id.0)
1304 .fetch_all(&self.pool)
1305 .await?)
1306 }
1307
1308 // orgs
1309
1310 #[allow(unused)] // Help rust-analyzer
1311 #[cfg(any(test, feature = "seed-support"))]
1312 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1313 let query = "
1314 SELECT *
1315 FROM orgs
1316 WHERE slug = $1
1317 ";
1318 Ok(sqlx::query_as(query)
1319 .bind(slug)
1320 .fetch_optional(&self.pool)
1321 .await?)
1322 }
1323
1324 #[cfg(any(test, feature = "seed-support"))]
1325 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1326 let query = "
1327 INSERT INTO orgs (name, slug)
1328 VALUES ($1, $2)
1329 RETURNING id
1330 ";
1331 Ok(sqlx::query_scalar(query)
1332 .bind(name)
1333 .bind(slug)
1334 .fetch_one(&self.pool)
1335 .await
1336 .map(OrgId)?)
1337 }
1338
1339 #[cfg(any(test, feature = "seed-support"))]
1340 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1341 let query = "
1342 INSERT INTO org_memberships (org_id, user_id, admin)
1343 VALUES ($1, $2, $3)
1344 ON CONFLICT DO NOTHING
1345 ";
1346 Ok(sqlx::query(query)
1347 .bind(org_id.0)
1348 .bind(user_id.0)
1349 .bind(is_admin)
1350 .execute(&self.pool)
1351 .await
1352 .map(drop)?)
1353 }
1354
1355 // channels
1356
1357 #[cfg(any(test, feature = "seed-support"))]
1358 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1359 let query = "
1360 INSERT INTO channels (owner_id, owner_is_user, name)
1361 VALUES ($1, false, $2)
1362 RETURNING id
1363 ";
1364 Ok(sqlx::query_scalar(query)
1365 .bind(org_id.0)
1366 .bind(name)
1367 .fetch_one(&self.pool)
1368 .await
1369 .map(ChannelId)?)
1370 }
1371
1372 #[allow(unused)] // Help rust-analyzer
1373 #[cfg(any(test, feature = "seed-support"))]
1374 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1375 let query = "
1376 SELECT *
1377 FROM channels
1378 WHERE
1379 channels.owner_is_user = false AND
1380 channels.owner_id = $1
1381 ";
1382 Ok(sqlx::query_as(query)
1383 .bind(org_id.0)
1384 .fetch_all(&self.pool)
1385 .await?)
1386 }
1387
1388 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1389 let query = "
1390 SELECT
1391 channels.*
1392 FROM
1393 channel_memberships, channels
1394 WHERE
1395 channel_memberships.user_id = $1 AND
1396 channel_memberships.channel_id = channels.id
1397 ";
1398 Ok(sqlx::query_as(query)
1399 .bind(user_id.0)
1400 .fetch_all(&self.pool)
1401 .await?)
1402 }
1403
1404 async fn can_user_access_channel(
1405 &self,
1406 user_id: UserId,
1407 channel_id: ChannelId,
1408 ) -> Result<bool> {
1409 let query = "
1410 SELECT id
1411 FROM channel_memberships
1412 WHERE user_id = $1 AND channel_id = $2
1413 LIMIT 1
1414 ";
1415 Ok(sqlx::query_scalar::<_, i32>(query)
1416 .bind(user_id.0)
1417 .bind(channel_id.0)
1418 .fetch_optional(&self.pool)
1419 .await
1420 .map(|e| e.is_some())?)
1421 }
1422
1423 #[cfg(any(test, feature = "seed-support"))]
1424 async fn add_channel_member(
1425 &self,
1426 channel_id: ChannelId,
1427 user_id: UserId,
1428 is_admin: bool,
1429 ) -> Result<()> {
1430 let query = "
1431 INSERT INTO channel_memberships (channel_id, user_id, admin)
1432 VALUES ($1, $2, $3)
1433 ON CONFLICT DO NOTHING
1434 ";
1435 Ok(sqlx::query(query)
1436 .bind(channel_id.0)
1437 .bind(user_id.0)
1438 .bind(is_admin)
1439 .execute(&self.pool)
1440 .await
1441 .map(drop)?)
1442 }
1443
1444 // messages
1445
1446 async fn create_channel_message(
1447 &self,
1448 channel_id: ChannelId,
1449 sender_id: UserId,
1450 body: &str,
1451 timestamp: OffsetDateTime,
1452 nonce: u128,
1453 ) -> Result<MessageId> {
1454 let query = "
1455 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1456 VALUES ($1, $2, $3, $4, $5)
1457 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1458 RETURNING id
1459 ";
1460 Ok(sqlx::query_scalar(query)
1461 .bind(channel_id.0)
1462 .bind(sender_id.0)
1463 .bind(body)
1464 .bind(timestamp)
1465 .bind(Uuid::from_u128(nonce))
1466 .fetch_one(&self.pool)
1467 .await
1468 .map(MessageId)?)
1469 }
1470
1471 async fn get_channel_messages(
1472 &self,
1473 channel_id: ChannelId,
1474 count: usize,
1475 before_id: Option<MessageId>,
1476 ) -> Result<Vec<ChannelMessage>> {
1477 let query = r#"
1478 SELECT * FROM (
1479 SELECT
1480 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1481 FROM
1482 channel_messages
1483 WHERE
1484 channel_id = $1 AND
1485 id < $2
1486 ORDER BY id DESC
1487 LIMIT $3
1488 ) as recent_messages
1489 ORDER BY id ASC
1490 "#;
1491 Ok(sqlx::query_as(query)
1492 .bind(channel_id.0)
1493 .bind(before_id.unwrap_or(MessageId::MAX))
1494 .bind(count as i64)
1495 .fetch_all(&self.pool)
1496 .await?)
1497 }
1498
1499 #[cfg(test)]
1500 async fn teardown(&self, url: &str) {
1501 use util::ResultExt;
1502
1503 let query = "
1504 SELECT pg_terminate_backend(pg_stat_activity.pid)
1505 FROM pg_stat_activity
1506 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1507 ";
1508 sqlx::query(query).execute(&self.pool).await.log_err();
1509 self.pool.close().await;
1510 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1511 .await
1512 .log_err();
1513 }
1514
1515 #[cfg(test)]
1516 fn as_fake(&self) -> Option<&FakeDb> {
1517 None
1518 }
1519}
1520
1521macro_rules! id_type {
1522 ($name:ident) => {
1523 #[derive(
1524 Clone,
1525 Copy,
1526 Debug,
1527 Default,
1528 PartialEq,
1529 Eq,
1530 PartialOrd,
1531 Ord,
1532 Hash,
1533 sqlx::Type,
1534 Serialize,
1535 Deserialize,
1536 )]
1537 #[sqlx(transparent)]
1538 #[serde(transparent)]
1539 pub struct $name(pub i32);
1540
1541 impl $name {
1542 #[allow(unused)]
1543 pub const MAX: Self = Self(i32::MAX);
1544
1545 #[allow(unused)]
1546 pub fn from_proto(value: u64) -> Self {
1547 Self(value as i32)
1548 }
1549
1550 #[allow(unused)]
1551 pub fn to_proto(self) -> u64 {
1552 self.0 as u64
1553 }
1554 }
1555
1556 impl std::fmt::Display for $name {
1557 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1558 self.0.fmt(f)
1559 }
1560 }
1561 };
1562}
1563
1564id_type!(UserId);
1565#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1566pub struct User {
1567 pub id: UserId,
1568 pub github_login: String,
1569 pub github_user_id: Option<i32>,
1570 pub metrics_id: i32,
1571 pub email_address: Option<String>,
1572 pub admin: bool,
1573 pub invite_code: Option<String>,
1574 pub invite_count: i32,
1575 pub connected_once: bool,
1576}
1577
1578id_type!(ProjectId);
1579#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1580pub struct Project {
1581 pub id: ProjectId,
1582 pub host_user_id: UserId,
1583 pub unregistered: bool,
1584}
1585
1586#[derive(Clone, Debug, PartialEq, Serialize)]
1587pub struct UserActivitySummary {
1588 pub id: UserId,
1589 pub github_login: String,
1590 pub project_activity: Vec<ProjectActivitySummary>,
1591}
1592
1593#[derive(Clone, Debug, PartialEq, Serialize)]
1594pub struct ProjectActivitySummary {
1595 pub id: ProjectId,
1596 pub duration: Duration,
1597 pub max_collaborators: usize,
1598}
1599
1600#[derive(Clone, Debug, PartialEq, Serialize)]
1601pub struct UserActivityPeriod {
1602 pub project_id: ProjectId,
1603 #[serde(with = "time::serde::iso8601")]
1604 pub start: OffsetDateTime,
1605 #[serde(with = "time::serde::iso8601")]
1606 pub end: OffsetDateTime,
1607 pub extensions: HashMap<String, usize>,
1608}
1609
1610id_type!(OrgId);
1611#[derive(FromRow)]
1612pub struct Org {
1613 pub id: OrgId,
1614 pub name: String,
1615 pub slug: String,
1616}
1617
1618id_type!(ChannelId);
1619#[derive(Clone, Debug, FromRow, Serialize)]
1620pub struct Channel {
1621 pub id: ChannelId,
1622 pub name: String,
1623 pub owner_id: i32,
1624 pub owner_is_user: bool,
1625}
1626
1627id_type!(MessageId);
1628#[derive(Clone, Debug, FromRow)]
1629pub struct ChannelMessage {
1630 pub id: MessageId,
1631 pub channel_id: ChannelId,
1632 pub sender_id: UserId,
1633 pub body: String,
1634 pub sent_at: OffsetDateTime,
1635 pub nonce: Uuid,
1636}
1637
1638#[derive(Clone, Debug, PartialEq, Eq)]
1639pub enum Contact {
1640 Accepted {
1641 user_id: UserId,
1642 should_notify: bool,
1643 },
1644 Outgoing {
1645 user_id: UserId,
1646 },
1647 Incoming {
1648 user_id: UserId,
1649 should_notify: bool,
1650 },
1651}
1652
1653impl Contact {
1654 pub fn user_id(&self) -> UserId {
1655 match self {
1656 Contact::Accepted { user_id, .. } => *user_id,
1657 Contact::Outgoing { user_id } => *user_id,
1658 Contact::Incoming { user_id, .. } => *user_id,
1659 }
1660 }
1661}
1662
1663#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1664pub struct IncomingContactRequest {
1665 pub requester_id: UserId,
1666 pub should_notify: bool,
1667}
1668
1669#[derive(Clone, Deserialize)]
1670pub struct Signup {
1671 pub email_address: String,
1672 pub platform_mac: bool,
1673 pub platform_windows: bool,
1674 pub platform_linux: bool,
1675 pub editor_features: Vec<String>,
1676 pub programming_languages: Vec<String>,
1677}
1678
1679#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1680pub struct WaitlistSummary {
1681 #[sqlx(default)]
1682 pub count: i64,
1683 #[sqlx(default)]
1684 pub linux_count: i64,
1685 #[sqlx(default)]
1686 pub mac_count: i64,
1687 #[sqlx(default)]
1688 pub windows_count: i64,
1689}
1690
1691#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1692pub struct Invite {
1693 pub email_address: String,
1694 pub email_confirmation_code: String,
1695}
1696
1697#[derive(Debug, Serialize, Deserialize)]
1698pub struct NewUserParams {
1699 pub github_login: String,
1700 pub github_user_id: i32,
1701 pub invite_count: i32,
1702}
1703
1704fn random_invite_code() -> String {
1705 nanoid::nanoid!(16)
1706}
1707
1708fn random_email_confirmation_code() -> String {
1709 nanoid::nanoid!(64)
1710}
1711
1712#[cfg(test)]
1713pub use test::*;
1714
1715#[cfg(test)]
1716mod test {
1717 use super::*;
1718 use anyhow::anyhow;
1719 use collections::BTreeMap;
1720 use gpui::executor::Background;
1721 use lazy_static::lazy_static;
1722 use parking_lot::Mutex;
1723 use rand::prelude::*;
1724 use sqlx::{
1725 migrate::{MigrateDatabase, Migrator},
1726 Postgres,
1727 };
1728 use std::{path::Path, sync::Arc};
1729 use util::post_inc;
1730
1731 pub struct FakeDb {
1732 background: Arc<Background>,
1733 pub users: Mutex<BTreeMap<UserId, User>>,
1734 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1735 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1736 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1737 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1738 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1739 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1740 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1741 pub contacts: Mutex<Vec<FakeContact>>,
1742 next_channel_message_id: Mutex<i32>,
1743 next_user_id: Mutex<i32>,
1744 next_org_id: Mutex<i32>,
1745 next_channel_id: Mutex<i32>,
1746 next_project_id: Mutex<i32>,
1747 }
1748
1749 #[derive(Debug)]
1750 pub struct FakeContact {
1751 pub requester_id: UserId,
1752 pub responder_id: UserId,
1753 pub accepted: bool,
1754 pub should_notify: bool,
1755 }
1756
1757 impl FakeDb {
1758 pub fn new(background: Arc<Background>) -> Self {
1759 Self {
1760 background,
1761 users: Default::default(),
1762 next_user_id: Mutex::new(0),
1763 projects: Default::default(),
1764 worktree_extensions: Default::default(),
1765 next_project_id: Mutex::new(1),
1766 orgs: Default::default(),
1767 next_org_id: Mutex::new(1),
1768 org_memberships: Default::default(),
1769 channels: Default::default(),
1770 next_channel_id: Mutex::new(1),
1771 channel_memberships: Default::default(),
1772 channel_messages: Default::default(),
1773 next_channel_message_id: Mutex::new(1),
1774 contacts: Default::default(),
1775 }
1776 }
1777 }
1778
1779 #[async_trait]
1780 impl Db for FakeDb {
1781 async fn create_user(
1782 &self,
1783 email_address: &str,
1784 admin: bool,
1785 params: NewUserParams,
1786 ) -> Result<UserId> {
1787 self.background.simulate_random_delay().await;
1788
1789 let mut users = self.users.lock();
1790 if let Some(user) = users
1791 .values()
1792 .find(|user| user.github_login == params.github_login)
1793 {
1794 Ok(user.id)
1795 } else {
1796 let id = post_inc(&mut *self.next_user_id.lock());
1797 let user_id = UserId(id);
1798 users.insert(
1799 user_id,
1800 User {
1801 id: user_id,
1802 github_login: params.github_login,
1803 github_user_id: Some(params.github_user_id),
1804 email_address: Some(email_address.to_string()),
1805 metrics_id: id + 100,
1806 admin,
1807 invite_code: None,
1808 invite_count: 0,
1809 connected_once: false,
1810 },
1811 );
1812 Ok(user_id)
1813 }
1814 }
1815
1816 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1817 unimplemented!()
1818 }
1819
1820 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1821 unimplemented!()
1822 }
1823
1824 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1825 self.background.simulate_random_delay().await;
1826 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1827 }
1828
1829 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1830 self.background.simulate_random_delay().await;
1831 let users = self.users.lock();
1832 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1833 }
1834
1835 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
1836 unimplemented!()
1837 }
1838
1839 async fn get_user_by_github_account(
1840 &self,
1841 github_login: &str,
1842 github_user_id: Option<i32>,
1843 ) -> Result<Option<User>> {
1844 self.background.simulate_random_delay().await;
1845 if let Some(github_user_id) = github_user_id {
1846 for user in self.users.lock().values_mut() {
1847 if user.github_user_id == Some(github_user_id) {
1848 user.github_login = github_login.into();
1849 return Ok(Some(user.clone()));
1850 }
1851 if user.github_login == github_login {
1852 user.github_user_id = Some(github_user_id);
1853 return Ok(Some(user.clone()));
1854 }
1855 }
1856 Ok(None)
1857 } else {
1858 Ok(self
1859 .users
1860 .lock()
1861 .values()
1862 .find(|user| user.github_login == github_login)
1863 .cloned())
1864 }
1865 }
1866
1867 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1868 unimplemented!()
1869 }
1870
1871 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1872 self.background.simulate_random_delay().await;
1873 let mut users = self.users.lock();
1874 let mut user = users
1875 .get_mut(&id)
1876 .ok_or_else(|| anyhow!("user not found"))?;
1877 user.connected_once = connected_once;
1878 Ok(())
1879 }
1880
1881 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1882 unimplemented!()
1883 }
1884
1885 // signups
1886
1887 async fn create_signup(&self, _signup: Signup) -> Result<i32> {
1888 unimplemented!()
1889 }
1890
1891 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
1892 unimplemented!()
1893 }
1894
1895 async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
1896 unimplemented!()
1897 }
1898
1899 async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
1900 unimplemented!()
1901 }
1902
1903 async fn create_user_from_invite(
1904 &self,
1905 _invite: &Invite,
1906 _user: NewUserParams,
1907 ) -> Result<(UserId, Option<UserId>)> {
1908 unimplemented!()
1909 }
1910
1911 // invite codes
1912
1913 async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
1914 unimplemented!()
1915 }
1916
1917 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1918 self.background.simulate_random_delay().await;
1919 Ok(None)
1920 }
1921
1922 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1923 unimplemented!()
1924 }
1925
1926 async fn create_invite_from_code(
1927 &self,
1928 _code: &str,
1929 _email_address: &str,
1930 ) -> Result<Invite> {
1931 unimplemented!()
1932 }
1933
1934 // projects
1935
1936 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
1937 self.background.simulate_random_delay().await;
1938 if !self.users.lock().contains_key(&host_user_id) {
1939 Err(anyhow!("no such user"))?;
1940 }
1941
1942 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
1943 self.projects.lock().insert(
1944 project_id,
1945 Project {
1946 id: project_id,
1947 host_user_id,
1948 unregistered: false,
1949 },
1950 );
1951 Ok(project_id)
1952 }
1953
1954 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
1955 self.background.simulate_random_delay().await;
1956 self.projects
1957 .lock()
1958 .get_mut(&project_id)
1959 .ok_or_else(|| anyhow!("no such project"))?
1960 .unregistered = true;
1961 Ok(())
1962 }
1963
1964 async fn update_worktree_extensions(
1965 &self,
1966 project_id: ProjectId,
1967 worktree_id: u64,
1968 extensions: HashMap<String, u32>,
1969 ) -> Result<()> {
1970 self.background.simulate_random_delay().await;
1971 if !self.projects.lock().contains_key(&project_id) {
1972 Err(anyhow!("no such project"))?;
1973 }
1974
1975 for (extension, count) in extensions {
1976 self.worktree_extensions
1977 .lock()
1978 .insert((project_id, worktree_id, extension), count);
1979 }
1980
1981 Ok(())
1982 }
1983
1984 async fn get_project_extensions(
1985 &self,
1986 _project_id: ProjectId,
1987 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
1988 unimplemented!()
1989 }
1990
1991 async fn record_user_activity(
1992 &self,
1993 _time_period: Range<OffsetDateTime>,
1994 _active_projects: &[(UserId, ProjectId)],
1995 ) -> Result<()> {
1996 unimplemented!()
1997 }
1998
1999 async fn get_active_user_count(
2000 &self,
2001 _time_period: Range<OffsetDateTime>,
2002 _min_duration: Duration,
2003 _only_collaborative: bool,
2004 ) -> Result<usize> {
2005 unimplemented!()
2006 }
2007
2008 async fn get_top_users_activity_summary(
2009 &self,
2010 _time_period: Range<OffsetDateTime>,
2011 _limit: usize,
2012 ) -> Result<Vec<UserActivitySummary>> {
2013 unimplemented!()
2014 }
2015
2016 async fn get_user_activity_timeline(
2017 &self,
2018 _time_period: Range<OffsetDateTime>,
2019 _user_id: UserId,
2020 ) -> Result<Vec<UserActivityPeriod>> {
2021 unimplemented!()
2022 }
2023
2024 // contacts
2025
2026 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2027 self.background.simulate_random_delay().await;
2028 let mut contacts = vec![Contact::Accepted {
2029 user_id: id,
2030 should_notify: false,
2031 }];
2032
2033 for contact in self.contacts.lock().iter() {
2034 if contact.requester_id == id {
2035 if contact.accepted {
2036 contacts.push(Contact::Accepted {
2037 user_id: contact.responder_id,
2038 should_notify: contact.should_notify,
2039 });
2040 } else {
2041 contacts.push(Contact::Outgoing {
2042 user_id: contact.responder_id,
2043 });
2044 }
2045 } else if contact.responder_id == id {
2046 if contact.accepted {
2047 contacts.push(Contact::Accepted {
2048 user_id: contact.requester_id,
2049 should_notify: false,
2050 });
2051 } else {
2052 contacts.push(Contact::Incoming {
2053 user_id: contact.requester_id,
2054 should_notify: contact.should_notify,
2055 });
2056 }
2057 }
2058 }
2059
2060 contacts.sort_unstable_by_key(|contact| contact.user_id());
2061 Ok(contacts)
2062 }
2063
2064 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2065 self.background.simulate_random_delay().await;
2066 Ok(self.contacts.lock().iter().any(|contact| {
2067 contact.accepted
2068 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2069 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2070 }))
2071 }
2072
2073 async fn send_contact_request(
2074 &self,
2075 requester_id: UserId,
2076 responder_id: UserId,
2077 ) -> Result<()> {
2078 self.background.simulate_random_delay().await;
2079 let mut contacts = self.contacts.lock();
2080 for contact in contacts.iter_mut() {
2081 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2082 if contact.accepted {
2083 Err(anyhow!("contact already exists"))?;
2084 } else {
2085 Err(anyhow!("contact already requested"))?;
2086 }
2087 }
2088 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2089 if contact.accepted {
2090 Err(anyhow!("contact already exists"))?;
2091 } else {
2092 contact.accepted = true;
2093 contact.should_notify = false;
2094 return Ok(());
2095 }
2096 }
2097 }
2098 contacts.push(FakeContact {
2099 requester_id,
2100 responder_id,
2101 accepted: false,
2102 should_notify: true,
2103 });
2104 Ok(())
2105 }
2106
2107 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2108 self.background.simulate_random_delay().await;
2109 self.contacts.lock().retain(|contact| {
2110 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2111 });
2112 Ok(())
2113 }
2114
2115 async fn dismiss_contact_notification(
2116 &self,
2117 user_id: UserId,
2118 contact_user_id: UserId,
2119 ) -> Result<()> {
2120 self.background.simulate_random_delay().await;
2121 let mut contacts = self.contacts.lock();
2122 for contact in contacts.iter_mut() {
2123 if contact.requester_id == contact_user_id
2124 && contact.responder_id == user_id
2125 && !contact.accepted
2126 {
2127 contact.should_notify = false;
2128 return Ok(());
2129 }
2130 if contact.requester_id == user_id
2131 && contact.responder_id == contact_user_id
2132 && contact.accepted
2133 {
2134 contact.should_notify = false;
2135 return Ok(());
2136 }
2137 }
2138 Err(anyhow!("no such notification"))?
2139 }
2140
2141 async fn respond_to_contact_request(
2142 &self,
2143 responder_id: UserId,
2144 requester_id: UserId,
2145 accept: bool,
2146 ) -> Result<()> {
2147 self.background.simulate_random_delay().await;
2148 let mut contacts = self.contacts.lock();
2149 for (ix, contact) in contacts.iter_mut().enumerate() {
2150 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2151 if contact.accepted {
2152 Err(anyhow!("contact already confirmed"))?;
2153 }
2154 if accept {
2155 contact.accepted = true;
2156 contact.should_notify = true;
2157 } else {
2158 contacts.remove(ix);
2159 }
2160 return Ok(());
2161 }
2162 }
2163 Err(anyhow!("no such contact request"))?
2164 }
2165
2166 async fn create_access_token_hash(
2167 &self,
2168 _user_id: UserId,
2169 _access_token_hash: &str,
2170 _max_access_token_count: usize,
2171 ) -> Result<()> {
2172 unimplemented!()
2173 }
2174
2175 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2176 unimplemented!()
2177 }
2178
2179 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2180 unimplemented!()
2181 }
2182
2183 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2184 self.background.simulate_random_delay().await;
2185 let mut orgs = self.orgs.lock();
2186 if orgs.values().any(|org| org.slug == slug) {
2187 Err(anyhow!("org already exists"))?
2188 } else {
2189 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2190 orgs.insert(
2191 org_id,
2192 Org {
2193 id: org_id,
2194 name: name.to_string(),
2195 slug: slug.to_string(),
2196 },
2197 );
2198 Ok(org_id)
2199 }
2200 }
2201
2202 async fn add_org_member(
2203 &self,
2204 org_id: OrgId,
2205 user_id: UserId,
2206 is_admin: bool,
2207 ) -> Result<()> {
2208 self.background.simulate_random_delay().await;
2209 if !self.orgs.lock().contains_key(&org_id) {
2210 Err(anyhow!("org does not exist"))?;
2211 }
2212 if !self.users.lock().contains_key(&user_id) {
2213 Err(anyhow!("user does not exist"))?;
2214 }
2215
2216 self.org_memberships
2217 .lock()
2218 .entry((org_id, user_id))
2219 .or_insert(is_admin);
2220 Ok(())
2221 }
2222
2223 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2224 self.background.simulate_random_delay().await;
2225 if !self.orgs.lock().contains_key(&org_id) {
2226 Err(anyhow!("org does not exist"))?;
2227 }
2228
2229 let mut channels = self.channels.lock();
2230 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2231 channels.insert(
2232 channel_id,
2233 Channel {
2234 id: channel_id,
2235 name: name.to_string(),
2236 owner_id: org_id.0,
2237 owner_is_user: false,
2238 },
2239 );
2240 Ok(channel_id)
2241 }
2242
2243 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2244 self.background.simulate_random_delay().await;
2245 Ok(self
2246 .channels
2247 .lock()
2248 .values()
2249 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2250 .cloned()
2251 .collect())
2252 }
2253
2254 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2255 self.background.simulate_random_delay().await;
2256 let channels = self.channels.lock();
2257 let memberships = self.channel_memberships.lock();
2258 Ok(channels
2259 .values()
2260 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2261 .cloned()
2262 .collect())
2263 }
2264
2265 async fn can_user_access_channel(
2266 &self,
2267 user_id: UserId,
2268 channel_id: ChannelId,
2269 ) -> Result<bool> {
2270 self.background.simulate_random_delay().await;
2271 Ok(self
2272 .channel_memberships
2273 .lock()
2274 .contains_key(&(channel_id, user_id)))
2275 }
2276
2277 async fn add_channel_member(
2278 &self,
2279 channel_id: ChannelId,
2280 user_id: UserId,
2281 is_admin: bool,
2282 ) -> Result<()> {
2283 self.background.simulate_random_delay().await;
2284 if !self.channels.lock().contains_key(&channel_id) {
2285 Err(anyhow!("channel does not exist"))?;
2286 }
2287 if !self.users.lock().contains_key(&user_id) {
2288 Err(anyhow!("user does not exist"))?;
2289 }
2290
2291 self.channel_memberships
2292 .lock()
2293 .entry((channel_id, user_id))
2294 .or_insert(is_admin);
2295 Ok(())
2296 }
2297
2298 async fn create_channel_message(
2299 &self,
2300 channel_id: ChannelId,
2301 sender_id: UserId,
2302 body: &str,
2303 timestamp: OffsetDateTime,
2304 nonce: u128,
2305 ) -> Result<MessageId> {
2306 self.background.simulate_random_delay().await;
2307 if !self.channels.lock().contains_key(&channel_id) {
2308 Err(anyhow!("channel does not exist"))?;
2309 }
2310 if !self.users.lock().contains_key(&sender_id) {
2311 Err(anyhow!("user does not exist"))?;
2312 }
2313
2314 let mut messages = self.channel_messages.lock();
2315 if let Some(message) = messages
2316 .values()
2317 .find(|message| message.nonce.as_u128() == nonce)
2318 {
2319 Ok(message.id)
2320 } else {
2321 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2322 messages.insert(
2323 message_id,
2324 ChannelMessage {
2325 id: message_id,
2326 channel_id,
2327 sender_id,
2328 body: body.to_string(),
2329 sent_at: timestamp,
2330 nonce: Uuid::from_u128(nonce),
2331 },
2332 );
2333 Ok(message_id)
2334 }
2335 }
2336
2337 async fn get_channel_messages(
2338 &self,
2339 channel_id: ChannelId,
2340 count: usize,
2341 before_id: Option<MessageId>,
2342 ) -> Result<Vec<ChannelMessage>> {
2343 self.background.simulate_random_delay().await;
2344 let mut messages = self
2345 .channel_messages
2346 .lock()
2347 .values()
2348 .rev()
2349 .filter(|message| {
2350 message.channel_id == channel_id
2351 && message.id < before_id.unwrap_or(MessageId::MAX)
2352 })
2353 .take(count)
2354 .cloned()
2355 .collect::<Vec<_>>();
2356 messages.sort_unstable_by_key(|message| message.id);
2357 Ok(messages)
2358 }
2359
2360 async fn teardown(&self, _: &str) {}
2361
2362 #[cfg(test)]
2363 fn as_fake(&self) -> Option<&FakeDb> {
2364 Some(self)
2365 }
2366 }
2367
2368 pub struct TestDb {
2369 pub db: Option<Arc<dyn Db>>,
2370 pub url: String,
2371 }
2372
2373 impl TestDb {
2374 #[allow(clippy::await_holding_lock)]
2375 pub async fn postgres() -> Self {
2376 lazy_static! {
2377 static ref LOCK: Mutex<()> = Mutex::new(());
2378 }
2379
2380 let _guard = LOCK.lock();
2381 let mut rng = StdRng::from_entropy();
2382 let name = format!("zed-test-{}", rng.gen::<u128>());
2383 let url = format!("postgres://postgres@localhost/{}", name);
2384 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2385 Postgres::create_database(&url)
2386 .await
2387 .expect("failed to create test db");
2388 let db = PostgresDb::new(&url, 5).await.unwrap();
2389 let migrator = Migrator::new(migrations_path).await.unwrap();
2390 migrator.run(&db.pool).await.unwrap();
2391 Self {
2392 db: Some(Arc::new(db)),
2393 url,
2394 }
2395 }
2396
2397 pub fn fake(background: Arc<Background>) -> Self {
2398 Self {
2399 db: Some(Arc::new(FakeDb::new(background))),
2400 url: Default::default(),
2401 }
2402 }
2403
2404 pub fn db(&self) -> &Arc<dyn Db> {
2405 self.db.as_ref().unwrap()
2406 }
2407 }
2408
2409 impl Drop for TestDb {
2410 fn drop(&mut self) {
2411 if let Some(db) = self.db.take() {
2412 futures::executor::block_on(db.teardown(&self.url));
2413 }
2414 }
2415 }
2416}