1use crate::{Error, Result};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use axum::http::StatusCode;
5use collections::HashMap;
6use futures::StreamExt;
7use serde::{Deserialize, Serialize};
8pub use sqlx::postgres::PgPoolOptions as DbOptions;
9use sqlx::{types::Uuid, FromRow, QueryBuilder, Row};
10use std::{cmp, ops::Range, time::Duration};
11use time::{OffsetDateTime, PrimitiveDateTime};
12
13#[async_trait]
14pub trait Db: Send + Sync {
15 async fn create_user(
16 &self,
17 github_login: &str,
18 email_address: &str,
19 admin: bool,
20 ) -> Result<UserId>;
21 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>>;
22 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>>;
23 async fn fuzzy_search_users(&self, query: &str, limit: u32) -> Result<Vec<User>>;
24 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
25 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
26 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>>;
27 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
28 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
29 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()>;
30 async fn destroy_user(&self, id: UserId) -> Result<()>;
31
32 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()>;
33 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>>;
34 async fn get_user_for_invite_code(&self, code: &str) -> Result<User>;
35 async fn create_invite_from_code(&self, code: &str, email_address: &str) -> Result<Invite>;
36
37 async fn create_signup(&self, signup: Signup) -> Result<()>;
38 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary>;
39 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>>;
40 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()>;
41 async fn create_user_from_invite(
42 &self,
43 invite: &Invite,
44 user: NewUserParams,
45 ) -> Result<(UserId, Option<UserId>)>;
46
47 /// Registers a new project for the given user.
48 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId>;
49
50 /// Unregisters a project for the given project id.
51 async fn unregister_project(&self, project_id: ProjectId) -> Result<()>;
52
53 /// Update file counts by extension for the given project and worktree.
54 async fn update_worktree_extensions(
55 &self,
56 project_id: ProjectId,
57 worktree_id: u64,
58 extensions: HashMap<String, u32>,
59 ) -> Result<()>;
60
61 /// Get the file counts on the given project keyed by their worktree and extension.
62 async fn get_project_extensions(
63 &self,
64 project_id: ProjectId,
65 ) -> Result<HashMap<u64, HashMap<String, usize>>>;
66
67 /// Record which users have been active in which projects during
68 /// a given period of time.
69 async fn record_user_activity(
70 &self,
71 time_period: Range<OffsetDateTime>,
72 active_projects: &[(UserId, ProjectId)],
73 ) -> Result<()>;
74
75 /// Get the number of users who have been active in the given
76 /// time period for at least the given time duration.
77 async fn get_active_user_count(
78 &self,
79 time_period: Range<OffsetDateTime>,
80 min_duration: Duration,
81 only_collaborative: bool,
82 ) -> Result<usize>;
83
84 /// Get the users that have been most active during the given time period,
85 /// along with the amount of time they have been active in each project.
86 async fn get_top_users_activity_summary(
87 &self,
88 time_period: Range<OffsetDateTime>,
89 max_user_count: usize,
90 ) -> Result<Vec<UserActivitySummary>>;
91
92 /// Get the project activity for the given user and time period.
93 async fn get_user_activity_timeline(
94 &self,
95 time_period: Range<OffsetDateTime>,
96 user_id: UserId,
97 ) -> Result<Vec<UserActivityPeriod>>;
98
99 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
100 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
101 async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
102 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
103 async fn dismiss_contact_notification(
104 &self,
105 responder_id: UserId,
106 requester_id: UserId,
107 ) -> Result<()>;
108 async fn respond_to_contact_request(
109 &self,
110 responder_id: UserId,
111 requester_id: UserId,
112 accept: bool,
113 ) -> Result<()>;
114
115 async fn create_access_token_hash(
116 &self,
117 user_id: UserId,
118 access_token_hash: &str,
119 max_access_token_count: usize,
120 ) -> Result<()>;
121 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
122
123 #[cfg(any(test, feature = "seed-support"))]
124 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
125 #[cfg(any(test, feature = "seed-support"))]
126 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
127 #[cfg(any(test, feature = "seed-support"))]
128 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
129 #[cfg(any(test, feature = "seed-support"))]
130 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
131 #[cfg(any(test, feature = "seed-support"))]
132
133 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
134 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
135 async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
136 -> Result<bool>;
137
138 #[cfg(any(test, feature = "seed-support"))]
139 async fn add_channel_member(
140 &self,
141 channel_id: ChannelId,
142 user_id: UserId,
143 is_admin: bool,
144 ) -> Result<()>;
145 async fn create_channel_message(
146 &self,
147 channel_id: ChannelId,
148 sender_id: UserId,
149 body: &str,
150 timestamp: OffsetDateTime,
151 nonce: u128,
152 ) -> Result<MessageId>;
153 async fn get_channel_messages(
154 &self,
155 channel_id: ChannelId,
156 count: usize,
157 before_id: Option<MessageId>,
158 ) -> Result<Vec<ChannelMessage>>;
159
160 #[cfg(test)]
161 async fn teardown(&self, url: &str);
162
163 #[cfg(test)]
164 fn as_fake(&self) -> Option<&FakeDb>;
165}
166
167pub struct PostgresDb {
168 pool: sqlx::PgPool,
169}
170
171impl PostgresDb {
172 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
173 let pool = DbOptions::new()
174 .max_connections(max_connections)
175 .connect(url)
176 .await
177 .context("failed to connect to postgres database")?;
178 Ok(Self { pool })
179 }
180
181 pub fn fuzzy_like_string(string: &str) -> String {
182 let mut result = String::with_capacity(string.len() * 2 + 1);
183 for c in string.chars() {
184 if c.is_alphanumeric() {
185 result.push('%');
186 result.push(c);
187 }
188 }
189 result.push('%');
190 result
191 }
192}
193
194#[async_trait]
195impl Db for PostgresDb {
196 // users
197
198 async fn create_user(
199 &self,
200 github_login: &str,
201 email_address: &str,
202 admin: bool,
203 ) -> Result<UserId> {
204 let query = "
205 INSERT INTO users (github_login, email_address, admin)
206 VALUES ($1, $2, $3)
207 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
208 RETURNING id
209 ";
210 Ok(sqlx::query_scalar(query)
211 .bind(github_login)
212 .bind(email_address)
213 .bind(admin)
214 .fetch_one(&self.pool)
215 .await
216 .map(UserId)?)
217 }
218
219 async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
220 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
221 Ok(sqlx::query_as(query)
222 .bind(limit as i32)
223 .bind((page * limit) as i32)
224 .fetch_all(&self.pool)
225 .await?)
226 }
227
228 async fn create_users(&self, users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
229 let mut query = QueryBuilder::new(
230 "INSERT INTO users (github_login, email_address, admin, invite_code, invite_count)",
231 );
232 query.push_values(
233 users,
234 |mut query, (github_login, email_address, invite_count)| {
235 query
236 .push_bind(github_login)
237 .push_bind(email_address)
238 .push_bind(false)
239 .push_bind(random_invite_code())
240 .push_bind(invite_count as i32);
241 },
242 );
243 query.push(
244 "
245 ON CONFLICT (github_login) DO UPDATE SET
246 github_login = excluded.github_login,
247 invite_count = excluded.invite_count,
248 invite_code = CASE WHEN users.invite_code IS NULL
249 THEN excluded.invite_code
250 ELSE users.invite_code
251 END
252 RETURNING id
253 ",
254 );
255
256 let rows = query.build().fetch_all(&self.pool).await?;
257 Ok(rows
258 .into_iter()
259 .filter_map(|row| row.try_get::<UserId, _>(0).ok())
260 .collect())
261 }
262
263 async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
264 let like_string = Self::fuzzy_like_string(name_query);
265 let query = "
266 SELECT users.*
267 FROM users
268 WHERE github_login ILIKE $1
269 ORDER BY github_login <-> $2
270 LIMIT $3
271 ";
272 Ok(sqlx::query_as(query)
273 .bind(like_string)
274 .bind(name_query)
275 .bind(limit as i32)
276 .fetch_all(&self.pool)
277 .await?)
278 }
279
280 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
281 let users = self.get_users_by_ids(vec![id]).await?;
282 Ok(users.into_iter().next())
283 }
284
285 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
286 let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
287 let query = "
288 SELECT users.*
289 FROM users
290 WHERE users.id = ANY ($1)
291 ";
292 Ok(sqlx::query_as(query)
293 .bind(&ids)
294 .fetch_all(&self.pool)
295 .await?)
296 }
297
298 async fn get_users_with_no_invites(&self, invited_by_another_user: bool) -> Result<Vec<User>> {
299 let query = format!(
300 "
301 SELECT users.*
302 FROM users
303 WHERE invite_count = 0
304 AND inviter_id IS{} NULL
305 ",
306 if invited_by_another_user { " NOT" } else { "" }
307 );
308
309 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
310 }
311
312 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
313 let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
314 Ok(sqlx::query_as(query)
315 .bind(github_login)
316 .fetch_optional(&self.pool)
317 .await?)
318 }
319
320 async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
321 let query = "UPDATE users SET admin = $1 WHERE id = $2";
322 Ok(sqlx::query(query)
323 .bind(is_admin)
324 .bind(id.0)
325 .execute(&self.pool)
326 .await
327 .map(drop)?)
328 }
329
330 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
331 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
332 Ok(sqlx::query(query)
333 .bind(connected_once)
334 .bind(id.0)
335 .execute(&self.pool)
336 .await
337 .map(drop)?)
338 }
339
340 async fn destroy_user(&self, id: UserId) -> Result<()> {
341 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
342 sqlx::query(query)
343 .bind(id.0)
344 .execute(&self.pool)
345 .await
346 .map(drop)?;
347 let query = "DELETE FROM users WHERE id = $1;";
348 Ok(sqlx::query(query)
349 .bind(id.0)
350 .execute(&self.pool)
351 .await
352 .map(drop)?)
353 }
354
355 // signups
356
357 async fn create_signup(&self, signup: Signup) -> Result<()> {
358 sqlx::query(
359 "
360 INSERT INTO signups
361 (
362 email_address,
363 email_confirmation_code,
364 email_confirmation_sent,
365 platform_linux,
366 platform_mac,
367 platform_windows,
368 platform_unknown,
369 editor_features,
370 programming_languages
371 )
372 VALUES
373 ($1, $2, 'f', $3, $4, $5, 'f', $6, $7)
374 ",
375 )
376 .bind(&signup.email_address)
377 .bind(&random_email_confirmation_code())
378 .bind(&signup.platform_linux)
379 .bind(&signup.platform_mac)
380 .bind(&signup.platform_windows)
381 .bind(&signup.editor_features)
382 .bind(&signup.programming_languages)
383 .execute(&self.pool)
384 .await?;
385 Ok(())
386 }
387
388 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
389 Ok(sqlx::query_as(
390 "
391 SELECT
392 COUNT(*) as count,
393 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
394 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
395 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count
396 FROM (
397 SELECT *
398 FROM signups
399 WHERE
400 NOT email_confirmation_sent
401 ) AS unsent
402 ",
403 )
404 .fetch_one(&self.pool)
405 .await?)
406 }
407
408 async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
409 Ok(sqlx::query_as(
410 "
411 SELECT
412 email_address, email_confirmation_code
413 FROM signups
414 WHERE
415 NOT email_confirmation_sent AND
416 platform_mac
417 LIMIT $1
418 ",
419 )
420 .bind(count as i32)
421 .fetch_all(&self.pool)
422 .await?)
423 }
424
425 async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
426 sqlx::query(
427 "
428 UPDATE signups
429 SET email_confirmation_sent = 't'
430 WHERE email_address = ANY ($1)
431 ",
432 )
433 .bind(
434 &invites
435 .iter()
436 .map(|s| s.email_address.as_str())
437 .collect::<Vec<_>>(),
438 )
439 .execute(&self.pool)
440 .await?;
441 Ok(())
442 }
443
444 async fn create_user_from_invite(
445 &self,
446 invite: &Invite,
447 user: NewUserParams,
448 ) -> Result<(UserId, Option<UserId>)> {
449 let mut tx = self.pool.begin().await?;
450
451 let (signup_id, metrics_id, inviting_user_id): (i32, i32, Option<UserId>) = sqlx::query_as(
452 "
453 SELECT id, metrics_id, inviting_user_id
454 FROM signups
455 WHERE
456 email_address = $1 AND
457 email_confirmation_code = $2 AND
458 user_id is NULL
459 ",
460 )
461 .bind(&invite.email_address)
462 .bind(&invite.email_confirmation_code)
463 .fetch_optional(&mut tx)
464 .await?
465 .ok_or_else(|| anyhow!("no such invite"))?;
466
467 let user_id: UserId = sqlx::query_scalar(
468 "
469 INSERT INTO users
470 (email_address, github_login, admin, invite_count, invite_code, metrics_id)
471 VALUES
472 ($1, $2, 'f', $3, $4, $5)
473 RETURNING id
474 ",
475 )
476 .bind(&invite.email_address)
477 .bind(&user.github_login)
478 .bind(&user.invite_count)
479 .bind(random_invite_code())
480 .bind(metrics_id)
481 .fetch_one(&mut tx)
482 .await?;
483
484 sqlx::query(
485 "
486 UPDATE signups
487 SET user_id = $1
488 WHERE id = $2
489 ",
490 )
491 .bind(&user_id)
492 .bind(&signup_id)
493 .execute(&mut tx)
494 .await?;
495
496 if let Some(inviting_user_id) = inviting_user_id {
497 let id: Option<UserId> = sqlx::query_scalar(
498 "
499 UPDATE users
500 SET invite_count = invite_count - 1
501 WHERE id = $1 AND invite_count > 0
502 RETURNING id
503 ",
504 )
505 .bind(&inviting_user_id)
506 .fetch_optional(&mut tx)
507 .await?;
508
509 if id.is_none() {
510 Err(Error::Http(
511 StatusCode::UNAUTHORIZED,
512 "no invites remaining".to_string(),
513 ))?;
514 }
515
516 sqlx::query(
517 "
518 INSERT INTO contacts
519 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
520 VALUES
521 ($1, $2, 't', 't', 't')
522 ",
523 )
524 .bind(inviting_user_id)
525 .bind(user_id)
526 .execute(&mut tx)
527 .await?;
528 }
529
530 tx.commit().await?;
531 Ok((user_id, inviting_user_id))
532 }
533
534 // invite codes
535
536 async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
537 let mut tx = self.pool.begin().await?;
538 if count > 0 {
539 sqlx::query(
540 "
541 UPDATE users
542 SET invite_code = $1
543 WHERE id = $2 AND invite_code IS NULL
544 ",
545 )
546 .bind(random_invite_code())
547 .bind(id)
548 .execute(&mut tx)
549 .await?;
550 }
551
552 sqlx::query(
553 "
554 UPDATE users
555 SET invite_count = $1
556 WHERE id = $2
557 ",
558 )
559 .bind(count as i32)
560 .bind(id)
561 .execute(&mut tx)
562 .await?;
563 tx.commit().await?;
564 Ok(())
565 }
566
567 async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
568 let result: Option<(String, i32)> = sqlx::query_as(
569 "
570 SELECT invite_code, invite_count
571 FROM users
572 WHERE id = $1 AND invite_code IS NOT NULL
573 ",
574 )
575 .bind(id)
576 .fetch_optional(&self.pool)
577 .await?;
578 if let Some((code, count)) = result {
579 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
580 } else {
581 Ok(None)
582 }
583 }
584
585 async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
586 sqlx::query_as(
587 "
588 SELECT *
589 FROM users
590 WHERE invite_code = $1
591 ",
592 )
593 .bind(code)
594 .fetch_optional(&self.pool)
595 .await?
596 .ok_or_else(|| {
597 Error::Http(
598 StatusCode::NOT_FOUND,
599 "that invite code does not exist".to_string(),
600 )
601 })
602 }
603
604 async fn create_invite_from_code(&self, code: &str, email_address: &str) -> Result<Invite> {
605 let mut tx = self.pool.begin().await?;
606
607 let existing_user: Option<UserId> = sqlx::query_scalar(
608 "
609 SELECT id
610 FROM users
611 WHERE email_address = $1
612 ",
613 )
614 .bind(email_address)
615 .fetch_optional(&mut tx)
616 .await?;
617 if existing_user.is_some() {
618 Err(anyhow!("email address is already in use"))?;
619 }
620
621 let row: Option<(UserId, i32)> = sqlx::query_as(
622 "
623 SELECT id, invite_count
624 FROM users
625 WHERE invite_code = $1
626 ",
627 )
628 .bind(code)
629 .fetch_optional(&mut tx)
630 .await?;
631
632 let (inviter_id, invite_count) = match row {
633 Some(row) => row,
634 None => Err(Error::Http(
635 StatusCode::NOT_FOUND,
636 "invite code not found".to_string(),
637 ))?,
638 };
639
640 if invite_count == 0 {
641 Err(Error::Http(
642 StatusCode::UNAUTHORIZED,
643 "no invites remaining".to_string(),
644 ))?;
645 }
646
647 let email_confirmation_code: String = sqlx::query_scalar(
648 "
649 INSERT INTO signups
650 (
651 email_address,
652 email_confirmation_code,
653 email_confirmation_sent,
654 inviting_user_id,
655 platform_linux,
656 platform_mac,
657 platform_windows,
658 platform_unknown
659 )
660 VALUES
661 ($1, $2, 'f', $3, 'f', 'f', 'f', 't')
662 ON CONFLICT (email_address)
663 DO UPDATE SET
664 inviting_user_id = excluded.inviting_user_id
665 RETURNING email_confirmation_code
666 ",
667 )
668 .bind(&email_address)
669 .bind(&random_email_confirmation_code())
670 .bind(&inviter_id)
671 .fetch_one(&mut tx)
672 .await?;
673
674 tx.commit().await?;
675
676 Ok(Invite {
677 email_address: email_address.into(),
678 email_confirmation_code,
679 })
680 }
681
682 // projects
683
684 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
685 Ok(sqlx::query_scalar(
686 "
687 INSERT INTO projects(host_user_id)
688 VALUES ($1)
689 RETURNING id
690 ",
691 )
692 .bind(host_user_id)
693 .fetch_one(&self.pool)
694 .await
695 .map(ProjectId)?)
696 }
697
698 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
699 sqlx::query(
700 "
701 UPDATE projects
702 SET unregistered = 't'
703 WHERE id = $1
704 ",
705 )
706 .bind(project_id)
707 .execute(&self.pool)
708 .await?;
709 Ok(())
710 }
711
712 async fn update_worktree_extensions(
713 &self,
714 project_id: ProjectId,
715 worktree_id: u64,
716 extensions: HashMap<String, u32>,
717 ) -> Result<()> {
718 if extensions.is_empty() {
719 return Ok(());
720 }
721
722 let mut query = QueryBuilder::new(
723 "INSERT INTO worktree_extensions (project_id, worktree_id, extension, count)",
724 );
725 query.push_values(extensions, |mut query, (extension, count)| {
726 query
727 .push_bind(project_id)
728 .push_bind(worktree_id as i32)
729 .push_bind(extension)
730 .push_bind(count as i32);
731 });
732 query.push(
733 "
734 ON CONFLICT (project_id, worktree_id, extension) DO UPDATE SET
735 count = excluded.count
736 ",
737 );
738 query.build().execute(&self.pool).await?;
739
740 Ok(())
741 }
742
743 async fn get_project_extensions(
744 &self,
745 project_id: ProjectId,
746 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
747 #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
748 struct WorktreeExtension {
749 worktree_id: i32,
750 extension: String,
751 count: i32,
752 }
753
754 let query = "
755 SELECT worktree_id, extension, count
756 FROM worktree_extensions
757 WHERE project_id = $1
758 ";
759 let counts = sqlx::query_as::<_, WorktreeExtension>(query)
760 .bind(&project_id)
761 .fetch_all(&self.pool)
762 .await?;
763
764 let mut extension_counts = HashMap::default();
765 for count in counts {
766 extension_counts
767 .entry(count.worktree_id as u64)
768 .or_insert_with(HashMap::default)
769 .insert(count.extension, count.count as usize);
770 }
771 Ok(extension_counts)
772 }
773
774 async fn record_user_activity(
775 &self,
776 time_period: Range<OffsetDateTime>,
777 projects: &[(UserId, ProjectId)],
778 ) -> Result<()> {
779 let query = "
780 INSERT INTO project_activity_periods
781 (ended_at, duration_millis, user_id, project_id)
782 VALUES
783 ($1, $2, $3, $4);
784 ";
785
786 let mut tx = self.pool.begin().await?;
787 let duration_millis =
788 ((time_period.end - time_period.start).as_seconds_f64() * 1000.0) as i32;
789 for (user_id, project_id) in projects {
790 sqlx::query(query)
791 .bind(time_period.end)
792 .bind(duration_millis)
793 .bind(user_id)
794 .bind(project_id)
795 .execute(&mut tx)
796 .await?;
797 }
798 tx.commit().await?;
799
800 Ok(())
801 }
802
803 async fn get_active_user_count(
804 &self,
805 time_period: Range<OffsetDateTime>,
806 min_duration: Duration,
807 only_collaborative: bool,
808 ) -> Result<usize> {
809 let mut with_clause = String::new();
810 with_clause.push_str("WITH\n");
811 with_clause.push_str(
812 "
813 project_durations AS (
814 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
815 FROM project_activity_periods
816 WHERE $1 < ended_at AND ended_at <= $2
817 GROUP BY user_id, project_id
818 ),
819 ",
820 );
821 with_clause.push_str(
822 "
823 project_collaborators as (
824 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
825 FROM project_durations
826 GROUP BY project_id
827 ),
828 ",
829 );
830
831 if only_collaborative {
832 with_clause.push_str(
833 "
834 user_durations AS (
835 SELECT user_id, SUM(project_duration) as total_duration
836 FROM project_durations, project_collaborators
837 WHERE
838 project_durations.project_id = project_collaborators.project_id AND
839 max_collaborators > 1
840 GROUP BY user_id
841 ORDER BY total_duration DESC
842 LIMIT $3
843 )
844 ",
845 );
846 } else {
847 with_clause.push_str(
848 "
849 user_durations AS (
850 SELECT user_id, SUM(project_duration) as total_duration
851 FROM project_durations
852 GROUP BY user_id
853 ORDER BY total_duration DESC
854 LIMIT $3
855 )
856 ",
857 );
858 }
859
860 let query = format!(
861 "
862 {with_clause}
863 SELECT count(user_durations.user_id)
864 FROM user_durations
865 WHERE user_durations.total_duration >= $3
866 "
867 );
868
869 let count: i64 = sqlx::query_scalar(&query)
870 .bind(time_period.start)
871 .bind(time_period.end)
872 .bind(min_duration.as_millis() as i64)
873 .fetch_one(&self.pool)
874 .await?;
875 Ok(count as usize)
876 }
877
878 async fn get_top_users_activity_summary(
879 &self,
880 time_period: Range<OffsetDateTime>,
881 max_user_count: usize,
882 ) -> Result<Vec<UserActivitySummary>> {
883 let query = "
884 WITH
885 project_durations AS (
886 SELECT user_id, project_id, SUM(duration_millis) AS project_duration
887 FROM project_activity_periods
888 WHERE $1 < ended_at AND ended_at <= $2
889 GROUP BY user_id, project_id
890 ),
891 user_durations AS (
892 SELECT user_id, SUM(project_duration) as total_duration
893 FROM project_durations
894 GROUP BY user_id
895 ORDER BY total_duration DESC
896 LIMIT $3
897 ),
898 project_collaborators as (
899 SELECT project_id, COUNT(DISTINCT user_id) as max_collaborators
900 FROM project_durations
901 GROUP BY project_id
902 )
903 SELECT user_durations.user_id, users.github_login, project_durations.project_id, project_duration, max_collaborators
904 FROM user_durations, project_durations, project_collaborators, users
905 WHERE
906 user_durations.user_id = project_durations.user_id AND
907 user_durations.user_id = users.id AND
908 project_durations.project_id = project_collaborators.project_id
909 ORDER BY total_duration DESC, user_id ASC, project_id ASC
910 ";
911
912 let mut rows = sqlx::query_as::<_, (UserId, String, ProjectId, i64, i64)>(query)
913 .bind(time_period.start)
914 .bind(time_period.end)
915 .bind(max_user_count as i32)
916 .fetch(&self.pool);
917
918 let mut result = Vec::<UserActivitySummary>::new();
919 while let Some(row) = rows.next().await {
920 let (user_id, github_login, project_id, duration_millis, project_collaborators) = row?;
921 let project_id = project_id;
922 let duration = Duration::from_millis(duration_millis as u64);
923 let project_activity = ProjectActivitySummary {
924 id: project_id,
925 duration,
926 max_collaborators: project_collaborators as usize,
927 };
928 if let Some(last_summary) = result.last_mut() {
929 if last_summary.id == user_id {
930 last_summary.project_activity.push(project_activity);
931 continue;
932 }
933 }
934 result.push(UserActivitySummary {
935 id: user_id,
936 project_activity: vec![project_activity],
937 github_login,
938 });
939 }
940
941 Ok(result)
942 }
943
944 async fn get_user_activity_timeline(
945 &self,
946 time_period: Range<OffsetDateTime>,
947 user_id: UserId,
948 ) -> Result<Vec<UserActivityPeriod>> {
949 const COALESCE_THRESHOLD: Duration = Duration::from_secs(30);
950
951 let query = "
952 SELECT
953 project_activity_periods.ended_at,
954 project_activity_periods.duration_millis,
955 project_activity_periods.project_id,
956 worktree_extensions.extension,
957 worktree_extensions.count
958 FROM project_activity_periods
959 LEFT OUTER JOIN
960 worktree_extensions
961 ON
962 project_activity_periods.project_id = worktree_extensions.project_id
963 WHERE
964 project_activity_periods.user_id = $1 AND
965 $2 < project_activity_periods.ended_at AND
966 project_activity_periods.ended_at <= $3
967 ORDER BY project_activity_periods.id ASC
968 ";
969
970 let mut rows = sqlx::query_as::<
971 _,
972 (
973 PrimitiveDateTime,
974 i32,
975 ProjectId,
976 Option<String>,
977 Option<i32>,
978 ),
979 >(query)
980 .bind(user_id)
981 .bind(time_period.start)
982 .bind(time_period.end)
983 .fetch(&self.pool);
984
985 let mut time_periods: HashMap<ProjectId, Vec<UserActivityPeriod>> = Default::default();
986 while let Some(row) = rows.next().await {
987 let (ended_at, duration_millis, project_id, extension, extension_count) = row?;
988 let ended_at = ended_at.assume_utc();
989 let duration = Duration::from_millis(duration_millis as u64);
990 let started_at = ended_at - duration;
991 let project_time_periods = time_periods.entry(project_id).or_default();
992
993 if let Some(prev_duration) = project_time_periods.last_mut() {
994 if started_at <= prev_duration.end + COALESCE_THRESHOLD
995 && ended_at >= prev_duration.start
996 {
997 prev_duration.end = cmp::max(prev_duration.end, ended_at);
998 } else {
999 project_time_periods.push(UserActivityPeriod {
1000 project_id,
1001 start: started_at,
1002 end: ended_at,
1003 extensions: Default::default(),
1004 });
1005 }
1006 } else {
1007 project_time_periods.push(UserActivityPeriod {
1008 project_id,
1009 start: started_at,
1010 end: ended_at,
1011 extensions: Default::default(),
1012 });
1013 }
1014
1015 if let Some((extension, extension_count)) = extension.zip(extension_count) {
1016 project_time_periods
1017 .last_mut()
1018 .unwrap()
1019 .extensions
1020 .insert(extension, extension_count as usize);
1021 }
1022 }
1023
1024 let mut durations = time_periods.into_values().flatten().collect::<Vec<_>>();
1025 durations.sort_unstable_by_key(|duration| duration.start);
1026 Ok(durations)
1027 }
1028
1029 // contacts
1030
1031 async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1032 let query = "
1033 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1034 FROM contacts
1035 WHERE user_id_a = $1 OR user_id_b = $1;
1036 ";
1037
1038 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1039 .bind(user_id)
1040 .fetch(&self.pool);
1041
1042 let mut contacts = vec![Contact::Accepted {
1043 user_id,
1044 should_notify: false,
1045 }];
1046 while let Some(row) = rows.next().await {
1047 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1048
1049 if user_id_a == user_id {
1050 if accepted {
1051 contacts.push(Contact::Accepted {
1052 user_id: user_id_b,
1053 should_notify: should_notify && a_to_b,
1054 });
1055 } else if a_to_b {
1056 contacts.push(Contact::Outgoing { user_id: user_id_b })
1057 } else {
1058 contacts.push(Contact::Incoming {
1059 user_id: user_id_b,
1060 should_notify,
1061 });
1062 }
1063 } else if accepted {
1064 contacts.push(Contact::Accepted {
1065 user_id: user_id_a,
1066 should_notify: should_notify && !a_to_b,
1067 });
1068 } else if a_to_b {
1069 contacts.push(Contact::Incoming {
1070 user_id: user_id_a,
1071 should_notify,
1072 });
1073 } else {
1074 contacts.push(Contact::Outgoing { user_id: user_id_a });
1075 }
1076 }
1077
1078 contacts.sort_unstable_by_key(|contact| contact.user_id());
1079
1080 Ok(contacts)
1081 }
1082
1083 async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1084 let (id_a, id_b) = if user_id_1 < user_id_2 {
1085 (user_id_1, user_id_2)
1086 } else {
1087 (user_id_2, user_id_1)
1088 };
1089
1090 let query = "
1091 SELECT 1 FROM contacts
1092 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
1093 LIMIT 1
1094 ";
1095 Ok(sqlx::query_scalar::<_, i32>(query)
1096 .bind(id_a.0)
1097 .bind(id_b.0)
1098 .fetch_optional(&self.pool)
1099 .await?
1100 .is_some())
1101 }
1102
1103 async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1104 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1105 (sender_id, receiver_id, true)
1106 } else {
1107 (receiver_id, sender_id, false)
1108 };
1109 let query = "
1110 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1111 VALUES ($1, $2, $3, 'f', 't')
1112 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1113 SET
1114 accepted = 't',
1115 should_notify = 'f'
1116 WHERE
1117 NOT contacts.accepted AND
1118 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1119 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1120 ";
1121 let result = sqlx::query(query)
1122 .bind(id_a.0)
1123 .bind(id_b.0)
1124 .bind(a_to_b)
1125 .execute(&self.pool)
1126 .await?;
1127
1128 if result.rows_affected() == 1 {
1129 Ok(())
1130 } else {
1131 Err(anyhow!("contact already requested"))?
1132 }
1133 }
1134
1135 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1136 let (id_a, id_b) = if responder_id < requester_id {
1137 (responder_id, requester_id)
1138 } else {
1139 (requester_id, responder_id)
1140 };
1141 let query = "
1142 DELETE FROM contacts
1143 WHERE user_id_a = $1 AND user_id_b = $2;
1144 ";
1145 let result = sqlx::query(query)
1146 .bind(id_a.0)
1147 .bind(id_b.0)
1148 .execute(&self.pool)
1149 .await?;
1150
1151 if result.rows_affected() == 1 {
1152 Ok(())
1153 } else {
1154 Err(anyhow!("no such contact"))?
1155 }
1156 }
1157
1158 async fn dismiss_contact_notification(
1159 &self,
1160 user_id: UserId,
1161 contact_user_id: UserId,
1162 ) -> Result<()> {
1163 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1164 (user_id, contact_user_id, true)
1165 } else {
1166 (contact_user_id, user_id, false)
1167 };
1168
1169 let query = "
1170 UPDATE contacts
1171 SET should_notify = 'f'
1172 WHERE
1173 user_id_a = $1 AND user_id_b = $2 AND
1174 (
1175 (a_to_b = $3 AND accepted) OR
1176 (a_to_b != $3 AND NOT accepted)
1177 );
1178 ";
1179
1180 let result = sqlx::query(query)
1181 .bind(id_a.0)
1182 .bind(id_b.0)
1183 .bind(a_to_b)
1184 .execute(&self.pool)
1185 .await?;
1186
1187 if result.rows_affected() == 0 {
1188 Err(anyhow!("no such contact request"))?;
1189 }
1190
1191 Ok(())
1192 }
1193
1194 async fn respond_to_contact_request(
1195 &self,
1196 responder_id: UserId,
1197 requester_id: UserId,
1198 accept: bool,
1199 ) -> Result<()> {
1200 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1201 (responder_id, requester_id, false)
1202 } else {
1203 (requester_id, responder_id, true)
1204 };
1205 let result = if accept {
1206 let query = "
1207 UPDATE contacts
1208 SET accepted = 't', should_notify = 't'
1209 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1210 ";
1211 sqlx::query(query)
1212 .bind(id_a.0)
1213 .bind(id_b.0)
1214 .bind(a_to_b)
1215 .execute(&self.pool)
1216 .await?
1217 } else {
1218 let query = "
1219 DELETE FROM contacts
1220 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1221 ";
1222 sqlx::query(query)
1223 .bind(id_a.0)
1224 .bind(id_b.0)
1225 .bind(a_to_b)
1226 .execute(&self.pool)
1227 .await?
1228 };
1229 if result.rows_affected() == 1 {
1230 Ok(())
1231 } else {
1232 Err(anyhow!("no such contact request"))?
1233 }
1234 }
1235
1236 // access tokens
1237
1238 async fn create_access_token_hash(
1239 &self,
1240 user_id: UserId,
1241 access_token_hash: &str,
1242 max_access_token_count: usize,
1243 ) -> Result<()> {
1244 let insert_query = "
1245 INSERT INTO access_tokens (user_id, hash)
1246 VALUES ($1, $2);
1247 ";
1248 let cleanup_query = "
1249 DELETE FROM access_tokens
1250 WHERE id IN (
1251 SELECT id from access_tokens
1252 WHERE user_id = $1
1253 ORDER BY id DESC
1254 OFFSET $3
1255 )
1256 ";
1257
1258 let mut tx = self.pool.begin().await?;
1259 sqlx::query(insert_query)
1260 .bind(user_id.0)
1261 .bind(access_token_hash)
1262 .execute(&mut tx)
1263 .await?;
1264 sqlx::query(cleanup_query)
1265 .bind(user_id.0)
1266 .bind(access_token_hash)
1267 .bind(max_access_token_count as i32)
1268 .execute(&mut tx)
1269 .await?;
1270 Ok(tx.commit().await?)
1271 }
1272
1273 async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1274 let query = "
1275 SELECT hash
1276 FROM access_tokens
1277 WHERE user_id = $1
1278 ORDER BY id DESC
1279 ";
1280 Ok(sqlx::query_scalar(query)
1281 .bind(user_id.0)
1282 .fetch_all(&self.pool)
1283 .await?)
1284 }
1285
1286 // orgs
1287
1288 #[allow(unused)] // Help rust-analyzer
1289 #[cfg(any(test, feature = "seed-support"))]
1290 async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
1291 let query = "
1292 SELECT *
1293 FROM orgs
1294 WHERE slug = $1
1295 ";
1296 Ok(sqlx::query_as(query)
1297 .bind(slug)
1298 .fetch_optional(&self.pool)
1299 .await?)
1300 }
1301
1302 #[cfg(any(test, feature = "seed-support"))]
1303 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
1304 let query = "
1305 INSERT INTO orgs (name, slug)
1306 VALUES ($1, $2)
1307 RETURNING id
1308 ";
1309 Ok(sqlx::query_scalar(query)
1310 .bind(name)
1311 .bind(slug)
1312 .fetch_one(&self.pool)
1313 .await
1314 .map(OrgId)?)
1315 }
1316
1317 #[cfg(any(test, feature = "seed-support"))]
1318 async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
1319 let query = "
1320 INSERT INTO org_memberships (org_id, user_id, admin)
1321 VALUES ($1, $2, $3)
1322 ON CONFLICT DO NOTHING
1323 ";
1324 Ok(sqlx::query(query)
1325 .bind(org_id.0)
1326 .bind(user_id.0)
1327 .bind(is_admin)
1328 .execute(&self.pool)
1329 .await
1330 .map(drop)?)
1331 }
1332
1333 // channels
1334
1335 #[cfg(any(test, feature = "seed-support"))]
1336 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
1337 let query = "
1338 INSERT INTO channels (owner_id, owner_is_user, name)
1339 VALUES ($1, false, $2)
1340 RETURNING id
1341 ";
1342 Ok(sqlx::query_scalar(query)
1343 .bind(org_id.0)
1344 .bind(name)
1345 .fetch_one(&self.pool)
1346 .await
1347 .map(ChannelId)?)
1348 }
1349
1350 #[allow(unused)] // Help rust-analyzer
1351 #[cfg(any(test, feature = "seed-support"))]
1352 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
1353 let query = "
1354 SELECT *
1355 FROM channels
1356 WHERE
1357 channels.owner_is_user = false AND
1358 channels.owner_id = $1
1359 ";
1360 Ok(sqlx::query_as(query)
1361 .bind(org_id.0)
1362 .fetch_all(&self.pool)
1363 .await?)
1364 }
1365
1366 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
1367 let query = "
1368 SELECT
1369 channels.*
1370 FROM
1371 channel_memberships, channels
1372 WHERE
1373 channel_memberships.user_id = $1 AND
1374 channel_memberships.channel_id = channels.id
1375 ";
1376 Ok(sqlx::query_as(query)
1377 .bind(user_id.0)
1378 .fetch_all(&self.pool)
1379 .await?)
1380 }
1381
1382 async fn can_user_access_channel(
1383 &self,
1384 user_id: UserId,
1385 channel_id: ChannelId,
1386 ) -> Result<bool> {
1387 let query = "
1388 SELECT id
1389 FROM channel_memberships
1390 WHERE user_id = $1 AND channel_id = $2
1391 LIMIT 1
1392 ";
1393 Ok(sqlx::query_scalar::<_, i32>(query)
1394 .bind(user_id.0)
1395 .bind(channel_id.0)
1396 .fetch_optional(&self.pool)
1397 .await
1398 .map(|e| e.is_some())?)
1399 }
1400
1401 #[cfg(any(test, feature = "seed-support"))]
1402 async fn add_channel_member(
1403 &self,
1404 channel_id: ChannelId,
1405 user_id: UserId,
1406 is_admin: bool,
1407 ) -> Result<()> {
1408 let query = "
1409 INSERT INTO channel_memberships (channel_id, user_id, admin)
1410 VALUES ($1, $2, $3)
1411 ON CONFLICT DO NOTHING
1412 ";
1413 Ok(sqlx::query(query)
1414 .bind(channel_id.0)
1415 .bind(user_id.0)
1416 .bind(is_admin)
1417 .execute(&self.pool)
1418 .await
1419 .map(drop)?)
1420 }
1421
1422 // messages
1423
1424 async fn create_channel_message(
1425 &self,
1426 channel_id: ChannelId,
1427 sender_id: UserId,
1428 body: &str,
1429 timestamp: OffsetDateTime,
1430 nonce: u128,
1431 ) -> Result<MessageId> {
1432 let query = "
1433 INSERT INTO channel_messages (channel_id, sender_id, body, sent_at, nonce)
1434 VALUES ($1, $2, $3, $4, $5)
1435 ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
1436 RETURNING id
1437 ";
1438 Ok(sqlx::query_scalar(query)
1439 .bind(channel_id.0)
1440 .bind(sender_id.0)
1441 .bind(body)
1442 .bind(timestamp)
1443 .bind(Uuid::from_u128(nonce))
1444 .fetch_one(&self.pool)
1445 .await
1446 .map(MessageId)?)
1447 }
1448
1449 async fn get_channel_messages(
1450 &self,
1451 channel_id: ChannelId,
1452 count: usize,
1453 before_id: Option<MessageId>,
1454 ) -> Result<Vec<ChannelMessage>> {
1455 let query = r#"
1456 SELECT * FROM (
1457 SELECT
1458 id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
1459 FROM
1460 channel_messages
1461 WHERE
1462 channel_id = $1 AND
1463 id < $2
1464 ORDER BY id DESC
1465 LIMIT $3
1466 ) as recent_messages
1467 ORDER BY id ASC
1468 "#;
1469 Ok(sqlx::query_as(query)
1470 .bind(channel_id.0)
1471 .bind(before_id.unwrap_or(MessageId::MAX))
1472 .bind(count as i64)
1473 .fetch_all(&self.pool)
1474 .await?)
1475 }
1476
1477 #[cfg(test)]
1478 async fn teardown(&self, url: &str) {
1479 use util::ResultExt;
1480
1481 let query = "
1482 SELECT pg_terminate_backend(pg_stat_activity.pid)
1483 FROM pg_stat_activity
1484 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
1485 ";
1486 sqlx::query(query).execute(&self.pool).await.log_err();
1487 self.pool.close().await;
1488 <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
1489 .await
1490 .log_err();
1491 }
1492
1493 #[cfg(test)]
1494 fn as_fake(&self) -> Option<&FakeDb> {
1495 None
1496 }
1497}
1498
1499macro_rules! id_type {
1500 ($name:ident) => {
1501 #[derive(
1502 Clone,
1503 Copy,
1504 Debug,
1505 Default,
1506 PartialEq,
1507 Eq,
1508 PartialOrd,
1509 Ord,
1510 Hash,
1511 sqlx::Type,
1512 Serialize,
1513 Deserialize,
1514 )]
1515 #[sqlx(transparent)]
1516 #[serde(transparent)]
1517 pub struct $name(pub i32);
1518
1519 impl $name {
1520 #[allow(unused)]
1521 pub const MAX: Self = Self(i32::MAX);
1522
1523 #[allow(unused)]
1524 pub fn from_proto(value: u64) -> Self {
1525 Self(value as i32)
1526 }
1527
1528 #[allow(unused)]
1529 pub fn to_proto(self) -> u64 {
1530 self.0 as u64
1531 }
1532 }
1533
1534 impl std::fmt::Display for $name {
1535 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1536 self.0.fmt(f)
1537 }
1538 }
1539 };
1540}
1541
1542id_type!(UserId);
1543#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1544pub struct User {
1545 pub id: UserId,
1546 pub github_login: String,
1547 pub email_address: Option<String>,
1548 pub admin: bool,
1549 pub invite_code: Option<String>,
1550 pub invite_count: i32,
1551 pub connected_once: bool,
1552}
1553
1554id_type!(ProjectId);
1555#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1556pub struct Project {
1557 pub id: ProjectId,
1558 pub host_user_id: UserId,
1559 pub unregistered: bool,
1560}
1561
1562#[derive(Clone, Debug, PartialEq, Serialize)]
1563pub struct UserActivitySummary {
1564 pub id: UserId,
1565 pub github_login: String,
1566 pub project_activity: Vec<ProjectActivitySummary>,
1567}
1568
1569#[derive(Clone, Debug, PartialEq, Serialize)]
1570pub struct ProjectActivitySummary {
1571 pub id: ProjectId,
1572 pub duration: Duration,
1573 pub max_collaborators: usize,
1574}
1575
1576#[derive(Clone, Debug, PartialEq, Serialize)]
1577pub struct UserActivityPeriod {
1578 pub project_id: ProjectId,
1579 #[serde(with = "time::serde::iso8601")]
1580 pub start: OffsetDateTime,
1581 #[serde(with = "time::serde::iso8601")]
1582 pub end: OffsetDateTime,
1583 pub extensions: HashMap<String, usize>,
1584}
1585
1586id_type!(OrgId);
1587#[derive(FromRow)]
1588pub struct Org {
1589 pub id: OrgId,
1590 pub name: String,
1591 pub slug: String,
1592}
1593
1594id_type!(ChannelId);
1595#[derive(Clone, Debug, FromRow, Serialize)]
1596pub struct Channel {
1597 pub id: ChannelId,
1598 pub name: String,
1599 pub owner_id: i32,
1600 pub owner_is_user: bool,
1601}
1602
1603id_type!(MessageId);
1604#[derive(Clone, Debug, FromRow)]
1605pub struct ChannelMessage {
1606 pub id: MessageId,
1607 pub channel_id: ChannelId,
1608 pub sender_id: UserId,
1609 pub body: String,
1610 pub sent_at: OffsetDateTime,
1611 pub nonce: Uuid,
1612}
1613
1614#[derive(Clone, Debug, PartialEq, Eq)]
1615pub enum Contact {
1616 Accepted {
1617 user_id: UserId,
1618 should_notify: bool,
1619 },
1620 Outgoing {
1621 user_id: UserId,
1622 },
1623 Incoming {
1624 user_id: UserId,
1625 should_notify: bool,
1626 },
1627}
1628
1629impl Contact {
1630 pub fn user_id(&self) -> UserId {
1631 match self {
1632 Contact::Accepted { user_id, .. } => *user_id,
1633 Contact::Outgoing { user_id } => *user_id,
1634 Contact::Incoming { user_id, .. } => *user_id,
1635 }
1636 }
1637}
1638
1639#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1640pub struct IncomingContactRequest {
1641 pub requester_id: UserId,
1642 pub should_notify: bool,
1643}
1644
1645#[derive(Clone, Deserialize)]
1646pub struct Signup {
1647 pub email_address: String,
1648 pub platform_mac: bool,
1649 pub platform_windows: bool,
1650 pub platform_linux: bool,
1651 pub editor_features: Vec<String>,
1652 pub programming_languages: Vec<String>,
1653}
1654
1655#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1656pub struct WaitlistSummary {
1657 #[sqlx(default)]
1658 pub count: i64,
1659 #[sqlx(default)]
1660 pub linux_count: i64,
1661 #[sqlx(default)]
1662 pub mac_count: i64,
1663 #[sqlx(default)]
1664 pub windows_count: i64,
1665}
1666
1667#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1668pub struct Invite {
1669 pub email_address: String,
1670 pub email_confirmation_code: String,
1671}
1672
1673#[derive(Debug, Serialize, Deserialize)]
1674pub struct NewUserParams {
1675 pub github_login: String,
1676 pub invite_count: i32,
1677}
1678
1679fn random_invite_code() -> String {
1680 nanoid::nanoid!(16)
1681}
1682
1683fn random_email_confirmation_code() -> String {
1684 nanoid::nanoid!(64)
1685}
1686
1687#[cfg(test)]
1688pub use test::*;
1689
1690#[cfg(test)]
1691mod test {
1692 use super::*;
1693 use anyhow::anyhow;
1694 use collections::BTreeMap;
1695 use gpui::executor::Background;
1696 use lazy_static::lazy_static;
1697 use parking_lot::Mutex;
1698 use rand::prelude::*;
1699 use sqlx::{
1700 migrate::{MigrateDatabase, Migrator},
1701 Postgres,
1702 };
1703 use std::{path::Path, sync::Arc};
1704 use util::post_inc;
1705
1706 pub struct FakeDb {
1707 background: Arc<Background>,
1708 pub users: Mutex<BTreeMap<UserId, User>>,
1709 pub projects: Mutex<BTreeMap<ProjectId, Project>>,
1710 pub worktree_extensions: Mutex<BTreeMap<(ProjectId, u64, String), u32>>,
1711 pub orgs: Mutex<BTreeMap<OrgId, Org>>,
1712 pub org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
1713 pub channels: Mutex<BTreeMap<ChannelId, Channel>>,
1714 pub channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
1715 pub channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
1716 pub contacts: Mutex<Vec<FakeContact>>,
1717 next_channel_message_id: Mutex<i32>,
1718 next_user_id: Mutex<i32>,
1719 next_org_id: Mutex<i32>,
1720 next_channel_id: Mutex<i32>,
1721 next_project_id: Mutex<i32>,
1722 }
1723
1724 #[derive(Debug)]
1725 pub struct FakeContact {
1726 pub requester_id: UserId,
1727 pub responder_id: UserId,
1728 pub accepted: bool,
1729 pub should_notify: bool,
1730 }
1731
1732 impl FakeDb {
1733 pub fn new(background: Arc<Background>) -> Self {
1734 Self {
1735 background,
1736 users: Default::default(),
1737 next_user_id: Mutex::new(0),
1738 projects: Default::default(),
1739 worktree_extensions: Default::default(),
1740 next_project_id: Mutex::new(1),
1741 orgs: Default::default(),
1742 next_org_id: Mutex::new(1),
1743 org_memberships: Default::default(),
1744 channels: Default::default(),
1745 next_channel_id: Mutex::new(1),
1746 channel_memberships: Default::default(),
1747 channel_messages: Default::default(),
1748 next_channel_message_id: Mutex::new(1),
1749 contacts: Default::default(),
1750 }
1751 }
1752 }
1753
1754 #[async_trait]
1755 impl Db for FakeDb {
1756 async fn create_user(
1757 &self,
1758 github_login: &str,
1759 email_address: &str,
1760 admin: bool,
1761 ) -> Result<UserId> {
1762 self.background.simulate_random_delay().await;
1763
1764 let mut users = self.users.lock();
1765 if let Some(user) = users
1766 .values()
1767 .find(|user| user.github_login == github_login)
1768 {
1769 Ok(user.id)
1770 } else {
1771 let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
1772 users.insert(
1773 user_id,
1774 User {
1775 id: user_id,
1776 github_login: github_login.to_string(),
1777 email_address: Some(email_address.to_string()),
1778 admin,
1779 invite_code: None,
1780 invite_count: 0,
1781 connected_once: false,
1782 },
1783 );
1784 Ok(user_id)
1785 }
1786 }
1787
1788 async fn get_all_users(&self, _page: u32, _limit: u32) -> Result<Vec<User>> {
1789 unimplemented!()
1790 }
1791
1792 async fn create_users(&self, _users: Vec<(String, String, usize)>) -> Result<Vec<UserId>> {
1793 unimplemented!()
1794 }
1795
1796 async fn fuzzy_search_users(&self, _: &str, _: u32) -> Result<Vec<User>> {
1797 unimplemented!()
1798 }
1799
1800 async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
1801 self.background.simulate_random_delay().await;
1802 Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
1803 }
1804
1805 async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
1806 self.background.simulate_random_delay().await;
1807 let users = self.users.lock();
1808 Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
1809 }
1810
1811 async fn get_users_with_no_invites(&self, _: bool) -> Result<Vec<User>> {
1812 unimplemented!()
1813 }
1814
1815 async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
1816 self.background.simulate_random_delay().await;
1817 Ok(self
1818 .users
1819 .lock()
1820 .values()
1821 .find(|user| user.github_login == github_login)
1822 .cloned())
1823 }
1824
1825 async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
1826 unimplemented!()
1827 }
1828
1829 async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
1830 self.background.simulate_random_delay().await;
1831 let mut users = self.users.lock();
1832 let mut user = users
1833 .get_mut(&id)
1834 .ok_or_else(|| anyhow!("user not found"))?;
1835 user.connected_once = connected_once;
1836 Ok(())
1837 }
1838
1839 async fn destroy_user(&self, _id: UserId) -> Result<()> {
1840 unimplemented!()
1841 }
1842
1843 // signups
1844
1845 async fn create_signup(&self, _signup: Signup) -> Result<()> {
1846 unimplemented!()
1847 }
1848
1849 async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
1850 unimplemented!()
1851 }
1852
1853 async fn get_unsent_invites(&self, _count: usize) -> Result<Vec<Invite>> {
1854 unimplemented!()
1855 }
1856
1857 async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
1858 unimplemented!()
1859 }
1860
1861 async fn create_user_from_invite(
1862 &self,
1863 _invite: &Invite,
1864 _user: NewUserParams,
1865 ) -> Result<(UserId, Option<UserId>)> {
1866 unimplemented!()
1867 }
1868
1869 // invite codes
1870
1871 async fn set_invite_count_for_user(&self, _id: UserId, _count: u32) -> Result<()> {
1872 unimplemented!()
1873 }
1874
1875 async fn get_invite_code_for_user(&self, _id: UserId) -> Result<Option<(String, u32)>> {
1876 self.background.simulate_random_delay().await;
1877 Ok(None)
1878 }
1879
1880 async fn get_user_for_invite_code(&self, _code: &str) -> Result<User> {
1881 unimplemented!()
1882 }
1883
1884 async fn create_invite_from_code(
1885 &self,
1886 _code: &str,
1887 _email_address: &str,
1888 ) -> Result<Invite> {
1889 unimplemented!()
1890 }
1891
1892 // projects
1893
1894 async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
1895 self.background.simulate_random_delay().await;
1896 if !self.users.lock().contains_key(&host_user_id) {
1897 Err(anyhow!("no such user"))?;
1898 }
1899
1900 let project_id = ProjectId(post_inc(&mut *self.next_project_id.lock()));
1901 self.projects.lock().insert(
1902 project_id,
1903 Project {
1904 id: project_id,
1905 host_user_id,
1906 unregistered: false,
1907 },
1908 );
1909 Ok(project_id)
1910 }
1911
1912 async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
1913 self.background.simulate_random_delay().await;
1914 self.projects
1915 .lock()
1916 .get_mut(&project_id)
1917 .ok_or_else(|| anyhow!("no such project"))?
1918 .unregistered = true;
1919 Ok(())
1920 }
1921
1922 async fn update_worktree_extensions(
1923 &self,
1924 project_id: ProjectId,
1925 worktree_id: u64,
1926 extensions: HashMap<String, u32>,
1927 ) -> Result<()> {
1928 self.background.simulate_random_delay().await;
1929 if !self.projects.lock().contains_key(&project_id) {
1930 Err(anyhow!("no such project"))?;
1931 }
1932
1933 for (extension, count) in extensions {
1934 self.worktree_extensions
1935 .lock()
1936 .insert((project_id, worktree_id, extension), count);
1937 }
1938
1939 Ok(())
1940 }
1941
1942 async fn get_project_extensions(
1943 &self,
1944 _project_id: ProjectId,
1945 ) -> Result<HashMap<u64, HashMap<String, usize>>> {
1946 unimplemented!()
1947 }
1948
1949 async fn record_user_activity(
1950 &self,
1951 _time_period: Range<OffsetDateTime>,
1952 _active_projects: &[(UserId, ProjectId)],
1953 ) -> Result<()> {
1954 unimplemented!()
1955 }
1956
1957 async fn get_active_user_count(
1958 &self,
1959 _time_period: Range<OffsetDateTime>,
1960 _min_duration: Duration,
1961 _only_collaborative: bool,
1962 ) -> Result<usize> {
1963 unimplemented!()
1964 }
1965
1966 async fn get_top_users_activity_summary(
1967 &self,
1968 _time_period: Range<OffsetDateTime>,
1969 _limit: usize,
1970 ) -> Result<Vec<UserActivitySummary>> {
1971 unimplemented!()
1972 }
1973
1974 async fn get_user_activity_timeline(
1975 &self,
1976 _time_period: Range<OffsetDateTime>,
1977 _user_id: UserId,
1978 ) -> Result<Vec<UserActivityPeriod>> {
1979 unimplemented!()
1980 }
1981
1982 // contacts
1983
1984 async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
1985 self.background.simulate_random_delay().await;
1986 let mut contacts = vec![Contact::Accepted {
1987 user_id: id,
1988 should_notify: false,
1989 }];
1990
1991 for contact in self.contacts.lock().iter() {
1992 if contact.requester_id == id {
1993 if contact.accepted {
1994 contacts.push(Contact::Accepted {
1995 user_id: contact.responder_id,
1996 should_notify: contact.should_notify,
1997 });
1998 } else {
1999 contacts.push(Contact::Outgoing {
2000 user_id: contact.responder_id,
2001 });
2002 }
2003 } else if contact.responder_id == id {
2004 if contact.accepted {
2005 contacts.push(Contact::Accepted {
2006 user_id: contact.requester_id,
2007 should_notify: false,
2008 });
2009 } else {
2010 contacts.push(Contact::Incoming {
2011 user_id: contact.requester_id,
2012 should_notify: contact.should_notify,
2013 });
2014 }
2015 }
2016 }
2017
2018 contacts.sort_unstable_by_key(|contact| contact.user_id());
2019 Ok(contacts)
2020 }
2021
2022 async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
2023 self.background.simulate_random_delay().await;
2024 Ok(self.contacts.lock().iter().any(|contact| {
2025 contact.accepted
2026 && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
2027 || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
2028 }))
2029 }
2030
2031 async fn send_contact_request(
2032 &self,
2033 requester_id: UserId,
2034 responder_id: UserId,
2035 ) -> Result<()> {
2036 self.background.simulate_random_delay().await;
2037 let mut contacts = self.contacts.lock();
2038 for contact in contacts.iter_mut() {
2039 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2040 if contact.accepted {
2041 Err(anyhow!("contact already exists"))?;
2042 } else {
2043 Err(anyhow!("contact already requested"))?;
2044 }
2045 }
2046 if contact.responder_id == requester_id && contact.requester_id == responder_id {
2047 if contact.accepted {
2048 Err(anyhow!("contact already exists"))?;
2049 } else {
2050 contact.accepted = true;
2051 contact.should_notify = false;
2052 return Ok(());
2053 }
2054 }
2055 }
2056 contacts.push(FakeContact {
2057 requester_id,
2058 responder_id,
2059 accepted: false,
2060 should_notify: true,
2061 });
2062 Ok(())
2063 }
2064
2065 async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2066 self.background.simulate_random_delay().await;
2067 self.contacts.lock().retain(|contact| {
2068 !(contact.requester_id == requester_id && contact.responder_id == responder_id)
2069 });
2070 Ok(())
2071 }
2072
2073 async fn dismiss_contact_notification(
2074 &self,
2075 user_id: UserId,
2076 contact_user_id: UserId,
2077 ) -> Result<()> {
2078 self.background.simulate_random_delay().await;
2079 let mut contacts = self.contacts.lock();
2080 for contact in contacts.iter_mut() {
2081 if contact.requester_id == contact_user_id
2082 && contact.responder_id == user_id
2083 && !contact.accepted
2084 {
2085 contact.should_notify = false;
2086 return Ok(());
2087 }
2088 if contact.requester_id == user_id
2089 && contact.responder_id == contact_user_id
2090 && contact.accepted
2091 {
2092 contact.should_notify = false;
2093 return Ok(());
2094 }
2095 }
2096 Err(anyhow!("no such notification"))?
2097 }
2098
2099 async fn respond_to_contact_request(
2100 &self,
2101 responder_id: UserId,
2102 requester_id: UserId,
2103 accept: bool,
2104 ) -> Result<()> {
2105 self.background.simulate_random_delay().await;
2106 let mut contacts = self.contacts.lock();
2107 for (ix, contact) in contacts.iter_mut().enumerate() {
2108 if contact.requester_id == requester_id && contact.responder_id == responder_id {
2109 if contact.accepted {
2110 Err(anyhow!("contact already confirmed"))?;
2111 }
2112 if accept {
2113 contact.accepted = true;
2114 contact.should_notify = true;
2115 } else {
2116 contacts.remove(ix);
2117 }
2118 return Ok(());
2119 }
2120 }
2121 Err(anyhow!("no such contact request"))?
2122 }
2123
2124 async fn create_access_token_hash(
2125 &self,
2126 _user_id: UserId,
2127 _access_token_hash: &str,
2128 _max_access_token_count: usize,
2129 ) -> Result<()> {
2130 unimplemented!()
2131 }
2132
2133 async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
2134 unimplemented!()
2135 }
2136
2137 async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
2138 unimplemented!()
2139 }
2140
2141 async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
2142 self.background.simulate_random_delay().await;
2143 let mut orgs = self.orgs.lock();
2144 if orgs.values().any(|org| org.slug == slug) {
2145 Err(anyhow!("org already exists"))?
2146 } else {
2147 let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
2148 orgs.insert(
2149 org_id,
2150 Org {
2151 id: org_id,
2152 name: name.to_string(),
2153 slug: slug.to_string(),
2154 },
2155 );
2156 Ok(org_id)
2157 }
2158 }
2159
2160 async fn add_org_member(
2161 &self,
2162 org_id: OrgId,
2163 user_id: UserId,
2164 is_admin: bool,
2165 ) -> Result<()> {
2166 self.background.simulate_random_delay().await;
2167 if !self.orgs.lock().contains_key(&org_id) {
2168 Err(anyhow!("org does not exist"))?;
2169 }
2170 if !self.users.lock().contains_key(&user_id) {
2171 Err(anyhow!("user does not exist"))?;
2172 }
2173
2174 self.org_memberships
2175 .lock()
2176 .entry((org_id, user_id))
2177 .or_insert(is_admin);
2178 Ok(())
2179 }
2180
2181 async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
2182 self.background.simulate_random_delay().await;
2183 if !self.orgs.lock().contains_key(&org_id) {
2184 Err(anyhow!("org does not exist"))?;
2185 }
2186
2187 let mut channels = self.channels.lock();
2188 let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
2189 channels.insert(
2190 channel_id,
2191 Channel {
2192 id: channel_id,
2193 name: name.to_string(),
2194 owner_id: org_id.0,
2195 owner_is_user: false,
2196 },
2197 );
2198 Ok(channel_id)
2199 }
2200
2201 async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
2202 self.background.simulate_random_delay().await;
2203 Ok(self
2204 .channels
2205 .lock()
2206 .values()
2207 .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
2208 .cloned()
2209 .collect())
2210 }
2211
2212 async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
2213 self.background.simulate_random_delay().await;
2214 let channels = self.channels.lock();
2215 let memberships = self.channel_memberships.lock();
2216 Ok(channels
2217 .values()
2218 .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
2219 .cloned()
2220 .collect())
2221 }
2222
2223 async fn can_user_access_channel(
2224 &self,
2225 user_id: UserId,
2226 channel_id: ChannelId,
2227 ) -> Result<bool> {
2228 self.background.simulate_random_delay().await;
2229 Ok(self
2230 .channel_memberships
2231 .lock()
2232 .contains_key(&(channel_id, user_id)))
2233 }
2234
2235 async fn add_channel_member(
2236 &self,
2237 channel_id: ChannelId,
2238 user_id: UserId,
2239 is_admin: bool,
2240 ) -> Result<()> {
2241 self.background.simulate_random_delay().await;
2242 if !self.channels.lock().contains_key(&channel_id) {
2243 Err(anyhow!("channel does not exist"))?;
2244 }
2245 if !self.users.lock().contains_key(&user_id) {
2246 Err(anyhow!("user does not exist"))?;
2247 }
2248
2249 self.channel_memberships
2250 .lock()
2251 .entry((channel_id, user_id))
2252 .or_insert(is_admin);
2253 Ok(())
2254 }
2255
2256 async fn create_channel_message(
2257 &self,
2258 channel_id: ChannelId,
2259 sender_id: UserId,
2260 body: &str,
2261 timestamp: OffsetDateTime,
2262 nonce: u128,
2263 ) -> Result<MessageId> {
2264 self.background.simulate_random_delay().await;
2265 if !self.channels.lock().contains_key(&channel_id) {
2266 Err(anyhow!("channel does not exist"))?;
2267 }
2268 if !self.users.lock().contains_key(&sender_id) {
2269 Err(anyhow!("user does not exist"))?;
2270 }
2271
2272 let mut messages = self.channel_messages.lock();
2273 if let Some(message) = messages
2274 .values()
2275 .find(|message| message.nonce.as_u128() == nonce)
2276 {
2277 Ok(message.id)
2278 } else {
2279 let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
2280 messages.insert(
2281 message_id,
2282 ChannelMessage {
2283 id: message_id,
2284 channel_id,
2285 sender_id,
2286 body: body.to_string(),
2287 sent_at: timestamp,
2288 nonce: Uuid::from_u128(nonce),
2289 },
2290 );
2291 Ok(message_id)
2292 }
2293 }
2294
2295 async fn get_channel_messages(
2296 &self,
2297 channel_id: ChannelId,
2298 count: usize,
2299 before_id: Option<MessageId>,
2300 ) -> Result<Vec<ChannelMessage>> {
2301 self.background.simulate_random_delay().await;
2302 let mut messages = self
2303 .channel_messages
2304 .lock()
2305 .values()
2306 .rev()
2307 .filter(|message| {
2308 message.channel_id == channel_id
2309 && message.id < before_id.unwrap_or(MessageId::MAX)
2310 })
2311 .take(count)
2312 .cloned()
2313 .collect::<Vec<_>>();
2314 messages.sort_unstable_by_key(|message| message.id);
2315 Ok(messages)
2316 }
2317
2318 async fn teardown(&self, _: &str) {}
2319
2320 #[cfg(test)]
2321 fn as_fake(&self) -> Option<&FakeDb> {
2322 Some(self)
2323 }
2324 }
2325
2326 pub struct TestDb {
2327 pub db: Option<Arc<dyn Db>>,
2328 pub url: String,
2329 }
2330
2331 impl TestDb {
2332 #[allow(clippy::await_holding_lock)]
2333 pub async fn postgres() -> Self {
2334 lazy_static! {
2335 static ref LOCK: Mutex<()> = Mutex::new(());
2336 }
2337
2338 let _guard = LOCK.lock();
2339 let mut rng = StdRng::from_entropy();
2340 let name = format!("zed-test-{}", rng.gen::<u128>());
2341 let url = format!("postgres://postgres@localhost/{}", name);
2342 let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
2343 Postgres::create_database(&url)
2344 .await
2345 .expect("failed to create test db");
2346 let db = PostgresDb::new(&url, 5).await.unwrap();
2347 let migrator = Migrator::new(migrations_path).await.unwrap();
2348 migrator.run(&db.pool).await.unwrap();
2349 Self {
2350 db: Some(Arc::new(db)),
2351 url,
2352 }
2353 }
2354
2355 pub fn fake(background: Arc<Background>) -> Self {
2356 Self {
2357 db: Some(Arc::new(FakeDb::new(background))),
2358 url: Default::default(),
2359 }
2360 }
2361
2362 pub fn db(&self) -> &Arc<dyn Db> {
2363 self.db.as_ref().unwrap()
2364 }
2365 }
2366
2367 impl Drop for TestDb {
2368 fn drop(&mut self) {
2369 if let Some(db) = self.db.take() {
2370 futures::executor::block_on(db.teardown(&self.url));
2371 }
2372 }
2373 }
2374}