1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use collections::HashMap;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow, QueryBuilder};
10use std::{cmp, ops::Range, time::Duration};
11use time::{OffsetDateTime, PrimitiveDateTime};
12
13#[async_trait]
14pub trait Db: Send + Sync {
15 async fn create_user(
16 &self,
17 email_address: &str,
18 admin: bool,
19 params: NewUserParams,
20 ) -> Result<UserId>;
21 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
22 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
23 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
24 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
25 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
26 async fn get_user_by_github_account(
27 &self,
28 github_login: &str,
29 github_user_id: Option<i32>,
30 ) -> Result<Option<User>>;
31 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
32 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
33 async fn destroy_user(&self, id: UserId) -> Result<()>;
34
35 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
36 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
37 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
38 async fn create_invite_from_code(&self, code: &str, email_address: &str) -> Result<Invite>;
39
40 async fn create_signup(&self, signup: Signup) -> Result<()>;
41 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
42 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
43 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
44 async fn create_user_from_invite(
45 &self,
46 invite: &Invite,
47 user: NewUserParams,
48 ) -> Result<(UserId, Option<UserId>, String)>;
49
50 /// Registers a new project for the given user.
51 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
52
53 /// Unregisters a project for the given project id.
54 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
55
56 /// Update file counts by extension for the given project and worktree.
57 async fn update_worktree_extensions(
58 &self,
59 project_id: ProjectId,
60 worktree_id: u64,
61 extensions: HashMap<String, u32>,
62 ) -> Result<()>;
63
64 /// Get the file counts on the given project keyed by their worktree and extension.
65 async fn get_project_extensions(
66 &self,
67 project_id: ProjectId,
68 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
69
70 /// Record which users have been active in which projects during
71 /// a given period of time.
72 async fn record_user_activity(
73 &self,
74 time_period: Range<OffsetDateTime>,
75 active_projects: &[(UserId, ProjectId)],
76 ) -> Result<()>;
77
78 /// Get the number of users who have been active in the given
79 /// time period for at least the given time duration.
80 async fn get_active_user_count(
81 &self,
82 time_period: Range<OffsetDateTime>,
83 min_duration: Duration,
84 only_collaborative: bool,
85 ) -> Result<usize>;
86
87 /// Get the users that have been most active during the given time period,
88 /// along with the amount of time they have been active in each project.
89 async fn get_top_users_activity_summary(
90 &self,
91 time_period: Range<OffsetDateTime>,
92 max_user_count: usize,
93 ) -> Result<Vec<UserActivitySummary>>;
94
95 /// Get the project activity for the given user and time period.
96 async fn get_user_activity_timeline(
97 &self,
98 time_period: Range<OffsetDateTime>,
99 user_id: UserId,
100 ) -> Result<Vec<UserActivityPeriod>>;
101
102 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
103 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
104 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
105 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
106 async fn dismiss_contact_notification(
107 &self,
108 responder_id: UserId,
109 requester_id: UserId,
110 ) -> Result<()>;
111 async fn respond_to_contact_request(
112 &self,
113 responder_id: UserId,
114 requester_id: UserId,
115 accept: bool,
116 ) -> Result<()>;
117
118 async fn create_access_token_hash(
119 &self,
120 user_id: UserId,
121 access_token_hash: &str,
122 max_access_token_count: usize,
123 ) -> Result<()>;
124 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
125
126 #[cfg(any(test, feature = "seed-support"))]
127 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
128 #[cfg(any(test, feature = "seed-support"))]
129 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
130 #[cfg(any(test, feature = "seed-support"))]
131 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
132 #[cfg(any(test, feature = "seed-support"))]
133 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
134 #[cfg(any(test, feature = "seed-support"))]
135
136 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
137 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
138 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
139 -> Result<bool>;
140
141 #[cfg(any(test, feature = "seed-support"))]
142 async fn add_channel_member(
143 &self,
144 channel_id: ChannelId,
145 user_id: UserId,
146 is_admin: bool,
147 ) -> Result<()>;
148 async fn create_channel_message(
149 &self,
150 channel_id: ChannelId,
151 sender_id: UserId,
152 body: &str,
153 timestamp: OffsetDateTime,
154 nonce: u128,
155 ) -> Result<MessageId>;
156 async fn get_channel_messages(
157 &self,
158 channel_id: ChannelId,
159 count: usize,
160 before_id: Option<MessageId>,
161 ) -> Result<Vec<ChannelMessage>>;
162
163 #[cfg(test)]
164 async fn teardown(&self, url: &str);
165
166 #[cfg(test)]
167 fn as_fake(&self) -> Option<&FakeDb>;
168}
169
170pub struct PostgresDb {
171 pool: sqlx::PgPool,
172}
173
174impl PostgresDb {
175 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
176 let pool = DbOptions::new()
177 .max_connections(max_connections)
178 .connect(url)
179 .await
180 .context("failed to connect to postgres database")?;
181 Ok(Self { pool })
182 }
183
184 pub fn fuzzy_like_string(string: &str) -> String {
185 let mut result = String::with_capacity(string.len() * 2 + 1);
186 for c in string.chars() {
187 if c.is_alphanumeric() {
188 result.push('%');
189 result.push(c);
190 }
191 }
192 result.push('%');
193 result
194 }
195}
196
197#[async_trait]
198impl Db for PostgresDb {
199 // users
200
201 async fn create_user(
202 &self,
203 email_address: &str,
204 admin: bool,
205 params: NewUserParams,
206 ) -> Result<UserId> {
207 let query = "
208 INSERT INTO users (email_address, github_login, github_user_id, admin)
209 VALUES ($1, $2, $3, $4)
210 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
211 RETURNING id
212 ";
213 Ok(sqlx::query_scalar(query)
214 .bind(email_address)
215 .bind(params.github_login)
216 .bind(params.github_user_id)
217 .bind(admin)
218 .fetch_one(&self.pool)
219 .await
220 .map(UserId)?)
221 }
222
223 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
224 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
225 Ok(sqlx::query_as(query)
226 .bind(limit as i32)
227 .bind((page * limit) as i32)
228 .fetch_all(&self.pool)
229 .await?)
230 }
231
232 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
233 let like_string = Self::fuzzy_like_string(name_query);
234 let query = "
235 SELECT users.*
236 FROM users
237 WHERE github_login ILIKE $1
238 ORDER BY github_login <-> $2
239 LIMIT $3
240 ";
241 Ok(sqlx::query_as(query)
242 .bind(like_string)
243 .bind(name_query)
244 .bind(limit as i32)
245 .fetch_all(&self.pool)
246 .await?)
247 }
248
249 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
250 let users = self.get_users_by_ids(vec![id]).await?;
251 Ok(users.into_iter().next())
252 }
253
254 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
255 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
256 let query = "
257 SELECT users.*
258 FROM users
259 WHERE users.id = ANY ($1)
260 ";
261 Ok(sqlx::query_as(query)
262 .bind(&ids)
263 .fetch_all(&self.pool)
264 .await?)
265 }
266
267 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
268 let query = format!(
269 "
270 SELECT users.*
271 FROM users
272 WHERE invite_count = 0
273 AND inviter_id IS{} NULL
274 ",
275 if invited_by_another_user { " NOT" } else { "" }
276 );
277
278 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
279 }
280
281 async fn get_user_by_github_account(
282 &self,
283 github_login: &str,
284 github_user_id: Option<i32>,
285 ) -> Result<Option<User>> {
286 if let Some(github_user_id) = github_user_id {
287 let mut user = sqlx::query_as::<_, User>(
288 "
289 UPDATE users
290 SET github_login = $1
291 WHERE github_user_id = $2
292 RETURNING *
293 ",
294 )
295 .bind(github_login)
296 .bind(github_user_id)
297 .fetch_optional(&self.pool)
298 .await?;
299
300 if user.is_none() {
301 user = sqlx::query_as::<_, User>(
302 "
303 UPDATE users
304 SET github_user_id = $1
305 WHERE github_login = $2
306 RETURNING *
307 ",
308 )
309 .bind(github_user_id)
310 .bind(github_login)
311 .fetch_optional(&self.pool)
312 .await?;
313 }
314
315 Ok(user)
316 } else {
317 Ok(sqlx::query_as(
318 "
319 SELECT * FROM users
320 WHERE github_login = $1
321 LIMIT 1
322 ",
323 )
324 .bind(github_login)
325 .fetch_optional(&self.pool)
326 .await?)
327 }
328 }
329
330 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
331 let query = "UPDATE users SET admin = $1 WHERE id = $2";
332 Ok(sqlx::query(query)
333 .bind(is_admin)
334 .bind(id.0)
335 .execute(&self.pool)
336 .await
337 .map(drop)?)
338 }
339
340 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
341 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
342 Ok(sqlx::query(query)
343 .bind(connected_once)
344 .bind(id.0)
345 .execute(&self.pool)
346 .await
347 .map(drop)?)
348 }
349
350 async fn destroy_user(&self, id: UserId) -> Result<()> {
351 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
352 sqlx::query(query)
353 .bind(id.0)
354 .execute(&self.pool)
355 .await
356 .map(drop)?;
357 let query = "DELETE FROM users WHERE id = $1;";
358 Ok(sqlx::query(query)
359 .bind(id.0)
360 .execute(&self.pool)
361 .await
362 .map(drop)?)
363 }
364
365 // signups
366
367 async fn create_signup(&self, signup: Signup) -> Result<()> {
368 sqlx::query(
369 "
370 INSERT INTO signups
371 (
372 email_address,
373 email_confirmation_code,
374 email_confirmation_sent,
375 platform_linux,
376 platform_mac,
377 platform_windows,
378 platform_unknown,
379 editor_features,
380 programming_languages,
381 device_id
382 )
383 VALUES
384 ($1, $2, 'f', $3, $4, $5, 'f', $6, $7, $8)
385 RETURNING id
386 ",
387 )
388 .bind(&signup.email_address)
389 .bind(&random_email_confirmation_code())
390 .bind(&signup.platform_linux)
391 .bind(&signup.platform_mac)
392 .bind(&signup.platform_windows)
393 .bind(&signup.editor_features)
394 .bind(&signup.programming_languages)
395 .bind(&signup.device_id)
396 .execute(&self.pool)
397 .await?;
398 Ok(())
399 }
400
401 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
402 Ok(sqlx::query_as(
403 "
404 SELECT
405 COUNT(*) as count,
406 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
407 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
408 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count
409 FROM (
410 SELECT *
411 FROM signups
412 WHERE
413 NOT email_confirmation_sent
414 ) AS unsent
415 ",
416 )
417 .fetch_one(&self.pool)
418 .await?)
419 }
420
421 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
422 Ok(sqlx::query_as(
423 "
424 SELECT
425 email_address, email_confirmation_code
426 FROM signups
427 WHERE
428 NOT email_confirmation_sent AND
429 platform_mac
430 LIMIT $1
431 ",
432 )
433 .bind(count as i32)
434 .fetch_all(&self.pool)
435 .await?)
436 }
437
438 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
439 sqlx::query(
440 "
441 UPDATE signups
442 SET email_confirmation_sent = 't'
443 WHERE email_address = ANY ($1)
444 ",
445 )
446 .bind(
447 &invites
448 .iter()
449 .map(|s| s.email_address.as_str())
450 .collect::<Vec<_>>(),
451 )
452 .execute(&self.pool)
453 .await?;
454 Ok(())
455 }
456
457 async fn create_user_from_invite(
458 &self,
459 invite: &Invite,
460 user: NewUserParams,
461 ) -> Result<(UserId, Option<UserId>, String)> {
462 let mut tx = self.pool.begin().await?;
463
464 let (signup_id, existing_user_id, inviting_user_id, device_id): (
465 i32,
466 Option<UserId>,
467 Option<UserId>,
468 String,
469 ) = sqlx::query_as(
470 "
471 SELECT id, user_id, inviting_user_id, device_id
472 FROM signups
473 WHERE
474 email_address = $1 AND
475 email_confirmation_code = $2
476 ",
477 )
478 .bind(&invite.email_address)
479 .bind(&invite.email_confirmation_code)
480 .fetch_optional(&mut tx)
481 .await?
482 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
483
484 if existing_user_id.is_some() {
485 Err(Error::Http(
486 StatusCode::UNPROCESSABLE_ENTITY,
487 "invitation already redeemed".to_string(),
488 ))?;
489 }
490
491 let user_id: UserId = sqlx::query_scalar(
492 "
493 INSERT INTO users
494 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
495 VALUES
496 ($1, $2, $3, 'f', $4, $5)
497 RETURNING id
498 ",
499 )
500 .bind(&invite.email_address)
501 .bind(&user.github_login)
502 .bind(&user.github_user_id)
503 .bind(&user.invite_count)
504 .bind(random_invite_code())
505 .fetch_one(&mut tx)
506 .await?;
507
508 sqlx::query(
509 "
510 UPDATE signups
511 SET user_id = $1
512 WHERE id = $2
513 ",
514 )
515 .bind(&user_id)
516 .bind(&signup_id)
517 .execute(&mut tx)
518 .await?;
519
520 if let Some(inviting_user_id) = inviting_user_id {
521 let id: Option<UserId> = sqlx::query_scalar(
522 "
523 UPDATE users
524 SET invite_count = invite_count - 1
525 WHERE id = $1 AND invite_count > 0
526 RETURNING id
527 ",
528 )
529 .bind(&inviting_user_id)
530 .fetch_optional(&mut tx)
531 .await?;
532
533 if id.is_none() {
534 Err(Error::Http(
535 StatusCode::UNAUTHORIZED,
536 "no invites remaining".to_string(),
537 ))?;
538 }
539
540 sqlx::query(
541 "
542 INSERT INTO contacts
543 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
544 VALUES
545 ($1, $2, 't', 't', 't')
546 ",
547 )
548 .bind(inviting_user_id)
549 .bind(user_id)
550 .execute(&mut tx)
551 .await?;
552 }
553
554 tx.commit().await?;
555 Ok((user_id, inviting_user_id, device_id))
556 }
557
558 // invite codes
559
560 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
561 let mut tx = self.pool.begin().await?;
562 if count > 0 {
563 sqlx::query(
564 "
565 UPDATE users
566 SET invite_code = $1
567 WHERE id = $2 AND invite_code IS NULL
568 ",
569 )
570 .bind(random_invite_code())
571 .bind(id)
572 .execute(&mut tx)
573 .await?;
574 }
575
576 sqlx::query(
577 "
578 UPDATE users
579 SET invite_count = $1
580 WHERE id = $2
581 ",
582 )
583 .bind(count as i32)
584 .bind(id)
585 .execute(&mut tx)
586 .await?;
587 tx.commit().await?;
588 Ok(())
589 }
590
591 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
592 let result: Option<(String, i32)> = sqlx::query_as(
593 "
594 SELECT invite_code, invite_count
595 FROM users
596 WHERE id = $1 AND invite_code IS NOT NULL
597 ",
598 )
599 .bind(id)
600 .fetch_optional(&self.pool)
601 .await?;
602 if let Some((code, count)) = result {
603 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
604 } else {
605 Ok(None)
606 }
607 }
608
609 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
610 sqlx::query_as(
611 "
612 SELECT *
613 FROM users
614 WHERE invite_code = $1
615 ",
616 )
617 .bind(code)
618 .fetch_optional(&self.pool)
619 .await?
620 .ok_or_else(|| {
621 Error::Http(
622 StatusCode::NOT_FOUND,
623 "that invite code does not exist".to_string(),
624 )
625 })
626 }
627
628 async fn create_invite_from_code(&self, code: &str, email_address: &str) -> Result<Invite> {
629 let mut tx = self.pool.begin().await?;
630
631 let existing_user: Option<UserId> = sqlx::query_scalar(
632 "
633 SELECT id
634 FROM users
635 WHERE email_address = $1
636 ",
637 )
638 .bind(email_address)
639 .fetch_optional(&mut tx)
640 .await?;
641 if existing_user.is_some() {
642 Err(anyhow!("email address is already in use"))?;
643 }
644
645 let row: Option<(UserId, i32)> = sqlx::query_as(
646 "
647 SELECT id, invite_count
648 FROM users
649 WHERE invite_code = $1
650 ",
651 )
652 .bind(code)
653 .fetch_optional(&mut tx)
654 .await?;
655
656 let (inviter_id, invite_count) = match row {
657 Some(row) => row,
658 None => Err(Error::Http(
659 StatusCode::NOT_FOUND,
660 "invite code not found".to_string(),
661 ))?,
662 };
663
664 if invite_count == 0 {
665 Err(Error::Http(
666 StatusCode::UNAUTHORIZED,
667 "no invites remaining".to_string(),
668 ))?;
669 }
670
671 let email_confirmation_code: String = sqlx::query_scalar(
672 "
673 INSERT INTO signups
674 (
675 email_address,
676 email_confirmation_code,
677 email_confirmation_sent,
678 inviting_user_id,
679 platform_linux,
680 platform_mac,
681 platform_windows,
682 platform_unknown
683 )
684 VALUES
685 ($1, $2, 'f', $3, 'f', 'f', 'f', 't')
686 ON CONFLICT (email_address)
687 DO UPDATE SET
688 inviting_user_id = excluded.inviting_user_id
689 RETURNING email_confirmation_code
690 ",
691 )
692 .bind(&email_address)
693 .bind(&random_email_confirmation_code())
694 .bind(&inviter_id)
695 .fetch_one(&mut tx)
696 .await?;
697
698 tx.commit().await?;
699
700 Ok(Invite {
701 email_address: email_address.into(),
702 email_confirmation_code,
703 })
704 }
705
706 // projects
707
708 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
709 Ok(sqlx::query_scalar(
710 "
711 INSERT INTO projects(host_user_id)
712 VALUES ($1)
713 RETURNING id
714 ",
715 )
716 .bind(host_user_id)
717 .fetch_one(&self.pool)
718 .await
719 .map(ProjectId)?)
720 }
721
722 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
723 sqlx::query(
724 "
725 UPDATE projects
726 SET unregistered = 't'
727 WHERE id = $1
728 ",
729 )
730 .bind(project_id)
731 .execute(&self.pool)
732 .await?;
733 Ok(())
734 }
735
736 async fn update_worktree_extensions(
737 &self,
738 project_id: ProjectId,
739 worktree_id: u64,
740 extensions: HashMap<String, u32>,
741 ) -> Result<()> {
742 if extensions.is_empty() {
743 return Ok(());
744 }
745
746 let mut query = QueryBuilder::new(
747 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
748 );
749 query.push_values(extensions, |mut query, (extension, count)| {
750 query
751 .push_bind(project_id)
752 .push_bind(worktree_id as i32)
753 .push_bind(extension)
754 .push_bind(count as i32);
755 });
756 query.push(
757 "
758 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
759 count = excluded.count
760 ",
761 );
762 query.build().execute(&self.pool).await?;
763
764 Ok(())
765 }
766
767 async fn get_project_extensions(
768 &self,
769 project_id: ProjectId,
770 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
771 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
772 struct WorktreeExtension {
773 worktree_id: i32,
774 extension: String,
775 count: i32,
776 }
777
778 let query = "
779 SELECT worktree_id, extension, count
780 FROM worktree_extensions
781 WHERE project_id = $1
782 ";
783 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
784 .bind(&project_id)
785 .fetch_all(&self.pool)
786 .await?;
787
788 let mut extension_counts = HashMap::default();
789 for count in counts {
790 extension_counts
791 .entry(count.worktree_id as u64)
792 .or_insert_with(HashMap::default)
793 .insert(count.extension, count.count as usize);
794 }
795 Ok(extension_counts)
796 }
797
798 async fn record_user_activity(
799 &self,
800 time_period: Range<OffsetDateTime>,
801 projects: &[(UserId, ProjectId)],
802 ) -> Result<()> {
803 let query = "
804 INSERT INTO project_activity_periods
805 (ended_at, duration_millis, user_id, project_id)
806 VALUES
807 ($1, $2, $3, $4);
808 ";
809
810 let mut tx = self.pool.begin().await?;
811 let duration_millis =
812 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
813 for (user_id, project_id) in projects {
814 sqlx::query(query)
815 .bind(time_period.end)
816 .bind(duration_millis)
817 .bind(user_id)
818 .bind(project_id)
819 .execute(&mut tx)
820 .await?;
821 }
822 tx.commit().await?;
823
824 Ok(())
825 }
826
827 async fn get_active_user_count(
828 &self,
829 time_period: Range<OffsetDateTime>,
830 min_duration: Duration,
831 only_collaborative: bool,
832 ) -> Result<usize> {
833 let mut with_clause = String::new();
834 with_clause.push_str("WITH\n");
835 with_clause.push_str(
836 "
837 project_durations AS (
838 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
839 FROM project_activity_periods
840 WHERE $1 < ended_at AND ended_at <= $2
841 GROUP BY user_id, project_id
842 ),
843 ",
844 );
845 with_clause.push_str(
846 "
847 project_collaborators as (
848 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
849 FROM project_durations
850 GROUP BY project_id
851 ),
852 ",
853 );
854
855 if only_collaborative {
856 with_clause.push_str(
857 "
858 user_durations AS (
859 SELECT user_id, SUM(project_duration) as total_duration
860 FROM project_durations, project_collaborators
861 WHERE
862 project_durations.project_id = project_collaborators.project_id AND
863 max_collaborators > 1
864 GROUP BY user_id
865 ORDER BY total_duration DESC
866 LIMIT $3
867 )
868 ",
869 );
870 } else {
871 with_clause.push_str(
872 "
873 user_durations AS (
874 SELECT user_id, SUM(project_duration) as total_duration
875 FROM project_durations
876 GROUP BY user_id
877 ORDER BY total_duration DESC
878 LIMIT $3
879 )
880 ",
881 );
882 }
883
884 let query = format!(
885 "
886 {with_clause}
887 SELECT count(user_durations.user_id)
888 FROM user_durations
889 WHERE user_durations.total_duration >= $3
890 "
891 );
892
893 let count: i64 = sqlx::query_scalar(&query)
894 .bind(time_period.start)
895 .bind(time_period.end)
896 .bind(min_duration.as_millis() as i64)
897 .fetch_one(&self.pool)
898 .await?;
899 Ok(count as usize)
900 }
901
902 async fn get_top_users_activity_summary(
903 &self,
904 time_period: Range<OffsetDateTime>,
905 max_user_count: usize,
906 ) -> Result<Vec<UserActivitySummary>> {
907 let query = "
908 WITH
909 project_durations AS (
910 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
911 FROM project_activity_periods
912 WHERE $1 < ended_at AND ended_at <= $2
913 GROUP BY user_id, project_id
914 ),
915 user_durations AS (
916 SELECT user_id, SUM(project_duration) as total_duration
917 FROM project_durations
918 GROUP BY user_id
919 ORDER BY total_duration DESC
920 LIMIT $3
921 ),
922 project_collaborators as (
923 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
924 FROM project_durations
925 GROUP BY project_id
926 )
927 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
928 FROM user_durations, project_durations, project_collaborators, users
929 WHERE
930 user_durations.user_id = project_durations.user_id AND
931 user_durations.user_id = users.id AND
932 project_durations.project_id = project_collaborators.project_id
933 ORDER BY total_duration DESC, user_id ASC, project_id ASC
934 ";
935
936 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
937 .bind(time_period.start)
938 .bind(time_period.end)
939 .bind(max_user_count as i32)
940 .fetch(&self.pool);
941
942 let mut result = Vec::<UserActivitySummary>::new();
943 while let Some(row) = rows.next().await {
944 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
945 let project_id = project_id;
946 let duration = Duration::from_millis(duration_millis as u64);
947 let project_activity = ProjectActivitySummary {
948 id: project_id,
949 duration,
950 max_collaborators: project_collaborators as usize,
951 };
952 if let Some(last_summary) = result.last_mut() {
953 if last_summary.id == user_id {
954 last_summary.project_activity.push(project_activity);
955 continue;
956 }
957 }
958 result.push(UserActivitySummary {
959 id: user_id,
960 project_activity: vec![project_activity],
961 github_login,
962 });
963 }
964
965 Ok(result)
966 }
967
968 async fn get_user_activity_timeline(
969 &self,
970 time_period: Range<OffsetDateTime>,
971 user_id: UserId,
972 ) -> Result<Vec<UserActivityPeriod>> {
973 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
974
975 let query = "
976 SELECT
977 project_activity_periods.ended_at,
978 project_activity_periods.duration_millis,
979 project_activity_periods.project_id,
980 worktree_extensions.extension,
981 worktree_extensions.count
982 FROM project_activity_periods
983 LEFT OUTER JOIN
984 worktree_extensions
985 ON
986 project_activity_periods.project_id = worktree_extensions.project_id
987 WHERE
988 project_activity_periods.user_id = $1 AND
989 $2 < project_activity_periods.ended_at AND
990 project_activity_periods.ended_at <= $3
991 ORDER BY project_activity_periods.id ASC
992 ";
993
994 let mut rows = sqlx::query_as::<
995 _,
996 (
997 PrimitiveDateTime,
998 i32,
999 ProjectId,
1000 Option<String>,
1001 Option<i32>,
1002 ),
1003 >(query)
1004 .bind(user_id)
1005 .bind(time_period.start)
1006 .bind(time_period.end)
1007 .fetch(&self.pool);
1008
1009 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
1010 while let Some(row) = rows.next().await {
1011 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
1012 let ended_at = ended_at.assume_utc();
1013 let duration = Duration::from_millis(duration_millis as u64);
1014 let started_at = ended_at - duration;
1015 let project_time_periods = time_periods.entry(project_id).or_default();
1016
1017 if let Some(prev_duration) = project_time_periods.last_mut() {
1018 if started_at <= prev_duration.end + COALESCE_THRESHOLD
1019 && ended_at >= prev_duration.start
1020 {
1021 prev_duration.end = cmp::max(prev_duration.end, ended_at);
1022 } else {
1023 project_time_periods.push(UserActivityPeriod {
1024 project_id,
1025 start: started_at,
1026 end: ended_at,
1027 extensions: Default::default(),
1028 });
1029 }
1030 } else {
1031 project_time_periods.push(UserActivityPeriod {
1032 project_id,
1033 start: started_at,
1034 end: ended_at,
1035 extensions: Default::default(),
1036 });
1037 }
1038
1039 if let Some((extension, extension_count)) = extension.zip(extension_count) {
1040 project_time_periods
1041 .last_mut()
1042 .unwrap()
1043 .extensions
1044 .insert(extension, extension_count as usize);
1045 }
1046 }
1047
1048 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1049 durations.sort_unstable_by_key(|duration| duration.start);
1050 Ok(durations)
1051 }
1052
1053 // contacts
1054
1055 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1056 let query = "
1057 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1058 FROM contacts
1059 WHERE user_id_a = $1 OR user_id_b = $1;
1060 ";
1061
1062 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1063 .bind(user_id)
1064 .fetch(&self.pool);
1065
1066 let mut contacts = vec![Contact::Accepted {
1067 user_id,
1068 should_notify: false,
1069 }];
1070 while let Some(row) = rows.next().await {
1071 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1072
1073 if user_id_a == user_id {
1074 if accepted {
1075 contacts.push(Contact::Accepted {
1076 user_id: user_id_b,
1077 should_notify: should_notify && a_to_b,
1078 });
1079 } else if a_to_b {
1080 contacts.push(Contact::Outgoing { user_id: user_id_b })
1081 } else {
1082 contacts.push(Contact::Incoming {
1083 user_id: user_id_b,
1084 should_notify,
1085 });
1086 }
1087 } else if accepted {
1088 contacts.push(Contact::Accepted {
1089 user_id: user_id_a,
1090 should_notify: should_notify && !a_to_b,
1091 });
1092 } else if a_to_b {
1093 contacts.push(Contact::Incoming {
1094 user_id: user_id_a,
1095 should_notify,
1096 });
1097 } else {
1098 contacts.push(Contact::Outgoing { user_id: user_id_a });
1099 }
1100 }
1101
1102 contacts.sort_unstable_by_key(|contact| contact.user_id());
1103
1104 Ok(contacts)
1105 }
1106
1107 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1108 let (id_a, id_b) = if user_id_1 < user_id_2 {
1109 (user_id_1, user_id_2)
1110 } else {
1111 (user_id_2, user_id_1)
1112 };
1113
1114 let query = "
1115 SELECT 1 FROM contacts
1116 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1117 LIMIT 1
1118 ";
1119 Ok(sqlx::query_scalar::<_, i32>(query)
1120 .bind(id_a.0)
1121 .bind(id_b.0)
1122 .fetch_optional(&self.pool)
1123 .await?
1124 .is_some())
1125 }
1126
1127 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1128 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1129 (sender_id, receiver_id, true)
1130 } else {
1131 (receiver_id, sender_id, false)
1132 };
1133 let query = "
1134 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1135 VALUES ($1, $2, $3, 'f', 't')
1136 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1137 SET
1138 accepted = 't',
1139 should_notify = 'f'
1140 WHERE
1141 NOT contacts.accepted AND
1142 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1143 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1144 ";
1145 let result = sqlx::query(query)
1146 .bind(id_a.0)
1147 .bind(id_b.0)
1148 .bind(a_to_b)
1149 .execute(&self.pool)
1150 .await?;
1151
1152 if result.rows_affected() == 1 {
1153 Ok(())
1154 } else {
1155 Err(anyhow!("contact already requested"))?
1156 }
1157 }
1158
1159 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1160 let (id_a, id_b) = if responder_id < requester_id {
1161 (responder_id, requester_id)
1162 } else {
1163 (requester_id, responder_id)
1164 };
1165 let query = "
1166 DELETE FROM contacts
1167 WHERE user_id_a = $1 AND user_id_b = $2;
1168 ";
1169 let result = sqlx::query(query)
1170 .bind(id_a.0)
1171 .bind(id_b.0)
1172 .execute(&self.pool)
1173 .await?;
1174
1175 if result.rows_affected() == 1 {
1176 Ok(())
1177 } else {
1178 Err(anyhow!("no such contact"))?
1179 }
1180 }
1181
1182 async fn dismiss_contact_notification(
1183 &self,
1184 user_id: UserId,
1185 contact_user_id: UserId,
1186 ) -> Result<()> {
1187 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1188 (user_id, contact_user_id, true)
1189 } else {
1190 (contact_user_id, user_id, false)
1191 };
1192
1193 let query = "
1194 UPDATE contacts
1195 SET should_notify = 'f'
1196 WHERE
1197 user_id_a = $1 AND user_id_b = $2 AND
1198 (
1199 (a_to_b = $3 AND accepted) OR
1200 (a_to_b != $3 AND NOT accepted)
1201 );
1202 ";
1203
1204 let result = sqlx::query(query)
1205 .bind(id_a.0)
1206 .bind(id_b.0)
1207 .bind(a_to_b)
1208 .execute(&self.pool)
1209 .await?;
1210
1211 if result.rows_affected() == 0 {
1212 Err(anyhow!("no such contact request"))?;
1213 }
1214
1215 Ok(())
1216 }
1217
1218 async fn respond_to_contact_request(
1219 &self,
1220 responder_id: UserId,
1221 requester_id: UserId,
1222 accept: bool,
1223 ) -> Result<()> {
1224 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1225 (responder_id, requester_id, false)
1226 } else {
1227 (requester_id, responder_id, true)
1228 };
1229 let result = if accept {
1230 let query = "
1231 UPDATE contacts
1232 SET accepted = 't', should_notify = 't'
1233 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1234 ";
1235 sqlx::query(query)
1236 .bind(id_a.0)
1237 .bind(id_b.0)
1238 .bind(a_to_b)
1239 .execute(&self.pool)
1240 .await?
1241 } else {
1242 let query = "
1243 DELETE FROM contacts
1244 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1245 ";
1246 sqlx::query(query)
1247 .bind(id_a.0)
1248 .bind(id_b.0)
1249 .bind(a_to_b)
1250 .execute(&self.pool)
1251 .await?
1252 };
1253 if result.rows_affected() == 1 {
1254 Ok(())
1255 } else {
1256 Err(anyhow!("no such contact request"))?
1257 }
1258 }
1259
1260 // access tokens
1261
1262 async fn create_access_token_hash(
1263 &self,
1264 user_id: UserId,
1265 access_token_hash: &str,
1266 max_access_token_count: usize,
1267 ) -> Result<()> {
1268 let insert_query = "
1269 INSERT INTO access_tokens (user_id, hash)
1270 VALUES ($1, $2);
1271 ";
1272 let cleanup_query = "
1273 DELETE FROM access_tokens
1274 WHERE id IN (
1275 SELECT id from access_tokens
1276 WHERE user_id = $1
1277 ORDER BY id DESC
1278 OFFSET $3
1279 )
1280 ";
1281
1282 let mut tx = self.pool.begin().await?;
1283 sqlx::query(insert_query)
1284 .bind(user_id.0)
1285 .bind(access_token_hash)
1286 .execute(&mut tx)
1287 .await?;
1288 sqlx::query(cleanup_query)
1289 .bind(user_id.0)
1290 .bind(access_token_hash)
1291 .bind(max_access_token_count as i32)
1292 .execute(&mut tx)
1293 .await?;
1294 Ok(tx.commit().await?)
1295 }
1296
1297 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1298 let query = "
1299 SELECT hash
1300 FROM access_tokens
1301 WHERE user_id = $1
1302 ORDER BY id DESC
1303 ";
1304 Ok(sqlx::query_scalar(query)
1305 .bind(user_id.0)
1306 .fetch_all(&self.pool)
1307 .await?)
1308 }
1309
1310 // orgs
1311
1312 #[allow(unused)] // Help rust-analyzer
1313 #[cfg(any(test, feature = "seed-support"))]
1314 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1315 let query = "
1316 SELECT *
1317 FROM orgs
1318 WHERE slug = $1
1319 ";
1320 Ok(sqlx::query_as(query)
1321 .bind(slug)
1322 .fetch_optional(&self.pool)
1323 .await?)
1324 }
1325
1326 #[cfg(any(test, feature = "seed-support"))]
1327 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1328 let query = "
1329 INSERT INTO orgs (name, slug)
1330 VALUES ($1, $2)
1331 RETURNING id
1332 ";
1333 Ok(sqlx::query_scalar(query)
1334 .bind(name)
1335 .bind(slug)
1336 .fetch_one(&self.pool)
1337 .await
1338 .map(OrgId)?)
1339 }
1340
1341 #[cfg(any(test, feature = "seed-support"))]
1342 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1343 let query = "
1344 INSERT INTO org_memberships (org_id, user_id, admin)
1345 VALUES ($1, $2, $3)
1346 ON CONFLICT DO NOTHING
1347 ";
1348 Ok(sqlx::query(query)
1349 .bind(org_id.0)
1350 .bind(user_id.0)
1351 .bind(is_admin)
1352 .execute(&self.pool)
1353 .await
1354 .map(drop)?)
1355 }
1356
1357 // channels
1358
1359 #[cfg(any(test, feature = "seed-support"))]
1360 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1361 let query = "
1362 INSERT INTO channels (owner_id, owner_is_user, name)
1363 VALUES ($1, false, $2)
1364 RETURNING id
1365 ";
1366 Ok(sqlx::query_scalar(query)
1367 .bind(org_id.0)
1368 .bind(name)
1369 .fetch_one(&self.pool)
1370 .await
1371 .map(ChannelId)?)
1372 }
1373
1374 #[allow(unused)] // Help rust-analyzer
1375 #[cfg(any(test, feature = "seed-support"))]
1376 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1377 let query = "
1378 SELECT *
1379 FROM channels
1380 WHERE
1381 channels.owner_is_user = false AND
1382 channels.owner_id = $1
1383 ";
1384 Ok(sqlx::query_as(query)
1385 .bind(org_id.0)
1386 .fetch_all(&self.pool)
1387 .await?)
1388 }
1389
1390 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1391 let query = "
1392 SELECT
1393 channels.*
1394 FROM
1395 channel_memberships, channels
1396 WHERE
1397 channel_memberships.user_id = $1 AND
1398 channel_memberships.channel_id = channels.id
1399 ";
1400 Ok(sqlx::query_as(query)
1401 .bind(user_id.0)
1402 .fetch_all(&self.pool)
1403 .await?)
1404 }
1405
1406 async fn can_user_access_channel(
1407 &self,
1408 user_id: UserId,
1409 channel_id: ChannelId,
1410 ) -> Result<bool> {
1411 let query = "
1412 SELECT id
1413 FROM channel_memberships
1414 WHERE user_id = $1 AND channel_id = $2
1415 LIMIT 1
1416 ";
1417 Ok(sqlx::query_scalar::<_, i32>(query)
1418 .bind(user_id.0)
1419 .bind(channel_id.0)
1420 .fetch_optional(&self.pool)
1421 .await
1422 .map(|e| e.is_some())?)
1423 }
1424
1425 #[cfg(any(test, feature = "seed-support"))]
1426 async fn add_channel_member(
1427 &self,
1428 channel_id: ChannelId,
1429 user_id: UserId,
1430 is_admin: bool,
1431 ) -> Result<()> {
1432 let query = "
1433 INSERT INTO channel_memberships (channel_id, user_id, admin)
1434 VALUES ($1, $2, $3)
1435 ON CONFLICT DO NOTHING
1436 ";
1437 Ok(sqlx::query(query)
1438 .bind(channel_id.0)
1439 .bind(user_id.0)
1440 .bind(is_admin)
1441 .execute(&self.pool)
1442 .await
1443 .map(drop)?)
1444 }
1445
1446 // messages
1447
1448 async fn create_channel_message(
1449 &self,
1450 channel_id: ChannelId,
1451 sender_id: UserId,
1452 body: &str,
1453 timestamp: OffsetDateTime,
1454 nonce: u128,
1455 ) -> Result<MessageId> {
1456 let query = "
1457 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1458 VALUES ($1, $2, $3, $4, $5)
1459 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1460 RETURNING id
1461 ";
1462 Ok(sqlx::query_scalar(query)
1463 .bind(channel_id.0)
1464 .bind(sender_id.0)
1465 .bind(body)
1466 .bind(timestamp)
1467 .bind(Uuid::from_u128(nonce))
1468 .fetch_one(&self.pool)
1469 .await
1470 .map(MessageId)?)
1471 }
1472
1473 async fn get_channel_messages(
1474 &self,
1475 channel_id: ChannelId,
1476 count: usize,
1477 before_id: Option<MessageId>,
1478 ) -> Result<Vec<ChannelMessage>> {
1479 let query = r#"
1480 SELECT * FROM (
1481 SELECT
1482 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1483 FROM
1484 channel_messages
1485 WHERE
1486 channel_id = $1 AND
1487 id < $2
1488 ORDER BY id DESC
1489 LIMIT $3
1490 ) as recent_messages
1491 ORDER BY id ASC
1492 "#;
1493 Ok(sqlx::query_as(query)
1494 .bind(channel_id.0)
1495 .bind(before_id.unwrap_or(MessageId::MAX))
1496 .bind(count as i64)
1497 .fetch_all(&self.pool)
1498 .await?)
1499 }
1500
1501 #[cfg(test)]
1502 async fn teardown(&self, url: &str) {
1503 use util::ResultExt;
1504
1505 let query = "
1506 SELECT pg_terminate_backend(pg_stat_activity.pid)
1507 FROM pg_stat_activity
1508 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1509 ";
1510 sqlx::query(query).execute(&self.pool).await.log_err();
1511 self.pool.close().await;
1512 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1513 .await
1514 .log_err();
1515 }
1516
1517 #[cfg(test)]
1518 fn as_fake(&self) -> Option<&FakeDb> {
1519 None
1520 }
1521}
1522
1523macro_rules! id_type {
1524 ($name:ident) => {
1525 #[derive(
1526 Clone,
1527 Copy,
1528 Debug,
1529 Default,
1530 PartialEq,
1531 Eq,
1532 PartialOrd,
1533 Ord,
1534 Hash,
1535 sqlx::Type,
1536 Serialize,
1537 Deserialize,
1538 )]
1539 #[sqlx(transparent)]
1540 #[serde(transparent)]
1541 pub struct $name(pub i32);
1542
1543 impl $name {
1544 #[allow(unused)]
1545 pub const MAX: Self = Self(i32::MAX);
1546
1547 #[allow(unused)]
1548 pub fn from_proto(value: u64) -> Self {
1549 Self(value as i32)
1550 }
1551
1552 #[allow(unused)]
1553 pub fn to_proto(self) -> u64 {
1554 self.0 as u64
1555 }
1556 }
1557
1558 impl std::fmt::Display for $name {
1559 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1560 self.0.fmt(f)
1561 }
1562 }
1563 };
1564}
1565
1566id_type!(UserId);
1567#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1568pub struct User {
1569 pub id: UserId,
1570 pub github_login: String,
1571 pub github_user_id: Option<i32>,
1572 pub email_address: Option<String>,
1573 pub admin: bool,
1574 pub invite_code: Option<String>,
1575 pub invite_count: i32,
1576 pub connected_once: bool,
1577}
1578
1579id_type!(ProjectId);
1580#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1581pub struct Project {
1582 pub id: ProjectId,
1583 pub host_user_id: UserId,
1584 pub unregistered: bool,
1585}
1586
1587#[derive(Clone, Debug, PartialEq, Serialize)]
1588pub struct UserActivitySummary {
1589 pub id: UserId,
1590 pub github_login: String,
1591 pub project_activity: Vec<ProjectActivitySummary>,
1592}
1593
1594#[derive(Clone, Debug, PartialEq, Serialize)]
1595pub struct ProjectActivitySummary {
1596 pub id: ProjectId,
1597 pub duration: Duration,
1598 pub max_collaborators: usize,
1599}
1600
1601#[derive(Clone, Debug, PartialEq, Serialize)]
1602pub struct UserActivityPeriod {
1603 pub project_id: ProjectId,
1604 #[serde(with = "time::serde::iso8601")]
1605 pub start: OffsetDateTime,
1606 #[serde(with = "time::serde::iso8601")]
1607 pub end: OffsetDateTime,
1608 pub extensions: HashMap<String, usize>,
1609}
1610
1611id_type!(OrgId);
1612#[derive(FromRow)]
1613pub struct Org {
1614 pub id: OrgId,
1615 pub name: String,
1616 pub slug: String,
1617}
1618
1619id_type!(ChannelId);
1620#[derive(Clone, Debug, FromRow, Serialize)]
1621pub struct Channel {
1622 pub id: ChannelId,
1623 pub name: String,
1624 pub owner_id: i32,
1625 pub owner_is_user: bool,
1626}
1627
1628id_type!(MessageId);
1629#[derive(Clone, Debug, FromRow)]
1630pub struct ChannelMessage {
1631 pub id: MessageId,
1632 pub channel_id: ChannelId,
1633 pub sender_id: UserId,
1634 pub body: String,
1635 pub sent_at: OffsetDateTime,
1636 pub nonce: Uuid,
1637}
1638
1639#[derive(Clone, Debug, PartialEq, Eq)]
1640pub enum Contact {
1641 Accepted {
1642 user_id: UserId,
1643 should_notify: bool,
1644 },
1645 Outgoing {
1646 user_id: UserId,
1647 },
1648 Incoming {
1649 user_id: UserId,
1650 should_notify: bool,
1651 },
1652}
1653
1654impl Contact {
1655 pub fn user_id(&self) -> UserId {
1656 match self {
1657 Contact::Accepted { user_id, .. } => *user_id,
1658 Contact::Outgoing { user_id } => *user_id,
1659 Contact::Incoming { user_id, .. } => *user_id,
1660 }
1661 }
1662}
1663
1664#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1665pub struct IncomingContactRequest {
1666 pub requester_id: UserId,
1667 pub should_notify: bool,
1668}
1669
1670#[derive(Clone, Deserialize)]
1671pub struct Signup {
1672 pub email_address: String,
1673 pub platform_mac: bool,
1674 pub platform_windows: bool,
1675 pub platform_linux: bool,
1676 pub editor_features: Vec<String>,
1677 pub programming_languages: Vec<String>,
1678 pub device_id: String,
1679}
1680
1681#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1682pub struct WaitlistSummary {
1683 #[sqlx(default)]
1684 pub count: i64,
1685 #[sqlx(default)]
1686 pub linux_count: i64,
1687 #[sqlx(default)]
1688 pub mac_count: i64,
1689 #[sqlx(default)]
1690 pub windows_count: i64,
1691}
1692
1693#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1694pub struct Invite {
1695 pub email_address: String,
1696 pub email_confirmation_code: String,
1697}
1698
1699#[derive(Debug, Serialize, Deserialize)]
1700pub struct NewUserParams {
1701 pub github_login: String,
1702 pub github_user_id: i32,
1703 pub invite_count: i32,
1704}
1705
1706fn random_invite_code() -> String {
1707 nanoid::nanoid!(16)
1708}
1709
1710fn random_email_confirmation_code() -> String {
1711 nanoid::nanoid!(64)
1712}
1713
1714#[cfg(test)]
1715pub use test::*;
1716
1717#[cfg(test)]
1718mod test {
1719 use super::*;
1720 use anyhow::anyhow;
1721 use collections::BTreeMap;
1722 use gpui::executor::Background;
1723 use lazy_static::lazy_static;
1724 use parking_lot::Mutex;
1725 use rand::prelude::*;
1726 use sqlx::{
1727 migrate::{MigrateDatabase, Migrator},
1728 Postgres,
1729 };
1730 use std::{path::Path, sync::Arc};
1731 use util::post_inc;
1732
1733 pub struct FakeDb {
1734 background: Arc<Background>,
1735 pub users: Mutex<BTreeMap<UserId, User>>,
1736 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1737 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1738 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1739 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1740 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1741 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1742 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1743 pub contacts: Mutex<Vec<FakeContact>>,
1744 next_channel_message_id: Mutex<i32>,
1745 next_user_id: Mutex<i32>,
1746 next_org_id: Mutex<i32>,
1747 next_channel_id: Mutex<i32>,
1748 next_project_id: Mutex<i32>,
1749 }
1750
1751 #[derive(Debug)]
1752 pub struct FakeContact {
1753 pub requester_id: UserId,
1754 pub responder_id: UserId,
1755 pub accepted: bool,
1756 pub should_notify: bool,
1757 }
1758
1759 impl FakeDb {
1760 pub fn new(background: Arc<Background>) -> Self {
1761 Self {
1762 background,
1763 users: Default::default(),
1764 next_user_id: Mutex::new(0),
1765 projects: Default::default(),
1766 worktree_extensions: Default::default(),
1767 next_project_id: Mutex::new(1),
1768 orgs: Default::default(),
1769 next_org_id: Mutex::new(1),
1770 org_memberships: Default::default(),
1771 channels: Default::default(),
1772 next_channel_id: Mutex::new(1),
1773 channel_memberships: Default::default(),
1774 channel_messages: Default::default(),
1775 next_channel_message_id: Mutex::new(1),
1776 contacts: Default::default(),
1777 }
1778 }
1779 }
1780
1781 #[async_trait]
1782 impl Db for FakeDb {
1783 async fn create_user(
1784 &self,
1785 email_address: &str,
1786 admin: bool,
1787 params: NewUserParams,
1788 ) -> Result<UserId> {
1789 self.background.simulate_random_delay().await;
1790
1791 let mut users = self.users.lock();
1792 if let Some(user) = users
1793 .values()
1794 .find(|user| user.github_login == params.github_login)
1795 {
1796 Ok(user.id)
1797 } else {
1798 let id = post_inc(&mut *self.next_user_id.lock());
1799 let user_id = UserId(id);
1800 users.insert(
1801 user_id,
1802 User {
1803 id: user_id,
1804 github_login: params.github_login,
1805 github_user_id: Some(params.github_user_id),
1806 email_address: Some(email_address.to_string()),
1807 admin,
1808 invite_code: None,
1809 invite_count: 0,
1810 connected_once: false,
1811 },
1812 );
1813 Ok(user_id)
1814 }
1815 }
1816
1817 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1818 unimplemented!()
1819 }
1820
1821 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1822 unimplemented!()
1823 }
1824
1825 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1826 self.background.simulate_random_delay().await;
1827 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1828 }
1829
1830 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1831 self.background.simulate_random_delay().await;
1832 let users = self.users.lock();
1833 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1834 }
1835
1836 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
1837 unimplemented!()
1838 }
1839
1840 async fn get_user_by_github_account(
1841 &self,
1842 github_login: &str,
1843 github_user_id: Option<i32>,
1844 ) -> Result<Option<User>> {
1845 self.background.simulate_random_delay().await;
1846 if let Some(github_user_id) = github_user_id {
1847 for user in self.users.lock().values_mut() {
1848 if user.github_user_id == Some(github_user_id) {
1849 user.github_login = github_login.into();
1850 return Ok(Some(user.clone()));
1851 }
1852 if user.github_login == github_login {
1853 user.github_user_id = Some(github_user_id);
1854 return Ok(Some(user.clone()));
1855 }
1856 }
1857 Ok(None)
1858 } else {
1859 Ok(self
1860 .users
1861 .lock()
1862 .values()
1863 .find(|user| user.github_login == github_login)
1864 .cloned())
1865 }
1866 }
1867
1868 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1869 unimplemented!()
1870 }
1871
1872 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1873 self.background.simulate_random_delay().await;
1874 let mut users = self.users.lock();
1875 let mut user = users
1876 .get_mut(&id)
1877 .ok_or_else(|| anyhow!("user not found"))?;
1878 user.connected_once = connected_once;
1879 Ok(())
1880 }
1881
1882 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1883 unimplemented!()
1884 }
1885
1886 // signups
1887
1888 async fn create_signup(&self, _signup: Signup) -> Result<()> {
1889 unimplemented!()
1890 }
1891
1892 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
1893 unimplemented!()
1894 }
1895
1896 async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
1897 unimplemented!()
1898 }
1899
1900 async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
1901 unimplemented!()
1902 }
1903
1904 async fn create_user_from_invite(
1905 &self,
1906 _invite: &Invite,
1907 _user: NewUserParams,
1908 ) -> Result<(UserId, Option<UserId>, String)> {
1909 unimplemented!()
1910 }
1911
1912 // invite codes
1913
1914 async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
1915 unimplemented!()
1916 }
1917
1918 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1919 self.background.simulate_random_delay().await;
1920 Ok(None)
1921 }
1922
1923 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1924 unimplemented!()
1925 }
1926
1927 async fn create_invite_from_code(
1928 &self,
1929 _code: &str,
1930 _email_address: &str,
1931 ) -> Result<Invite> {
1932 unimplemented!()
1933 }
1934
1935 // projects
1936
1937 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
1938 self.background.simulate_random_delay().await;
1939 if !self.users.lock().contains_key(&host_user_id) {
1940 Err(anyhow!("no such user"))?;
1941 }
1942
1943 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
1944 self.projects.lock().insert(
1945 project_id,
1946 Project {
1947 id: project_id,
1948 host_user_id,
1949 unregistered: false,
1950 },
1951 );
1952 Ok(project_id)
1953 }
1954
1955 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
1956 self.background.simulate_random_delay().await;
1957 self.projects
1958 .lock()
1959 .get_mut(&project_id)
1960 .ok_or_else(|| anyhow!("no such project"))?
1961 .unregistered = true;
1962 Ok(())
1963 }
1964
1965 async fn update_worktree_extensions(
1966 &self,
1967 project_id: ProjectId,
1968 worktree_id: u64,
1969 extensions: HashMap<String, u32>,
1970 ) -> Result<()> {
1971 self.background.simulate_random_delay().await;
1972 if !self.projects.lock().contains_key(&project_id) {
1973 Err(anyhow!("no such project"))?;
1974 }
1975
1976 for (extension, count) in extensions {
1977 self.worktree_extensions
1978 .lock()
1979 .insert((project_id, worktree_id, extension), count);
1980 }
1981
1982 Ok(())
1983 }
1984
1985 async fn get_project_extensions(
1986 &self,
1987 _project_id: ProjectId,
1988 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
1989 unimplemented!()
1990 }
1991
1992 async fn record_user_activity(
1993 &self,
1994 _time_period: Range<OffsetDateTime>,
1995 _active_projects: &[(UserId, ProjectId)],
1996 ) -> Result<()> {
1997 unimplemented!()
1998 }
1999
2000 async fn get_active_user_count(
2001 &self,
2002 _time_period: Range<OffsetDateTime>,
2003 _min_duration: Duration,
2004 _only_collaborative: bool,
2005 ) -> Result<usize> {
2006 unimplemented!()
2007 }
2008
2009 async fn get_top_users_activity_summary(
2010 &self,
2011 _time_period: Range<OffsetDateTime>,
2012 _limit: usize,
2013 ) -> Result<Vec<UserActivitySummary>> {
2014 unimplemented!()
2015 }
2016
2017 async fn get_user_activity_timeline(
2018 &self,
2019 _time_period: Range<OffsetDateTime>,
2020 _user_id: UserId,
2021 ) -> Result<Vec<UserActivityPeriod>> {
2022 unimplemented!()
2023 }
2024
2025 // contacts
2026
2027 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
2028 self.background.simulate_random_delay().await;
2029 let mut contacts = vec![Contact::Accepted {
2030 user_id: id,
2031 should_notify: false,
2032 }];
2033
2034 for contact in self.contacts.lock().iter() {
2035 if contact.requester_id == id {
2036 if contact.accepted {
2037 contacts.push(Contact::Accepted {
2038 user_id: contact.responder_id,
2039 should_notify: contact.should_notify,
2040 });
2041 } else {
2042 contacts.push(Contact::Outgoing {
2043 user_id: contact.responder_id,
2044 });
2045 }
2046 } else if contact.responder_id == id {
2047 if contact.accepted {
2048 contacts.push(Contact::Accepted {
2049 user_id: contact.requester_id,
2050 should_notify: false,
2051 });
2052 } else {
2053 contacts.push(Contact::Incoming {
2054 user_id: contact.requester_id,
2055 should_notify: contact.should_notify,
2056 });
2057 }
2058 }
2059 }
2060
2061 contacts.sort_unstable_by_key(|contact| contact.user_id());
2062 Ok(contacts)
2063 }
2064
2065 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2066 self.background.simulate_random_delay().await;
2067 Ok(self.contacts.lock().iter().any(|contact| {
2068 contact.accepted
2069 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2070 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2071 }))
2072 }
2073
2074 async fn send_contact_request(
2075 &self,
2076 requester_id: UserId,
2077 responder_id: UserId,
2078 ) -> Result<()> {
2079 self.background.simulate_random_delay().await;
2080 let mut contacts = self.contacts.lock();
2081 for contact in contacts.iter_mut() {
2082 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2083 if contact.accepted {
2084 Err(anyhow!("contact already exists"))?;
2085 } else {
2086 Err(anyhow!("contact already requested"))?;
2087 }
2088 }
2089 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2090 if contact.accepted {
2091 Err(anyhow!("contact already exists"))?;
2092 } else {
2093 contact.accepted = true;
2094 contact.should_notify = false;
2095 return Ok(());
2096 }
2097 }
2098 }
2099 contacts.push(FakeContact {
2100 requester_id,
2101 responder_id,
2102 accepted: false,
2103 should_notify: true,
2104 });
2105 Ok(())
2106 }
2107
2108 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2109 self.background.simulate_random_delay().await;
2110 self.contacts.lock().retain(|contact| {
2111 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2112 });
2113 Ok(())
2114 }
2115
2116 async fn dismiss_contact_notification(
2117 &self,
2118 user_id: UserId,
2119 contact_user_id: UserId,
2120 ) -> Result<()> {
2121 self.background.simulate_random_delay().await;
2122 let mut contacts = self.contacts.lock();
2123 for contact in contacts.iter_mut() {
2124 if contact.requester_id == contact_user_id
2125 && contact.responder_id == user_id
2126 && !contact.accepted
2127 {
2128 contact.should_notify = false;
2129 return Ok(());
2130 }
2131 if contact.requester_id == user_id
2132 && contact.responder_id == contact_user_id
2133 && contact.accepted
2134 {
2135 contact.should_notify = false;
2136 return Ok(());
2137 }
2138 }
2139 Err(anyhow!("no such notification"))?
2140 }
2141
2142 async fn respond_to_contact_request(
2143 &self,
2144 responder_id: UserId,
2145 requester_id: UserId,
2146 accept: bool,
2147 ) -> Result<()> {
2148 self.background.simulate_random_delay().await;
2149 let mut contacts = self.contacts.lock();
2150 for (ix, contact) in contacts.iter_mut().enumerate() {
2151 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2152 if contact.accepted {
2153 Err(anyhow!("contact already confirmed"))?;
2154 }
2155 if accept {
2156 contact.accepted = true;
2157 contact.should_notify = true;
2158 } else {
2159 contacts.remove(ix);
2160 }
2161 return Ok(());
2162 }
2163 }
2164 Err(anyhow!("no such contact request"))?
2165 }
2166
2167 async fn create_access_token_hash(
2168 &self,
2169 _user_id: UserId,
2170 _access_token_hash: &str,
2171 _max_access_token_count: usize,
2172 ) -> Result<()> {
2173 unimplemented!()
2174 }
2175
2176 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2177 unimplemented!()
2178 }
2179
2180 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2181 unimplemented!()
2182 }
2183
2184 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2185 self.background.simulate_random_delay().await;
2186 let mut orgs = self.orgs.lock();
2187 if orgs.values().any(|org| org.slug == slug) {
2188 Err(anyhow!("org already exists"))?
2189 } else {
2190 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2191 orgs.insert(
2192 org_id,
2193 Org {
2194 id: org_id,
2195 name: name.to_string(),
2196 slug: slug.to_string(),
2197 },
2198 );
2199 Ok(org_id)
2200 }
2201 }
2202
2203 async fn add_org_member(
2204 &self,
2205 org_id: OrgId,
2206 user_id: UserId,
2207 is_admin: bool,
2208 ) -> Result<()> {
2209 self.background.simulate_random_delay().await;
2210 if !self.orgs.lock().contains_key(&org_id) {
2211 Err(anyhow!("org does not exist"))?;
2212 }
2213 if !self.users.lock().contains_key(&user_id) {
2214 Err(anyhow!("user does not exist"))?;
2215 }
2216
2217 self.org_memberships
2218 .lock()
2219 .entry((org_id, user_id))
2220 .or_insert(is_admin);
2221 Ok(())
2222 }
2223
2224 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2225 self.background.simulate_random_delay().await;
2226 if !self.orgs.lock().contains_key(&org_id) {
2227 Err(anyhow!("org does not exist"))?;
2228 }
2229
2230 let mut channels = self.channels.lock();
2231 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2232 channels.insert(
2233 channel_id,
2234 Channel {
2235 id: channel_id,
2236 name: name.to_string(),
2237 owner_id: org_id.0,
2238 owner_is_user: false,
2239 },
2240 );
2241 Ok(channel_id)
2242 }
2243
2244 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2245 self.background.simulate_random_delay().await;
2246 Ok(self
2247 .channels
2248 .lock()
2249 .values()
2250 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2251 .cloned()
2252 .collect())
2253 }
2254
2255 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2256 self.background.simulate_random_delay().await;
2257 let channels = self.channels.lock();
2258 let memberships = self.channel_memberships.lock();
2259 Ok(channels
2260 .values()
2261 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2262 .cloned()
2263 .collect())
2264 }
2265
2266 async fn can_user_access_channel(
2267 &self,
2268 user_id: UserId,
2269 channel_id: ChannelId,
2270 ) -> Result<bool> {
2271 self.background.simulate_random_delay().await;
2272 Ok(self
2273 .channel_memberships
2274 .lock()
2275 .contains_key(&(channel_id, user_id)))
2276 }
2277
2278 async fn add_channel_member(
2279 &self,
2280 channel_id: ChannelId,
2281 user_id: UserId,
2282 is_admin: bool,
2283 ) -> Result<()> {
2284 self.background.simulate_random_delay().await;
2285 if !self.channels.lock().contains_key(&channel_id) {
2286 Err(anyhow!("channel does not exist"))?;
2287 }
2288 if !self.users.lock().contains_key(&user_id) {
2289 Err(anyhow!("user does not exist"))?;
2290 }
2291
2292 self.channel_memberships
2293 .lock()
2294 .entry((channel_id, user_id))
2295 .or_insert(is_admin);
2296 Ok(())
2297 }
2298
2299 async fn create_channel_message(
2300 &self,
2301 channel_id: ChannelId,
2302 sender_id: UserId,
2303 body: &str,
2304 timestamp: OffsetDateTime,
2305 nonce: u128,
2306 ) -> Result<MessageId> {
2307 self.background.simulate_random_delay().await;
2308 if !self.channels.lock().contains_key(&channel_id) {
2309 Err(anyhow!("channel does not exist"))?;
2310 }
2311 if !self.users.lock().contains_key(&sender_id) {
2312 Err(anyhow!("user does not exist"))?;
2313 }
2314
2315 let mut messages = self.channel_messages.lock();
2316 if let Some(message) = messages
2317 .values()
2318 .find(|message| message.nonce.as_u128() == nonce)
2319 {
2320 Ok(message.id)
2321 } else {
2322 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2323 messages.insert(
2324 message_id,
2325 ChannelMessage {
2326 id: message_id,
2327 channel_id,
2328 sender_id,
2329 body: body.to_string(),
2330 sent_at: timestamp,
2331 nonce: Uuid::from_u128(nonce),
2332 },
2333 );
2334 Ok(message_id)
2335 }
2336 }
2337
2338 async fn get_channel_messages(
2339 &self,
2340 channel_id: ChannelId,
2341 count: usize,
2342 before_id: Option<MessageId>,
2343 ) -> Result<Vec<ChannelMessage>> {
2344 self.background.simulate_random_delay().await;
2345 let mut messages = self
2346 .channel_messages
2347 .lock()
2348 .values()
2349 .rev()
2350 .filter(|message| {
2351 message.channel_id == channel_id
2352 && message.id < before_id.unwrap_or(MessageId::MAX)
2353 })
2354 .take(count)
2355 .cloned()
2356 .collect::<Vec<_>>();
2357 messages.sort_unstable_by_key(|message| message.id);
2358 Ok(messages)
2359 }
2360
2361 async fn teardown(&self, _: &str) {}
2362
2363 #[cfg(test)]
2364 fn as_fake(&self) -> Option<&FakeDb> {
2365 Some(self)
2366 }
2367 }
2368
2369 pub struct TestDb {
2370 pub db: Option<Arc<dyn Db>>,
2371 pub url: String,
2372 }
2373
2374 impl TestDb {
2375 #[allow(clippy::await_holding_lock)]
2376 pub async fn postgres() -> Self {
2377 lazy_static! {
2378 static ref LOCK: Mutex<()> = Mutex::new(());
2379 }
2380
2381 let _guard = LOCK.lock();
2382 let mut rng = StdRng::from_entropy();
2383 let name = format!("zed-test-{}", rng.gen::<u128>());
2384 let url = format!("postgres://postgres@localhost/{}", name);
2385 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2386 Postgres::create_database(&url)
2387 .await
2388 .expect("failed to create test db");
2389 let db = PostgresDb::new(&url, 5).await.unwrap();
2390 let migrator = Migrator::new(migrations_path).await.unwrap();
2391 migrator.run(&db.pool).await.unwrap();
2392 Self {
2393 db: Some(Arc::new(db)),
2394 url,
2395 }
2396 }
2397
2398 pub fn fake(background: Arc<Background>) -> Self {
2399 Self {
2400 db: Some(Arc::new(FakeDb::new(background))),
2401 url: Default::default(),
2402 }
2403 }
2404
2405 pub fn db(&self) -> &Arc<dyn Db> {
2406 self.db.as_ref().unwrap()
2407 }
2408 }
2409
2410 impl Drop for TestDb {
2411 fn drop(&mut self) {
2412 if let Some(db) = self.db.take() {
2413 futures::executor::block_on(db.teardown(&self.url));
2414 }
2415 }
2416 }
2417}