1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use rpc::{proto, ConnectionId};
7use serde::{Deserialize, Serialize};
8use sqlx::{
9 migrate::{Migrate as _, Migration, MigrationSource},
10 types::Uuid,
11 FromRow,
12};
13use std::{path::Path, time::Duration};
14use time::{OffsetDateTime, PrimitiveDateTime};
15
16#[cfg(test)]
17pub type DefaultDb = Db<sqlx::Sqlite>;
18
19#[cfg(not(test))]
20pub type DefaultDb = Db<sqlx::Postgres>;
21
22pub struct Db<D: sqlx::Database> {
23 pool: sqlx::Pool<D>,
24 #[cfg(test)]
25 background: Option<std::sync::Arc<gpui::executor::Background>>,
26 #[cfg(test)]
27 runtime: Option<tokio::runtime::Runtime>,
28}
29
30macro_rules! test_support {
31 ($self:ident, { $($token:tt)* }) => {{
32 let body = async {
33 $($token)*
34 };
35
36 if cfg!(test) {
37 #[cfg(not(test))]
38 unreachable!();
39
40 #[cfg(test)]
41 if let Some(background) = $self.background.as_ref() {
42 background.simulate_random_delay().await;
43 }
44
45 #[cfg(test)]
46 $self.runtime.as_ref().unwrap().block_on(body)
47 } else {
48 body.await
49 }
50 }};
51}
52
53pub trait RowsAffected {
54 fn rows_affected(&self) -> u64;
55}
56
57#[cfg(test)]
58impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
59 fn rows_affected(&self) -> u64 {
60 self.rows_affected()
61 }
62}
63
64impl RowsAffected for sqlx::postgres::PgQueryResult {
65 fn rows_affected(&self) -> u64 {
66 self.rows_affected()
67 }
68}
69
70#[cfg(test)]
71impl Db<sqlx::Sqlite> {
72 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
73 use std::str::FromStr as _;
74 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
75 .unwrap()
76 .create_if_missing(true)
77 .shared_cache(true);
78 let pool = sqlx::sqlite::SqlitePoolOptions::new()
79 .min_connections(2)
80 .max_connections(max_connections)
81 .connect_with(options)
82 .await?;
83 Ok(Self {
84 pool,
85 background: None,
86 runtime: None,
87 })
88 }
89
90 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
91 test_support!(self, {
92 let query = "
93 SELECT users.*
94 FROM users
95 WHERE users.id IN (SELECT value from json_each($1))
96 ";
97 Ok(sqlx::query_as(query)
98 .bind(&serde_json::json!(ids))
99 .fetch_all(&self.pool)
100 .await?)
101 })
102 }
103
104 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
105 test_support!(self, {
106 let query = "
107 SELECT metrics_id
108 FROM users
109 WHERE id = $1
110 ";
111 Ok(sqlx::query_scalar(query)
112 .bind(id)
113 .fetch_one(&self.pool)
114 .await?)
115 })
116 }
117
118 pub async fn create_user(
119 &self,
120 email_address: &str,
121 admin: bool,
122 params: NewUserParams,
123 ) -> Result<NewUserResult> {
124 test_support!(self, {
125 let query = "
126 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
127 VALUES ($1, $2, $3, $4, $5)
128 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
129 RETURNING id, metrics_id
130 ";
131
132 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
133 .bind(email_address)
134 .bind(params.github_login)
135 .bind(params.github_user_id)
136 .bind(admin)
137 .bind(Uuid::new_v4().to_string())
138 .fetch_one(&self.pool)
139 .await?;
140 Ok(NewUserResult {
141 user_id,
142 metrics_id,
143 signup_device_id: None,
144 inviting_user_id: None,
145 })
146 })
147 }
148
149 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
150 unimplemented!()
151 }
152
153 pub async fn create_user_from_invite(
154 &self,
155 _invite: &Invite,
156 _user: NewUserParams,
157 ) -> Result<Option<NewUserResult>> {
158 unimplemented!()
159 }
160
161 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
162 unimplemented!()
163 }
164
165 pub async fn create_invite_from_code(
166 &self,
167 _code: &str,
168 _email_address: &str,
169 _device_id: Option<&str>,
170 ) -> Result<Invite> {
171 unimplemented!()
172 }
173
174 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
175 unimplemented!()
176 }
177}
178
179impl Db<sqlx::Postgres> {
180 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
181 let pool = sqlx::postgres::PgPoolOptions::new()
182 .max_connections(max_connections)
183 .connect(url)
184 .await?;
185 Ok(Self {
186 pool,
187 #[cfg(test)]
188 background: None,
189 #[cfg(test)]
190 runtime: None,
191 })
192 }
193
194 #[cfg(test)]
195 pub fn teardown(&self, url: &str) {
196 self.runtime.as_ref().unwrap().block_on(async {
197 use util::ResultExt;
198 let query = "
199 SELECT pg_terminate_backend(pg_stat_activity.pid)
200 FROM pg_stat_activity
201 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
202 ";
203 sqlx::query(query).execute(&self.pool).await.log_err();
204 self.pool.close().await;
205 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
206 .await
207 .log_err();
208 })
209 }
210
211 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
212 test_support!(self, {
213 let like_string = Self::fuzzy_like_string(name_query);
214 let query = "
215 SELECT users.*
216 FROM users
217 WHERE github_login ILIKE $1
218 ORDER BY github_login <-> $2
219 LIMIT $3
220 ";
221 Ok(sqlx::query_as(query)
222 .bind(like_string)
223 .bind(name_query)
224 .bind(limit as i32)
225 .fetch_all(&self.pool)
226 .await?)
227 })
228 }
229
230 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
231 test_support!(self, {
232 let query = "
233 SELECT users.*
234 FROM users
235 WHERE users.id = ANY ($1)
236 ";
237 Ok(sqlx::query_as(query)
238 .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
239 .fetch_all(&self.pool)
240 .await?)
241 })
242 }
243
244 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
245 test_support!(self, {
246 let query = "
247 SELECT metrics_id::text
248 FROM users
249 WHERE id = $1
250 ";
251 Ok(sqlx::query_scalar(query)
252 .bind(id)
253 .fetch_one(&self.pool)
254 .await?)
255 })
256 }
257
258 pub async fn create_user(
259 &self,
260 email_address: &str,
261 admin: bool,
262 params: NewUserParams,
263 ) -> Result<NewUserResult> {
264 test_support!(self, {
265 let query = "
266 INSERT INTO users (email_address, github_login, github_user_id, admin)
267 VALUES ($1, $2, $3, $4)
268 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
269 RETURNING id, metrics_id::text
270 ";
271
272 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
273 .bind(email_address)
274 .bind(params.github_login)
275 .bind(params.github_user_id)
276 .bind(admin)
277 .fetch_one(&self.pool)
278 .await?;
279 Ok(NewUserResult {
280 user_id,
281 metrics_id,
282 signup_device_id: None,
283 inviting_user_id: None,
284 })
285 })
286 }
287
288 pub async fn create_user_from_invite(
289 &self,
290 invite: &Invite,
291 user: NewUserParams,
292 ) -> Result<Option<NewUserResult>> {
293 test_support!(self, {
294 let mut tx = self.pool.begin().await?;
295
296 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
297 i32,
298 Option<UserId>,
299 Option<UserId>,
300 Option<String>,
301 ) = sqlx::query_as(
302 "
303 SELECT id, user_id, inviting_user_id, device_id
304 FROM signups
305 WHERE
306 email_address = $1 AND
307 email_confirmation_code = $2
308 ",
309 )
310 .bind(&invite.email_address)
311 .bind(&invite.email_confirmation_code)
312 .fetch_optional(&mut tx)
313 .await?
314 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
315
316 if existing_user_id.is_some() {
317 return Ok(None);
318 }
319
320 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
321 "
322 INSERT INTO users
323 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
324 VALUES
325 ($1, $2, $3, FALSE, $4, $5)
326 ON CONFLICT (github_login) DO UPDATE SET
327 email_address = excluded.email_address,
328 github_user_id = excluded.github_user_id,
329 admin = excluded.admin
330 RETURNING id, metrics_id::text
331 ",
332 )
333 .bind(&invite.email_address)
334 .bind(&user.github_login)
335 .bind(&user.github_user_id)
336 .bind(&user.invite_count)
337 .bind(random_invite_code())
338 .fetch_one(&mut tx)
339 .await?;
340
341 sqlx::query(
342 "
343 UPDATE signups
344 SET user_id = $1
345 WHERE id = $2
346 ",
347 )
348 .bind(&user_id)
349 .bind(&signup_id)
350 .execute(&mut tx)
351 .await?;
352
353 if let Some(inviting_user_id) = inviting_user_id {
354 let id: Option<UserId> = sqlx::query_scalar(
355 "
356 UPDATE users
357 SET invite_count = invite_count - 1
358 WHERE id = $1 AND invite_count > 0
359 RETURNING id
360 ",
361 )
362 .bind(&inviting_user_id)
363 .fetch_optional(&mut tx)
364 .await?;
365
366 if id.is_none() {
367 Err(Error::Http(
368 StatusCode::UNAUTHORIZED,
369 "no invites remaining".to_string(),
370 ))?;
371 }
372
373 sqlx::query(
374 "
375 INSERT INTO contacts
376 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
377 VALUES
378 ($1, $2, TRUE, TRUE, TRUE)
379 ON CONFLICT DO NOTHING
380 ",
381 )
382 .bind(inviting_user_id)
383 .bind(user_id)
384 .execute(&mut tx)
385 .await?;
386 }
387
388 tx.commit().await?;
389 Ok(Some(NewUserResult {
390 user_id,
391 metrics_id,
392 inviting_user_id,
393 signup_device_id,
394 }))
395 })
396 }
397
398 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
399 test_support!(self, {
400 sqlx::query(
401 "
402 INSERT INTO signups
403 (
404 email_address,
405 email_confirmation_code,
406 email_confirmation_sent,
407 platform_linux,
408 platform_mac,
409 platform_windows,
410 platform_unknown,
411 editor_features,
412 programming_languages,
413 device_id
414 )
415 VALUES
416 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
417 RETURNING id
418 ",
419 )
420 .bind(&signup.email_address)
421 .bind(&random_email_confirmation_code())
422 .bind(&signup.platform_linux)
423 .bind(&signup.platform_mac)
424 .bind(&signup.platform_windows)
425 .bind(&signup.editor_features)
426 .bind(&signup.programming_languages)
427 .bind(&signup.device_id)
428 .execute(&self.pool)
429 .await?;
430 Ok(())
431 })
432 }
433
434 pub async fn create_invite_from_code(
435 &self,
436 code: &str,
437 email_address: &str,
438 device_id: Option<&str>,
439 ) -> Result<Invite> {
440 test_support!(self, {
441 let mut tx = self.pool.begin().await?;
442
443 let existing_user: Option<UserId> = sqlx::query_scalar(
444 "
445 SELECT id
446 FROM users
447 WHERE email_address = $1
448 ",
449 )
450 .bind(email_address)
451 .fetch_optional(&mut tx)
452 .await?;
453 if existing_user.is_some() {
454 Err(anyhow!("email address is already in use"))?;
455 }
456
457 let row: Option<(UserId, i32)> = sqlx::query_as(
458 "
459 SELECT id, invite_count
460 FROM users
461 WHERE invite_code = $1
462 ",
463 )
464 .bind(code)
465 .fetch_optional(&mut tx)
466 .await?;
467
468 let (inviter_id, invite_count) = match row {
469 Some(row) => row,
470 None => Err(Error::Http(
471 StatusCode::NOT_FOUND,
472 "invite code not found".to_string(),
473 ))?,
474 };
475
476 if invite_count == 0 {
477 Err(Error::Http(
478 StatusCode::UNAUTHORIZED,
479 "no invites remaining".to_string(),
480 ))?;
481 }
482
483 let email_confirmation_code: String = sqlx::query_scalar(
484 "
485 INSERT INTO signups
486 (
487 email_address,
488 email_confirmation_code,
489 email_confirmation_sent,
490 inviting_user_id,
491 platform_linux,
492 platform_mac,
493 platform_windows,
494 platform_unknown,
495 device_id
496 )
497 VALUES
498 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
499 ON CONFLICT (email_address)
500 DO UPDATE SET
501 inviting_user_id = excluded.inviting_user_id
502 RETURNING email_confirmation_code
503 ",
504 )
505 .bind(&email_address)
506 .bind(&random_email_confirmation_code())
507 .bind(&inviter_id)
508 .bind(&device_id)
509 .fetch_one(&mut tx)
510 .await?;
511
512 tx.commit().await?;
513
514 Ok(Invite {
515 email_address: email_address.into(),
516 email_confirmation_code,
517 })
518 })
519 }
520
521 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
522 test_support!(self, {
523 let emails = invites
524 .iter()
525 .map(|s| s.email_address.as_str())
526 .collect::<Vec<_>>();
527 sqlx::query(
528 "
529 UPDATE signups
530 SET email_confirmation_sent = TRUE
531 WHERE email_address = ANY ($1)
532 ",
533 )
534 .bind(&emails)
535 .execute(&self.pool)
536 .await?;
537 Ok(())
538 })
539 }
540}
541
542impl<D> Db<D>
543where
544 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
545 D::Connection: sqlx::migrate::Migrate,
546 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
547 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
548 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
549 D::QueryResult: RowsAffected,
550 String: sqlx::Type<D>,
551 i32: sqlx::Type<D>,
552 i64: sqlx::Type<D>,
553 bool: sqlx::Type<D>,
554 str: sqlx::Type<D>,
555 Uuid: sqlx::Type<D>,
556 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
557 OffsetDateTime: sqlx::Type<D>,
558 PrimitiveDateTime: sqlx::Type<D>,
559 usize: sqlx::ColumnIndex<D::Row>,
560 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
561 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
562 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
563 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
564 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
565 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
566 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
567 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
568 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
569 for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
570 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
571 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
572 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
573{
574 pub async fn migrate(
575 &self,
576 migrations_path: &Path,
577 ignore_checksum_mismatch: bool,
578 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
579 let migrations = MigrationSource::resolve(migrations_path)
580 .await
581 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
582
583 let mut conn = self.pool.acquire().await?;
584
585 conn.ensure_migrations_table().await?;
586 let applied_migrations: HashMap<_, _> = conn
587 .list_applied_migrations()
588 .await?
589 .into_iter()
590 .map(|m| (m.version, m))
591 .collect();
592
593 let mut new_migrations = Vec::new();
594 for migration in migrations {
595 match applied_migrations.get(&migration.version) {
596 Some(applied_migration) => {
597 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
598 {
599 Err(anyhow!(
600 "checksum mismatch for applied migration {}",
601 migration.description
602 ))?;
603 }
604 }
605 None => {
606 let elapsed = conn.apply(&migration).await?;
607 new_migrations.push((migration, elapsed));
608 }
609 }
610 }
611
612 Ok(new_migrations)
613 }
614
615 pub fn fuzzy_like_string(string: &str) -> String {
616 let mut result = String::with_capacity(string.len() * 2 + 1);
617 for c in string.chars() {
618 if c.is_alphanumeric() {
619 result.push('%');
620 result.push(c);
621 }
622 }
623 result.push('%');
624 result
625 }
626
627 // users
628
629 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
630 test_support!(self, {
631 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
632 Ok(sqlx::query_as(query)
633 .bind(limit as i32)
634 .bind((page * limit) as i32)
635 .fetch_all(&self.pool)
636 .await?)
637 })
638 }
639
640 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
641 test_support!(self, {
642 let query = "
643 SELECT users.*
644 FROM users
645 WHERE id = $1
646 LIMIT 1
647 ";
648 Ok(sqlx::query_as(query)
649 .bind(&id)
650 .fetch_optional(&self.pool)
651 .await?)
652 })
653 }
654
655 pub async fn get_users_with_no_invites(
656 &self,
657 invited_by_another_user: bool,
658 ) -> Result<Vec<User>> {
659 test_support!(self, {
660 let query = format!(
661 "
662 SELECT users.*
663 FROM users
664 WHERE invite_count = 0
665 AND inviter_id IS{} NULL
666 ",
667 if invited_by_another_user { " NOT" } else { "" }
668 );
669
670 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
671 })
672 }
673
674 pub async fn get_user_by_github_account(
675 &self,
676 github_login: &str,
677 github_user_id: Option<i32>,
678 ) -> Result<Option<User>> {
679 test_support!(self, {
680 if let Some(github_user_id) = github_user_id {
681 let mut user = sqlx::query_as::<_, User>(
682 "
683 UPDATE users
684 SET github_login = $1
685 WHERE github_user_id = $2
686 RETURNING *
687 ",
688 )
689 .bind(github_login)
690 .bind(github_user_id)
691 .fetch_optional(&self.pool)
692 .await?;
693
694 if user.is_none() {
695 user = sqlx::query_as::<_, User>(
696 "
697 UPDATE users
698 SET github_user_id = $1
699 WHERE github_login = $2
700 RETURNING *
701 ",
702 )
703 .bind(github_user_id)
704 .bind(github_login)
705 .fetch_optional(&self.pool)
706 .await?;
707 }
708
709 Ok(user)
710 } else {
711 let user = sqlx::query_as(
712 "
713 SELECT * FROM users
714 WHERE github_login = $1
715 LIMIT 1
716 ",
717 )
718 .bind(github_login)
719 .fetch_optional(&self.pool)
720 .await?;
721 Ok(user)
722 }
723 })
724 }
725
726 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
727 test_support!(self, {
728 let query = "UPDATE users SET admin = $1 WHERE id = $2";
729 Ok(sqlx::query(query)
730 .bind(is_admin)
731 .bind(id.0)
732 .execute(&self.pool)
733 .await
734 .map(drop)?)
735 })
736 }
737
738 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
739 test_support!(self, {
740 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
741 Ok(sqlx::query(query)
742 .bind(connected_once)
743 .bind(id.0)
744 .execute(&self.pool)
745 .await
746 .map(drop)?)
747 })
748 }
749
750 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
751 test_support!(self, {
752 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
753 sqlx::query(query)
754 .bind(id.0)
755 .execute(&self.pool)
756 .await
757 .map(drop)?;
758 let query = "DELETE FROM users WHERE id = $1;";
759 Ok(sqlx::query(query)
760 .bind(id.0)
761 .execute(&self.pool)
762 .await
763 .map(drop)?)
764 })
765 }
766
767 // signups
768
769 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
770 test_support!(self, {
771 Ok(sqlx::query_as(
772 "
773 SELECT
774 COUNT(*) as count,
775 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
776 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
777 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
778 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
779 FROM (
780 SELECT *
781 FROM signups
782 WHERE
783 NOT email_confirmation_sent
784 ) AS unsent
785 ",
786 )
787 .fetch_one(&self.pool)
788 .await?)
789 })
790 }
791
792 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
793 test_support!(self, {
794 Ok(sqlx::query_as(
795 "
796 SELECT
797 email_address, email_confirmation_code
798 FROM signups
799 WHERE
800 NOT email_confirmation_sent AND
801 (platform_mac OR platform_unknown)
802 LIMIT $1
803 ",
804 )
805 .bind(count as i32)
806 .fetch_all(&self.pool)
807 .await?)
808 })
809 }
810
811 // invite codes
812
813 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
814 test_support!(self, {
815 let mut tx = self.pool.begin().await?;
816 if count > 0 {
817 sqlx::query(
818 "
819 UPDATE users
820 SET invite_code = $1
821 WHERE id = $2 AND invite_code IS NULL
822 ",
823 )
824 .bind(random_invite_code())
825 .bind(id)
826 .execute(&mut tx)
827 .await?;
828 }
829
830 sqlx::query(
831 "
832 UPDATE users
833 SET invite_count = $1
834 WHERE id = $2
835 ",
836 )
837 .bind(count as i32)
838 .bind(id)
839 .execute(&mut tx)
840 .await?;
841 tx.commit().await?;
842 Ok(())
843 })
844 }
845
846 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
847 test_support!(self, {
848 let result: Option<(String, i32)> = sqlx::query_as(
849 "
850 SELECT invite_code, invite_count
851 FROM users
852 WHERE id = $1 AND invite_code IS NOT NULL
853 ",
854 )
855 .bind(id)
856 .fetch_optional(&self.pool)
857 .await?;
858 if let Some((code, count)) = result {
859 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
860 } else {
861 Ok(None)
862 }
863 })
864 }
865
866 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
867 test_support!(self, {
868 sqlx::query_as(
869 "
870 SELECT *
871 FROM users
872 WHERE invite_code = $1
873 ",
874 )
875 .bind(code)
876 .fetch_optional(&self.pool)
877 .await?
878 .ok_or_else(|| {
879 Error::Http(
880 StatusCode::NOT_FOUND,
881 "that invite code does not exist".to_string(),
882 )
883 })
884 })
885 }
886
887 pub async fn create_room(
888 &self,
889 user_id: UserId,
890 connection_id: ConnectionId,
891 ) -> Result<proto::Room> {
892 test_support!(self, {
893 let mut tx = self.pool.begin().await?;
894 let live_kit_room = nanoid::nanoid!(30);
895 let room_id = sqlx::query_scalar(
896 "
897 INSERT INTO rooms (live_kit_room, version)
898 VALUES ($1, $2)
899 RETURNING id
900 ",
901 )
902 .bind(&live_kit_room)
903 .bind(0)
904 .fetch_one(&mut tx)
905 .await
906 .map(RoomId)?;
907
908 sqlx::query(
909 "
910 INSERT INTO room_participants (room_id, user_id, connection_id)
911 VALUES ($1, $2, $3)
912 ",
913 )
914 .bind(room_id)
915 .bind(user_id)
916 .bind(connection_id.0 as i32)
917 .execute(&mut tx)
918 .await?;
919
920 sqlx::query(
921 "
922 INSERT INTO calls (room_id, calling_user_id, called_user_id, answering_connection_id)
923 VALUES ($1, $2, $3, $4)
924 ",
925 )
926 .bind(room_id)
927 .bind(user_id)
928 .bind(user_id)
929 .bind(connection_id.0 as i32)
930 .execute(&mut tx)
931 .await?;
932
933 self.commit_room_transaction(room_id, tx).await
934 })
935 }
936
937 pub async fn call(
938 &self,
939 room_id: RoomId,
940 calling_user_id: UserId,
941 called_user_id: UserId,
942 initial_project_id: Option<ProjectId>,
943 ) -> Result<proto::Room> {
944 test_support!(self, {
945 let mut tx = self.pool.begin().await?;
946 sqlx::query(
947 "
948 INSERT INTO calls (room_id, calling_user_id, called_user_id, initial_project_id)
949 VALUES ($1, $2, $3, $4)
950 ",
951 )
952 .bind(room_id)
953 .bind(calling_user_id)
954 .bind(called_user_id)
955 .bind(initial_project_id)
956 .execute(&mut tx)
957 .await?;
958
959 sqlx::query(
960 "
961 INSERT INTO room_participants (room_id, user_id)
962 VALUES ($1, $2)
963 ",
964 )
965 .bind(room_id)
966 .bind(called_user_id)
967 .execute(&mut tx)
968 .await?;
969
970 self.commit_room_transaction(room_id, tx).await
971 })
972 }
973
974 pub async fn call_failed(
975 &self,
976 room_id: RoomId,
977 called_user_id: UserId,
978 ) -> Result<proto::Room> {
979 test_support!(self, {
980 let mut tx = self.pool.begin().await?;
981 sqlx::query(
982 "
983 DELETE FROM calls
984 WHERE room_id = $1 AND called_user_id = $2
985 ",
986 )
987 .bind(room_id)
988 .bind(called_user_id)
989 .execute(&mut tx)
990 .await?;
991
992 sqlx::query(
993 "
994 DELETE FROM room_participants
995 WHERE room_id = $1 AND user_id = $2
996 ",
997 )
998 .bind(room_id)
999 .bind(called_user_id)
1000 .execute(&mut tx)
1001 .await?;
1002
1003 self.commit_room_transaction(room_id, tx).await
1004 })
1005 }
1006
1007 pub async fn update_room_participant_location(
1008 &self,
1009 room_id: RoomId,
1010 user_id: UserId,
1011 location: proto::ParticipantLocation,
1012 ) -> Result<proto::Room> {
1013 test_support!(self, {
1014 let mut tx = self.pool.begin().await?;
1015
1016 let location_kind;
1017 let location_project_id;
1018 match location
1019 .variant
1020 .ok_or_else(|| anyhow!("invalid location"))?
1021 {
1022 proto::participant_location::Variant::SharedProject(project) => {
1023 location_kind = 0;
1024 location_project_id = Some(ProjectId::from_proto(project.id));
1025 }
1026 proto::participant_location::Variant::UnsharedProject(_) => {
1027 location_kind = 1;
1028 location_project_id = None;
1029 }
1030 proto::participant_location::Variant::External(_) => {
1031 location_kind = 2;
1032 location_project_id = None;
1033 }
1034 }
1035
1036 sqlx::query(
1037 "
1038 UPDATE room_participants
1039 SET location_kind = $1 AND location_project_id = $2
1040 WHERE room_id = $1 AND user_id = $2
1041 ",
1042 )
1043 .bind(location_kind)
1044 .bind(location_project_id)
1045 .bind(room_id)
1046 .bind(user_id)
1047 .execute(&mut tx)
1048 .await?;
1049
1050 self.commit_room_transaction(room_id, tx).await
1051 })
1052 }
1053
1054 async fn commit_room_transaction(
1055 &self,
1056 room_id: RoomId,
1057 mut tx: sqlx::Transaction<'_, D>,
1058 ) -> Result<proto::Room> {
1059 sqlx::query(
1060 "
1061 UPDATE rooms
1062 SET version = version + 1
1063 WHERE id = $1
1064 ",
1065 )
1066 .bind(room_id)
1067 .execute(&mut tx)
1068 .await?;
1069
1070 let room: Room = sqlx::query_as(
1071 "
1072 SELECT *
1073 FROM rooms
1074 WHERE id = $1
1075 ",
1076 )
1077 .bind(room_id)
1078 .fetch_one(&mut tx)
1079 .await?;
1080
1081 let mut db_participants =
1082 sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<i32>)>(
1083 "
1084 SELECT user_id, connection_id, location_kind, location_project_id
1085 FROM room_participants
1086 WHERE room_id = $1
1087 ",
1088 )
1089 .bind(room_id)
1090 .fetch(&mut tx);
1091
1092 let mut participants = Vec::new();
1093 let mut pending_participant_user_ids = Vec::new();
1094 while let Some(participant) = db_participants.next().await {
1095 let (user_id, connection_id, _location_kind, _location_project_id) = participant?;
1096 if let Some(connection_id) = connection_id {
1097 participants.push(proto::Participant {
1098 user_id: user_id.to_proto(),
1099 peer_id: connection_id as u32,
1100 projects: Default::default(),
1101 location: Some(proto::ParticipantLocation {
1102 variant: Some(proto::participant_location::Variant::External(
1103 Default::default(),
1104 )),
1105 }),
1106 });
1107 } else {
1108 pending_participant_user_ids.push(user_id.to_proto());
1109 }
1110 }
1111 drop(db_participants);
1112
1113 for participant in &mut participants {
1114 let mut entries = sqlx::query_as::<_, (ProjectId, String)>(
1115 "
1116 SELECT projects.id, worktrees.root_name
1117 FROM projects
1118 LEFT JOIN worktrees ON projects.id = worktrees.project_id
1119 WHERE room_id = $1 AND host_user_id = $2
1120 ",
1121 )
1122 .bind(room_id)
1123 .fetch(&mut tx);
1124
1125 let mut projects = HashMap::default();
1126 while let Some(entry) = entries.next().await {
1127 let (project_id, worktree_root_name) = entry?;
1128 let participant_project =
1129 projects
1130 .entry(project_id)
1131 .or_insert(proto::ParticipantProject {
1132 id: project_id.to_proto(),
1133 worktree_root_names: Default::default(),
1134 });
1135 participant_project
1136 .worktree_root_names
1137 .push(worktree_root_name);
1138 }
1139
1140 participant.projects = projects.into_values().collect();
1141 }
1142
1143 tx.commit().await?;
1144
1145 Ok(proto::Room {
1146 id: room.id.to_proto(),
1147 version: room.version as u64,
1148 live_kit_room: room.live_kit_room,
1149 participants,
1150 pending_participant_user_ids,
1151 })
1152 }
1153
1154 // projects
1155
1156 pub async fn share_project(
1157 &self,
1158 user_id: UserId,
1159 connection_id: ConnectionId,
1160 room_id: RoomId,
1161 worktrees: &[proto::WorktreeMetadata],
1162 ) -> Result<(ProjectId, proto::Room)> {
1163 test_support!(self, {
1164 let mut tx = self.pool.begin().await?;
1165 let project_id = sqlx::query_scalar(
1166 "
1167 INSERT INTO projects (host_user_id, room_id)
1168 VALUES ($1)
1169 RETURNING id
1170 ",
1171 )
1172 .bind(user_id)
1173 .bind(room_id)
1174 .fetch_one(&mut tx)
1175 .await
1176 .map(ProjectId)?;
1177
1178 for worktree in worktrees {
1179 sqlx::query(
1180 "
1181 INSERT INTO worktrees (id, project_id, root_name)
1182 ",
1183 )
1184 .bind(worktree.id as i32)
1185 .bind(project_id)
1186 .bind(&worktree.root_name)
1187 .execute(&mut tx)
1188 .await?;
1189 }
1190
1191 sqlx::query(
1192 "
1193 INSERT INTO project_collaborators (
1194 project_id,
1195 connection_id,
1196 user_id,
1197 replica_id,
1198 is_host
1199 )
1200 VALUES ($1, $2, $3, $4, $5)
1201 ",
1202 )
1203 .bind(project_id)
1204 .bind(connection_id.0 as i32)
1205 .bind(user_id)
1206 .bind(0)
1207 .bind(true)
1208 .execute(&mut tx)
1209 .await?;
1210
1211 let room = self.commit_room_transaction(room_id, tx).await?;
1212 Ok((project_id, room))
1213 })
1214 }
1215
1216 pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
1217 todo!()
1218 // test_support!(self, {
1219 // sqlx::query(
1220 // "
1221 // UPDATE projects
1222 // SET unregistered = TRUE
1223 // WHERE id = $1
1224 // ",
1225 // )
1226 // .bind(project_id)
1227 // .execute(&self.pool)
1228 // .await?;
1229 // Ok(())
1230 // })
1231 }
1232
1233 // contacts
1234
1235 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1236 test_support!(self, {
1237 let query = "
1238 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1239 FROM contacts
1240 WHERE user_id_a = $1 OR user_id_b = $1;
1241 ";
1242
1243 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1244 .bind(user_id)
1245 .fetch(&self.pool);
1246
1247 let mut contacts = Vec::new();
1248 while let Some(row) = rows.next().await {
1249 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1250
1251 if user_id_a == user_id {
1252 if accepted {
1253 contacts.push(Contact::Accepted {
1254 user_id: user_id_b,
1255 should_notify: should_notify && a_to_b,
1256 });
1257 } else if a_to_b {
1258 contacts.push(Contact::Outgoing { user_id: user_id_b })
1259 } else {
1260 contacts.push(Contact::Incoming {
1261 user_id: user_id_b,
1262 should_notify,
1263 });
1264 }
1265 } else if accepted {
1266 contacts.push(Contact::Accepted {
1267 user_id: user_id_a,
1268 should_notify: should_notify && !a_to_b,
1269 });
1270 } else if a_to_b {
1271 contacts.push(Contact::Incoming {
1272 user_id: user_id_a,
1273 should_notify,
1274 });
1275 } else {
1276 contacts.push(Contact::Outgoing { user_id: user_id_a });
1277 }
1278 }
1279
1280 contacts.sort_unstable_by_key(|contact| contact.user_id());
1281
1282 Ok(contacts)
1283 })
1284 }
1285
1286 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1287 test_support!(self, {
1288 let (id_a, id_b) = if user_id_1 < user_id_2 {
1289 (user_id_1, user_id_2)
1290 } else {
1291 (user_id_2, user_id_1)
1292 };
1293
1294 let query = "
1295 SELECT 1 FROM contacts
1296 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
1297 LIMIT 1
1298 ";
1299 Ok(sqlx::query_scalar::<_, i32>(query)
1300 .bind(id_a.0)
1301 .bind(id_b.0)
1302 .fetch_optional(&self.pool)
1303 .await?
1304 .is_some())
1305 })
1306 }
1307
1308 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1309 test_support!(self, {
1310 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1311 (sender_id, receiver_id, true)
1312 } else {
1313 (receiver_id, sender_id, false)
1314 };
1315 let query = "
1316 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1317 VALUES ($1, $2, $3, FALSE, TRUE)
1318 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1319 SET
1320 accepted = TRUE,
1321 should_notify = FALSE
1322 WHERE
1323 NOT contacts.accepted AND
1324 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1325 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1326 ";
1327 let result = sqlx::query(query)
1328 .bind(id_a.0)
1329 .bind(id_b.0)
1330 .bind(a_to_b)
1331 .execute(&self.pool)
1332 .await?;
1333
1334 if result.rows_affected() == 1 {
1335 Ok(())
1336 } else {
1337 Err(anyhow!("contact already requested"))?
1338 }
1339 })
1340 }
1341
1342 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1343 test_support!(self, {
1344 let (id_a, id_b) = if responder_id < requester_id {
1345 (responder_id, requester_id)
1346 } else {
1347 (requester_id, responder_id)
1348 };
1349 let query = "
1350 DELETE FROM contacts
1351 WHERE user_id_a = $1 AND user_id_b = $2;
1352 ";
1353 let result = sqlx::query(query)
1354 .bind(id_a.0)
1355 .bind(id_b.0)
1356 .execute(&self.pool)
1357 .await?;
1358
1359 if result.rows_affected() == 1 {
1360 Ok(())
1361 } else {
1362 Err(anyhow!("no such contact"))?
1363 }
1364 })
1365 }
1366
1367 pub async fn dismiss_contact_notification(
1368 &self,
1369 user_id: UserId,
1370 contact_user_id: UserId,
1371 ) -> Result<()> {
1372 test_support!(self, {
1373 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1374 (user_id, contact_user_id, true)
1375 } else {
1376 (contact_user_id, user_id, false)
1377 };
1378
1379 let query = "
1380 UPDATE contacts
1381 SET should_notify = FALSE
1382 WHERE
1383 user_id_a = $1 AND user_id_b = $2 AND
1384 (
1385 (a_to_b = $3 AND accepted) OR
1386 (a_to_b != $3 AND NOT accepted)
1387 );
1388 ";
1389
1390 let result = sqlx::query(query)
1391 .bind(id_a.0)
1392 .bind(id_b.0)
1393 .bind(a_to_b)
1394 .execute(&self.pool)
1395 .await?;
1396
1397 if result.rows_affected() == 0 {
1398 Err(anyhow!("no such contact request"))?;
1399 }
1400
1401 Ok(())
1402 })
1403 }
1404
1405 pub async fn respond_to_contact_request(
1406 &self,
1407 responder_id: UserId,
1408 requester_id: UserId,
1409 accept: bool,
1410 ) -> Result<()> {
1411 test_support!(self, {
1412 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1413 (responder_id, requester_id, false)
1414 } else {
1415 (requester_id, responder_id, true)
1416 };
1417 let result = if accept {
1418 let query = "
1419 UPDATE contacts
1420 SET accepted = TRUE, should_notify = TRUE
1421 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1422 ";
1423 sqlx::query(query)
1424 .bind(id_a.0)
1425 .bind(id_b.0)
1426 .bind(a_to_b)
1427 .execute(&self.pool)
1428 .await?
1429 } else {
1430 let query = "
1431 DELETE FROM contacts
1432 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1433 ";
1434 sqlx::query(query)
1435 .bind(id_a.0)
1436 .bind(id_b.0)
1437 .bind(a_to_b)
1438 .execute(&self.pool)
1439 .await?
1440 };
1441 if result.rows_affected() == 1 {
1442 Ok(())
1443 } else {
1444 Err(anyhow!("no such contact request"))?
1445 }
1446 })
1447 }
1448
1449 // access tokens
1450
1451 pub async fn create_access_token_hash(
1452 &self,
1453 user_id: UserId,
1454 access_token_hash: &str,
1455 max_access_token_count: usize,
1456 ) -> Result<()> {
1457 test_support!(self, {
1458 let insert_query = "
1459 INSERT INTO access_tokens (user_id, hash)
1460 VALUES ($1, $2);
1461 ";
1462 let cleanup_query = "
1463 DELETE FROM access_tokens
1464 WHERE id IN (
1465 SELECT id from access_tokens
1466 WHERE user_id = $1
1467 ORDER BY id DESC
1468 LIMIT 10000
1469 OFFSET $3
1470 )
1471 ";
1472
1473 let mut tx = self.pool.begin().await?;
1474 sqlx::query(insert_query)
1475 .bind(user_id.0)
1476 .bind(access_token_hash)
1477 .execute(&mut tx)
1478 .await?;
1479 sqlx::query(cleanup_query)
1480 .bind(user_id.0)
1481 .bind(access_token_hash)
1482 .bind(max_access_token_count as i32)
1483 .execute(&mut tx)
1484 .await?;
1485 Ok(tx.commit().await?)
1486 })
1487 }
1488
1489 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1490 test_support!(self, {
1491 let query = "
1492 SELECT hash
1493 FROM access_tokens
1494 WHERE user_id = $1
1495 ORDER BY id DESC
1496 ";
1497 Ok(sqlx::query_scalar(query)
1498 .bind(user_id.0)
1499 .fetch_all(&self.pool)
1500 .await?)
1501 })
1502 }
1503}
1504
1505macro_rules! id_type {
1506 ($name:ident) => {
1507 #[derive(
1508 Clone,
1509 Copy,
1510 Debug,
1511 Default,
1512 PartialEq,
1513 Eq,
1514 PartialOrd,
1515 Ord,
1516 Hash,
1517 sqlx::Type,
1518 Serialize,
1519 Deserialize,
1520 )]
1521 #[sqlx(transparent)]
1522 #[serde(transparent)]
1523 pub struct $name(pub i32);
1524
1525 impl $name {
1526 #[allow(unused)]
1527 pub const MAX: Self = Self(i32::MAX);
1528
1529 #[allow(unused)]
1530 pub fn from_proto(value: u64) -> Self {
1531 Self(value as i32)
1532 }
1533
1534 #[allow(unused)]
1535 pub fn to_proto(self) -> u64 {
1536 self.0 as u64
1537 }
1538 }
1539
1540 impl std::fmt::Display for $name {
1541 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1542 self.0.fmt(f)
1543 }
1544 }
1545 };
1546}
1547
1548id_type!(UserId);
1549#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1550pub struct User {
1551 pub id: UserId,
1552 pub github_login: String,
1553 pub github_user_id: Option<i32>,
1554 pub email_address: Option<String>,
1555 pub admin: bool,
1556 pub invite_code: Option<String>,
1557 pub invite_count: i32,
1558 pub connected_once: bool,
1559}
1560
1561id_type!(RoomId);
1562#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1563pub struct Room {
1564 pub id: RoomId,
1565 pub version: i32,
1566 pub live_kit_room: String,
1567}
1568
1569id_type!(ProjectId);
1570#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1571pub struct Project {
1572 pub id: ProjectId,
1573 pub host_user_id: UserId,
1574 pub unregistered: bool,
1575}
1576
1577#[derive(Clone, Debug, PartialEq, Eq)]
1578pub enum Contact {
1579 Accepted {
1580 user_id: UserId,
1581 should_notify: bool,
1582 },
1583 Outgoing {
1584 user_id: UserId,
1585 },
1586 Incoming {
1587 user_id: UserId,
1588 should_notify: bool,
1589 },
1590}
1591
1592impl Contact {
1593 pub fn user_id(&self) -> UserId {
1594 match self {
1595 Contact::Accepted { user_id, .. } => *user_id,
1596 Contact::Outgoing { user_id } => *user_id,
1597 Contact::Incoming { user_id, .. } => *user_id,
1598 }
1599 }
1600}
1601
1602#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1603pub struct IncomingContactRequest {
1604 pub requester_id: UserId,
1605 pub should_notify: bool,
1606}
1607
1608#[derive(Clone, Deserialize)]
1609pub struct Signup {
1610 pub email_address: String,
1611 pub platform_mac: bool,
1612 pub platform_windows: bool,
1613 pub platform_linux: bool,
1614 pub editor_features: Vec<String>,
1615 pub programming_languages: Vec<String>,
1616 pub device_id: Option<String>,
1617}
1618
1619#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1620pub struct WaitlistSummary {
1621 #[sqlx(default)]
1622 pub count: i64,
1623 #[sqlx(default)]
1624 pub linux_count: i64,
1625 #[sqlx(default)]
1626 pub mac_count: i64,
1627 #[sqlx(default)]
1628 pub windows_count: i64,
1629 #[sqlx(default)]
1630 pub unknown_count: i64,
1631}
1632
1633#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1634pub struct Invite {
1635 pub email_address: String,
1636 pub email_confirmation_code: String,
1637}
1638
1639#[derive(Debug, Serialize, Deserialize)]
1640pub struct NewUserParams {
1641 pub github_login: String,
1642 pub github_user_id: i32,
1643 pub invite_count: i32,
1644}
1645
1646#[derive(Debug)]
1647pub struct NewUserResult {
1648 pub user_id: UserId,
1649 pub metrics_id: String,
1650 pub inviting_user_id: Option<UserId>,
1651 pub signup_device_id: Option<String>,
1652}
1653
1654fn random_invite_code() -> String {
1655 nanoid::nanoid!(16)
1656}
1657
1658fn random_email_confirmation_code() -> String {
1659 nanoid::nanoid!(64)
1660}
1661
1662#[cfg(test)]
1663pub use test::*;
1664
1665#[cfg(test)]
1666mod test {
1667 use super::*;
1668 use gpui::executor::Background;
1669 use lazy_static::lazy_static;
1670 use parking_lot::Mutex;
1671 use rand::prelude::*;
1672 use sqlx::migrate::MigrateDatabase;
1673 use std::sync::Arc;
1674
1675 pub struct SqliteTestDb {
1676 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1677 pub conn: sqlx::sqlite::SqliteConnection,
1678 }
1679
1680 pub struct PostgresTestDb {
1681 pub db: Option<Arc<Db<sqlx::Postgres>>>,
1682 pub url: String,
1683 }
1684
1685 impl SqliteTestDb {
1686 pub fn new(background: Arc<Background>) -> Self {
1687 let mut rng = StdRng::from_entropy();
1688 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1689 let runtime = tokio::runtime::Builder::new_current_thread()
1690 .enable_io()
1691 .enable_time()
1692 .build()
1693 .unwrap();
1694
1695 let (mut db, conn) = runtime.block_on(async {
1696 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1697 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1698 db.migrate(migrations_path.as_ref(), false).await.unwrap();
1699 let conn = db.pool.acquire().await.unwrap().detach();
1700 (db, conn)
1701 });
1702
1703 db.background = Some(background);
1704 db.runtime = Some(runtime);
1705
1706 Self {
1707 db: Some(Arc::new(db)),
1708 conn,
1709 }
1710 }
1711
1712 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1713 self.db.as_ref().unwrap()
1714 }
1715 }
1716
1717 impl PostgresTestDb {
1718 pub fn new(background: Arc<Background>) -> Self {
1719 lazy_static! {
1720 static ref LOCK: Mutex<()> = Mutex::new(());
1721 }
1722
1723 let _guard = LOCK.lock();
1724 let mut rng = StdRng::from_entropy();
1725 let url = format!(
1726 "postgres://postgres@localhost/zed-test-{}",
1727 rng.gen::<u128>()
1728 );
1729 let runtime = tokio::runtime::Builder::new_current_thread()
1730 .enable_io()
1731 .enable_time()
1732 .build()
1733 .unwrap();
1734
1735 let mut db = runtime.block_on(async {
1736 sqlx::Postgres::create_database(&url)
1737 .await
1738 .expect("failed to create test db");
1739 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1740 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1741 db.migrate(Path::new(migrations_path), false).await.unwrap();
1742 db
1743 });
1744
1745 db.background = Some(background);
1746 db.runtime = Some(runtime);
1747
1748 Self {
1749 db: Some(Arc::new(db)),
1750 url,
1751 }
1752 }
1753
1754 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1755 self.db.as_ref().unwrap()
1756 }
1757 }
1758
1759 impl Drop for PostgresTestDb {
1760 fn drop(&mut self) {
1761 let db = self.db.take().unwrap();
1762 db.teardown(&self.url);
1763 }
1764 }
1765}