1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use rpc::{proto, ConnectionId};
7use serde::{Deserialize, Serialize};
8use sqlx::{
9 migrate::{Migrate as _, Migration, MigrationSource},
10 types::Uuid,
11 FromRow,
12};
13use std::{path::Path, time::Duration};
14use time::{OffsetDateTime, PrimitiveDateTime};
15
16#[cfg(test)]
17pub type DefaultDb = Db<sqlx::Sqlite>;
18
19#[cfg(not(test))]
20pub type DefaultDb = Db<sqlx::Postgres>;
21
22pub struct Db<D: sqlx::Database> {
23 pool: sqlx::Pool<D>,
24 #[cfg(test)]
25 background: Option<std::sync::Arc<gpui::executor::Background>>,
26 #[cfg(test)]
27 runtime: Option<tokio::runtime::Runtime>,
28}
29
30macro_rules! test_support {
31 ($self:ident, { $($token:tt)* }) => {{
32 let body = async {
33 $($token)*
34 };
35
36 if cfg!(test) {
37 #[cfg(not(test))]
38 unreachable!();
39
40 #[cfg(test)]
41 if let Some(background) = $self.background.as_ref() {
42 background.simulate_random_delay().await;
43 }
44
45 #[cfg(test)]
46 $self.runtime.as_ref().unwrap().block_on(body)
47 } else {
48 body.await
49 }
50 }};
51}
52
53pub trait RowsAffected {
54 fn rows_affected(&self) -> u64;
55}
56
57#[cfg(test)]
58impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
59 fn rows_affected(&self) -> u64 {
60 self.rows_affected()
61 }
62}
63
64impl RowsAffected for sqlx::postgres::PgQueryResult {
65 fn rows_affected(&self) -> u64 {
66 self.rows_affected()
67 }
68}
69
70#[cfg(test)]
71impl Db<sqlx::Sqlite> {
72 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
73 use std::str::FromStr as _;
74 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
75 .unwrap()
76 .create_if_missing(true)
77 .shared_cache(true);
78 let pool = sqlx::sqlite::SqlitePoolOptions::new()
79 .min_connections(2)
80 .max_connections(max_connections)
81 .connect_with(options)
82 .await?;
83 Ok(Self {
84 pool,
85 background: None,
86 runtime: None,
87 })
88 }
89
90 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
91 test_support!(self, {
92 let query = "
93 SELECT users.*
94 FROM users
95 WHERE users.id IN (SELECT value from json_each($1))
96 ";
97 Ok(sqlx::query_as(query)
98 .bind(&serde_json::json!(ids))
99 .fetch_all(&self.pool)
100 .await?)
101 })
102 }
103
104 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
105 test_support!(self, {
106 let query = "
107 SELECT metrics_id
108 FROM users
109 WHERE id = $1
110 ";
111 Ok(sqlx::query_scalar(query)
112 .bind(id)
113 .fetch_one(&self.pool)
114 .await?)
115 })
116 }
117
118 pub async fn create_user(
119 &self,
120 email_address: &str,
121 admin: bool,
122 params: NewUserParams,
123 ) -> Result<NewUserResult> {
124 test_support!(self, {
125 let query = "
126 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
127 VALUES ($1, $2, $3, $4, $5)
128 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
129 RETURNING id, metrics_id
130 ";
131
132 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
133 .bind(email_address)
134 .bind(params.github_login)
135 .bind(params.github_user_id)
136 .bind(admin)
137 .bind(Uuid::new_v4().to_string())
138 .fetch_one(&self.pool)
139 .await?;
140 Ok(NewUserResult {
141 user_id,
142 metrics_id,
143 signup_device_id: None,
144 inviting_user_id: None,
145 })
146 })
147 }
148
149 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
150 unimplemented!()
151 }
152
153 pub async fn create_user_from_invite(
154 &self,
155 _invite: &Invite,
156 _user: NewUserParams,
157 ) -> Result<Option<NewUserResult>> {
158 unimplemented!()
159 }
160
161 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
162 unimplemented!()
163 }
164
165 pub async fn create_invite_from_code(
166 &self,
167 _code: &str,
168 _email_address: &str,
169 _device_id: Option<&str>,
170 ) -> Result<Invite> {
171 unimplemented!()
172 }
173
174 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
175 unimplemented!()
176 }
177}
178
179impl Db<sqlx::Postgres> {
180 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
181 let pool = sqlx::postgres::PgPoolOptions::new()
182 .max_connections(max_connections)
183 .connect(url)
184 .await?;
185 Ok(Self {
186 pool,
187 #[cfg(test)]
188 background: None,
189 #[cfg(test)]
190 runtime: None,
191 })
192 }
193
194 #[cfg(test)]
195 pub fn teardown(&self, url: &str) {
196 self.runtime.as_ref().unwrap().block_on(async {
197 use util::ResultExt;
198 let query = "
199 SELECT pg_terminate_backend(pg_stat_activity.pid)
200 FROM pg_stat_activity
201 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
202 ";
203 sqlx::query(query).execute(&self.pool).await.log_err();
204 self.pool.close().await;
205 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
206 .await
207 .log_err();
208 })
209 }
210
211 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
212 test_support!(self, {
213 let like_string = Self::fuzzy_like_string(name_query);
214 let query = "
215 SELECT users.*
216 FROM users
217 WHERE github_login ILIKE $1
218 ORDER BY github_login <-> $2
219 LIMIT $3
220 ";
221 Ok(sqlx::query_as(query)
222 .bind(like_string)
223 .bind(name_query)
224 .bind(limit as i32)
225 .fetch_all(&self.pool)
226 .await?)
227 })
228 }
229
230 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
231 test_support!(self, {
232 let query = "
233 SELECT users.*
234 FROM users
235 WHERE users.id = ANY ($1)
236 ";
237 Ok(sqlx::query_as(query)
238 .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
239 .fetch_all(&self.pool)
240 .await?)
241 })
242 }
243
244 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
245 test_support!(self, {
246 let query = "
247 SELECT metrics_id::text
248 FROM users
249 WHERE id = $1
250 ";
251 Ok(sqlx::query_scalar(query)
252 .bind(id)
253 .fetch_one(&self.pool)
254 .await?)
255 })
256 }
257
258 pub async fn create_user(
259 &self,
260 email_address: &str,
261 admin: bool,
262 params: NewUserParams,
263 ) -> Result<NewUserResult> {
264 test_support!(self, {
265 let query = "
266 INSERT INTO users (email_address, github_login, github_user_id, admin)
267 VALUES ($1, $2, $3, $4)
268 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
269 RETURNING id, metrics_id::text
270 ";
271
272 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
273 .bind(email_address)
274 .bind(params.github_login)
275 .bind(params.github_user_id)
276 .bind(admin)
277 .fetch_one(&self.pool)
278 .await?;
279 Ok(NewUserResult {
280 user_id,
281 metrics_id,
282 signup_device_id: None,
283 inviting_user_id: None,
284 })
285 })
286 }
287
288 pub async fn create_user_from_invite(
289 &self,
290 invite: &Invite,
291 user: NewUserParams,
292 ) -> Result<Option<NewUserResult>> {
293 test_support!(self, {
294 let mut tx = self.pool.begin().await?;
295
296 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
297 i32,
298 Option<UserId>,
299 Option<UserId>,
300 Option<String>,
301 ) = sqlx::query_as(
302 "
303 SELECT id, user_id, inviting_user_id, device_id
304 FROM signups
305 WHERE
306 email_address = $1 AND
307 email_confirmation_code = $2
308 ",
309 )
310 .bind(&invite.email_address)
311 .bind(&invite.email_confirmation_code)
312 .fetch_optional(&mut tx)
313 .await?
314 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
315
316 if existing_user_id.is_some() {
317 return Ok(None);
318 }
319
320 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
321 "
322 INSERT INTO users
323 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
324 VALUES
325 ($1, $2, $3, FALSE, $4, $5)
326 ON CONFLICT (github_login) DO UPDATE SET
327 email_address = excluded.email_address,
328 github_user_id = excluded.github_user_id,
329 admin = excluded.admin
330 RETURNING id, metrics_id::text
331 ",
332 )
333 .bind(&invite.email_address)
334 .bind(&user.github_login)
335 .bind(&user.github_user_id)
336 .bind(&user.invite_count)
337 .bind(random_invite_code())
338 .fetch_one(&mut tx)
339 .await?;
340
341 sqlx::query(
342 "
343 UPDATE signups
344 SET user_id = $1
345 WHERE id = $2
346 ",
347 )
348 .bind(&user_id)
349 .bind(&signup_id)
350 .execute(&mut tx)
351 .await?;
352
353 if let Some(inviting_user_id) = inviting_user_id {
354 let id: Option<UserId> = sqlx::query_scalar(
355 "
356 UPDATE users
357 SET invite_count = invite_count - 1
358 WHERE id = $1 AND invite_count > 0
359 RETURNING id
360 ",
361 )
362 .bind(&inviting_user_id)
363 .fetch_optional(&mut tx)
364 .await?;
365
366 if id.is_none() {
367 Err(Error::Http(
368 StatusCode::UNAUTHORIZED,
369 "no invites remaining".to_string(),
370 ))?;
371 }
372
373 sqlx::query(
374 "
375 INSERT INTO contacts
376 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
377 VALUES
378 ($1, $2, TRUE, TRUE, TRUE)
379 ON CONFLICT DO NOTHING
380 ",
381 )
382 .bind(inviting_user_id)
383 .bind(user_id)
384 .execute(&mut tx)
385 .await?;
386 }
387
388 tx.commit().await?;
389 Ok(Some(NewUserResult {
390 user_id,
391 metrics_id,
392 inviting_user_id,
393 signup_device_id,
394 }))
395 })
396 }
397
398 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
399 test_support!(self, {
400 sqlx::query(
401 "
402 INSERT INTO signups
403 (
404 email_address,
405 email_confirmation_code,
406 email_confirmation_sent,
407 platform_linux,
408 platform_mac,
409 platform_windows,
410 platform_unknown,
411 editor_features,
412 programming_languages,
413 device_id
414 )
415 VALUES
416 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
417 RETURNING id
418 ",
419 )
420 .bind(&signup.email_address)
421 .bind(&random_email_confirmation_code())
422 .bind(&signup.platform_linux)
423 .bind(&signup.platform_mac)
424 .bind(&signup.platform_windows)
425 .bind(&signup.editor_features)
426 .bind(&signup.programming_languages)
427 .bind(&signup.device_id)
428 .execute(&self.pool)
429 .await?;
430 Ok(())
431 })
432 }
433
434 pub async fn create_invite_from_code(
435 &self,
436 code: &str,
437 email_address: &str,
438 device_id: Option<&str>,
439 ) -> Result<Invite> {
440 test_support!(self, {
441 let mut tx = self.pool.begin().await?;
442
443 let existing_user: Option<UserId> = sqlx::query_scalar(
444 "
445 SELECT id
446 FROM users
447 WHERE email_address = $1
448 ",
449 )
450 .bind(email_address)
451 .fetch_optional(&mut tx)
452 .await?;
453 if existing_user.is_some() {
454 Err(anyhow!("email address is already in use"))?;
455 }
456
457 let row: Option<(UserId, i32)> = sqlx::query_as(
458 "
459 SELECT id, invite_count
460 FROM users
461 WHERE invite_code = $1
462 ",
463 )
464 .bind(code)
465 .fetch_optional(&mut tx)
466 .await?;
467
468 let (inviter_id, invite_count) = match row {
469 Some(row) => row,
470 None => Err(Error::Http(
471 StatusCode::NOT_FOUND,
472 "invite code not found".to_string(),
473 ))?,
474 };
475
476 if invite_count == 0 {
477 Err(Error::Http(
478 StatusCode::UNAUTHORIZED,
479 "no invites remaining".to_string(),
480 ))?;
481 }
482
483 let email_confirmation_code: String = sqlx::query_scalar(
484 "
485 INSERT INTO signups
486 (
487 email_address,
488 email_confirmation_code,
489 email_confirmation_sent,
490 inviting_user_id,
491 platform_linux,
492 platform_mac,
493 platform_windows,
494 platform_unknown,
495 device_id
496 )
497 VALUES
498 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
499 ON CONFLICT (email_address)
500 DO UPDATE SET
501 inviting_user_id = excluded.inviting_user_id
502 RETURNING email_confirmation_code
503 ",
504 )
505 .bind(&email_address)
506 .bind(&random_email_confirmation_code())
507 .bind(&inviter_id)
508 .bind(&device_id)
509 .fetch_one(&mut tx)
510 .await?;
511
512 tx.commit().await?;
513
514 Ok(Invite {
515 email_address: email_address.into(),
516 email_confirmation_code,
517 })
518 })
519 }
520
521 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
522 test_support!(self, {
523 let emails = invites
524 .iter()
525 .map(|s| s.email_address.as_str())
526 .collect::<Vec<_>>();
527 sqlx::query(
528 "
529 UPDATE signups
530 SET email_confirmation_sent = TRUE
531 WHERE email_address = ANY ($1)
532 ",
533 )
534 .bind(&emails)
535 .execute(&self.pool)
536 .await?;
537 Ok(())
538 })
539 }
540}
541
542impl<D> Db<D>
543where
544 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
545 D::Connection: sqlx::migrate::Migrate,
546 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
547 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
548 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
549 D::QueryResult: RowsAffected,
550 String: sqlx::Type<D>,
551 i32: sqlx::Type<D>,
552 i64: sqlx::Type<D>,
553 bool: sqlx::Type<D>,
554 str: sqlx::Type<D>,
555 Uuid: sqlx::Type<D>,
556 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
557 OffsetDateTime: sqlx::Type<D>,
558 PrimitiveDateTime: sqlx::Type<D>,
559 usize: sqlx::ColumnIndex<D::Row>,
560 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
561 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
562 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
563 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
564 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
565 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
566 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
567 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
568 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
569 for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
570 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
571 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
572 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
573{
574 pub async fn migrate(
575 &self,
576 migrations_path: &Path,
577 ignore_checksum_mismatch: bool,
578 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
579 let migrations = MigrationSource::resolve(migrations_path)
580 .await
581 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
582
583 let mut conn = self.pool.acquire().await?;
584
585 conn.ensure_migrations_table().await?;
586 let applied_migrations: HashMap<_, _> = conn
587 .list_applied_migrations()
588 .await?
589 .into_iter()
590 .map(|m| (m.version, m))
591 .collect();
592
593 let mut new_migrations = Vec::new();
594 for migration in migrations {
595 match applied_migrations.get(&migration.version) {
596 Some(applied_migration) => {
597 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
598 {
599 Err(anyhow!(
600 "checksum mismatch for applied migration {}",
601 migration.description
602 ))?;
603 }
604 }
605 None => {
606 let elapsed = conn.apply(&migration).await?;
607 new_migrations.push((migration, elapsed));
608 }
609 }
610 }
611
612 Ok(new_migrations)
613 }
614
615 pub fn fuzzy_like_string(string: &str) -> String {
616 let mut result = String::with_capacity(string.len() * 2 + 1);
617 for c in string.chars() {
618 if c.is_alphanumeric() {
619 result.push('%');
620 result.push(c);
621 }
622 }
623 result.push('%');
624 result
625 }
626
627 // users
628
629 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
630 test_support!(self, {
631 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
632 Ok(sqlx::query_as(query)
633 .bind(limit as i32)
634 .bind((page * limit) as i32)
635 .fetch_all(&self.pool)
636 .await?)
637 })
638 }
639
640 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
641 test_support!(self, {
642 let query = "
643 SELECT users.*
644 FROM users
645 WHERE id = $1
646 LIMIT 1
647 ";
648 Ok(sqlx::query_as(query)
649 .bind(&id)
650 .fetch_optional(&self.pool)
651 .await?)
652 })
653 }
654
655 pub async fn get_users_with_no_invites(
656 &self,
657 invited_by_another_user: bool,
658 ) -> Result<Vec<User>> {
659 test_support!(self, {
660 let query = format!(
661 "
662 SELECT users.*
663 FROM users
664 WHERE invite_count = 0
665 AND inviter_id IS{} NULL
666 ",
667 if invited_by_another_user { " NOT" } else { "" }
668 );
669
670 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
671 })
672 }
673
674 pub async fn get_user_by_github_account(
675 &self,
676 github_login: &str,
677 github_user_id: Option<i32>,
678 ) -> Result<Option<User>> {
679 test_support!(self, {
680 if let Some(github_user_id) = github_user_id {
681 let mut user = sqlx::query_as::<_, User>(
682 "
683 UPDATE users
684 SET github_login = $1
685 WHERE github_user_id = $2
686 RETURNING *
687 ",
688 )
689 .bind(github_login)
690 .bind(github_user_id)
691 .fetch_optional(&self.pool)
692 .await?;
693
694 if user.is_none() {
695 user = sqlx::query_as::<_, User>(
696 "
697 UPDATE users
698 SET github_user_id = $1
699 WHERE github_login = $2
700 RETURNING *
701 ",
702 )
703 .bind(github_user_id)
704 .bind(github_login)
705 .fetch_optional(&self.pool)
706 .await?;
707 }
708
709 Ok(user)
710 } else {
711 let user = sqlx::query_as(
712 "
713 SELECT * FROM users
714 WHERE github_login = $1
715 LIMIT 1
716 ",
717 )
718 .bind(github_login)
719 .fetch_optional(&self.pool)
720 .await?;
721 Ok(user)
722 }
723 })
724 }
725
726 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
727 test_support!(self, {
728 let query = "UPDATE users SET admin = $1 WHERE id = $2";
729 Ok(sqlx::query(query)
730 .bind(is_admin)
731 .bind(id.0)
732 .execute(&self.pool)
733 .await
734 .map(drop)?)
735 })
736 }
737
738 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
739 test_support!(self, {
740 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
741 Ok(sqlx::query(query)
742 .bind(connected_once)
743 .bind(id.0)
744 .execute(&self.pool)
745 .await
746 .map(drop)?)
747 })
748 }
749
750 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
751 test_support!(self, {
752 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
753 sqlx::query(query)
754 .bind(id.0)
755 .execute(&self.pool)
756 .await
757 .map(drop)?;
758 let query = "DELETE FROM users WHERE id = $1;";
759 Ok(sqlx::query(query)
760 .bind(id.0)
761 .execute(&self.pool)
762 .await
763 .map(drop)?)
764 })
765 }
766
767 // signups
768
769 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
770 test_support!(self, {
771 Ok(sqlx::query_as(
772 "
773 SELECT
774 COUNT(*) as count,
775 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
776 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
777 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
778 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
779 FROM (
780 SELECT *
781 FROM signups
782 WHERE
783 NOT email_confirmation_sent
784 ) AS unsent
785 ",
786 )
787 .fetch_one(&self.pool)
788 .await?)
789 })
790 }
791
792 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
793 test_support!(self, {
794 Ok(sqlx::query_as(
795 "
796 SELECT
797 email_address, email_confirmation_code
798 FROM signups
799 WHERE
800 NOT email_confirmation_sent AND
801 (platform_mac OR platform_unknown)
802 LIMIT $1
803 ",
804 )
805 .bind(count as i32)
806 .fetch_all(&self.pool)
807 .await?)
808 })
809 }
810
811 // invite codes
812
813 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
814 test_support!(self, {
815 let mut tx = self.pool.begin().await?;
816 if count > 0 {
817 sqlx::query(
818 "
819 UPDATE users
820 SET invite_code = $1
821 WHERE id = $2 AND invite_code IS NULL
822 ",
823 )
824 .bind(random_invite_code())
825 .bind(id)
826 .execute(&mut tx)
827 .await?;
828 }
829
830 sqlx::query(
831 "
832 UPDATE users
833 SET invite_count = $1
834 WHERE id = $2
835 ",
836 )
837 .bind(count as i32)
838 .bind(id)
839 .execute(&mut tx)
840 .await?;
841 tx.commit().await?;
842 Ok(())
843 })
844 }
845
846 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
847 test_support!(self, {
848 let result: Option<(String, i32)> = sqlx::query_as(
849 "
850 SELECT invite_code, invite_count
851 FROM users
852 WHERE id = $1 AND invite_code IS NOT NULL
853 ",
854 )
855 .bind(id)
856 .fetch_optional(&self.pool)
857 .await?;
858 if let Some((code, count)) = result {
859 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
860 } else {
861 Ok(None)
862 }
863 })
864 }
865
866 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
867 test_support!(self, {
868 sqlx::query_as(
869 "
870 SELECT *
871 FROM users
872 WHERE invite_code = $1
873 ",
874 )
875 .bind(code)
876 .fetch_optional(&self.pool)
877 .await?
878 .ok_or_else(|| {
879 Error::Http(
880 StatusCode::NOT_FOUND,
881 "that invite code does not exist".to_string(),
882 )
883 })
884 })
885 }
886
887 pub async fn create_room(
888 &self,
889 user_id: UserId,
890 connection_id: ConnectionId,
891 ) -> Result<proto::Room> {
892 test_support!(self, {
893 let mut tx = self.pool.begin().await?;
894 let live_kit_room = nanoid::nanoid!(30);
895 let room_id = sqlx::query_scalar(
896 "
897 INSERT INTO rooms (live_kit_room, version)
898 VALUES ($1, $2)
899 RETURNING id
900 ",
901 )
902 .bind(&live_kit_room)
903 .bind(0)
904 .fetch_one(&mut tx)
905 .await
906 .map(RoomId)?;
907
908 sqlx::query(
909 "
910 INSERT INTO room_participants (room_id, user_id, connection_id, calling_user_id)
911 VALUES ($1, $2, $3, $4)
912 ",
913 )
914 .bind(room_id)
915 .bind(user_id)
916 .bind(connection_id.0 as i32)
917 .bind(user_id)
918 .execute(&mut tx)
919 .await?;
920
921 self.commit_room_transaction(room_id, tx).await
922 })
923 }
924
925 pub async fn call(
926 &self,
927 room_id: RoomId,
928 calling_user_id: UserId,
929 called_user_id: UserId,
930 initial_project_id: Option<ProjectId>,
931 ) -> Result<(proto::Room, proto::IncomingCall)> {
932 test_support!(self, {
933 let mut tx = self.pool.begin().await?;
934 sqlx::query(
935 "
936 INSERT INTO room_participants (room_id, user_id, calling_user_id, initial_project_id)
937 VALUES ($1, $2, $3, $4)
938 ",
939 )
940 .bind(room_id)
941 .bind(called_user_id)
942 .bind(calling_user_id)
943 .bind(initial_project_id)
944 .execute(&mut tx)
945 .await?;
946
947 let room = self.commit_room_transaction(room_id, tx).await?;
948 let incoming_call = Self::build_incoming_call(&room, called_user_id)
949 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
950 Ok((room, incoming_call))
951 })
952 }
953
954 pub async fn incoming_call_for_user(
955 &self,
956 user_id: UserId,
957 ) -> Result<Option<proto::IncomingCall>> {
958 test_support!(self, {
959 let mut tx = self.pool.begin().await?;
960 let room_id = sqlx::query_scalar::<_, RoomId>(
961 "
962 SELECT room_id
963 FROM room_participants
964 WHERE user_id = $1 AND connection_id IS NULL
965 ",
966 )
967 .bind(user_id)
968 .fetch_optional(&mut tx)
969 .await?;
970
971 if let Some(room_id) = room_id {
972 let room = self.get_room(room_id, &mut tx).await?;
973 Ok(Self::build_incoming_call(&room, user_id))
974 } else {
975 Ok(None)
976 }
977 })
978 }
979
980 fn build_incoming_call(
981 room: &proto::Room,
982 called_user_id: UserId,
983 ) -> Option<proto::IncomingCall> {
984 let pending_participant = room
985 .pending_participants
986 .iter()
987 .find(|participant| participant.user_id == called_user_id.to_proto())?;
988
989 Some(proto::IncomingCall {
990 room_id: room.id,
991 calling_user_id: pending_participant.calling_user_id,
992 participant_user_ids: room
993 .participants
994 .iter()
995 .map(|participant| participant.user_id)
996 .collect(),
997 initial_project: room.participants.iter().find_map(|participant| {
998 let initial_project_id = pending_participant.initial_project_id?;
999 participant
1000 .projects
1001 .iter()
1002 .find(|project| project.id == initial_project_id)
1003 .cloned()
1004 }),
1005 })
1006 }
1007
1008 pub async fn call_failed(
1009 &self,
1010 room_id: RoomId,
1011 called_user_id: UserId,
1012 ) -> Result<proto::Room> {
1013 test_support!(self, {
1014 let mut tx = self.pool.begin().await?;
1015 sqlx::query(
1016 "
1017 DELETE FROM room_participants
1018 WHERE room_id = $1 AND user_id = $2
1019 ",
1020 )
1021 .bind(room_id)
1022 .bind(called_user_id)
1023 .execute(&mut tx)
1024 .await?;
1025
1026 self.commit_room_transaction(room_id, tx).await
1027 })
1028 }
1029
1030 pub async fn decline_call(&self, room_id: RoomId, user_id: UserId) -> Result<proto::Room> {
1031 test_support!(self, {
1032 let mut tx = self.pool.begin().await?;
1033 sqlx::query(
1034 "
1035 DELETE FROM room_participants
1036 WHERE room_id = $1 AND user_id = $2 AND connection_id IS NULL
1037 ",
1038 )
1039 .bind(room_id)
1040 .bind(user_id)
1041 .execute(&mut tx)
1042 .await?;
1043
1044 self.commit_room_transaction(room_id, tx).await
1045 })
1046 }
1047
1048 pub async fn join_room(
1049 &self,
1050 room_id: RoomId,
1051 user_id: UserId,
1052 connection_id: ConnectionId,
1053 ) -> Result<proto::Room> {
1054 test_support!(self, {
1055 let mut tx = self.pool.begin().await?;
1056 sqlx::query(
1057 "
1058 UPDATE room_participants
1059 SET connection_id = $1
1060 WHERE room_id = $2 AND user_id = $3
1061 RETURNING 1
1062 ",
1063 )
1064 .bind(connection_id.0 as i32)
1065 .bind(room_id)
1066 .bind(user_id)
1067 .fetch_one(&mut tx)
1068 .await?;
1069 self.commit_room_transaction(room_id, tx).await
1070 })
1071 }
1072
1073 pub async fn update_room_participant_location(
1074 &self,
1075 room_id: RoomId,
1076 user_id: UserId,
1077 location: proto::ParticipantLocation,
1078 ) -> Result<proto::Room> {
1079 test_support!(self, {
1080 let mut tx = self.pool.begin().await?;
1081
1082 let location_kind;
1083 let location_project_id;
1084 match location
1085 .variant
1086 .ok_or_else(|| anyhow!("invalid location"))?
1087 {
1088 proto::participant_location::Variant::SharedProject(project) => {
1089 location_kind = 0;
1090 location_project_id = Some(ProjectId::from_proto(project.id));
1091 }
1092 proto::participant_location::Variant::UnsharedProject(_) => {
1093 location_kind = 1;
1094 location_project_id = None;
1095 }
1096 proto::participant_location::Variant::External(_) => {
1097 location_kind = 2;
1098 location_project_id = None;
1099 }
1100 }
1101
1102 sqlx::query(
1103 "
1104 UPDATE room_participants
1105 SET location_kind = $1 AND location_project_id = $2
1106 WHERE room_id = $1 AND user_id = $2
1107 ",
1108 )
1109 .bind(location_kind)
1110 .bind(location_project_id)
1111 .bind(room_id)
1112 .bind(user_id)
1113 .execute(&mut tx)
1114 .await?;
1115
1116 self.commit_room_transaction(room_id, tx).await
1117 })
1118 }
1119
1120 async fn commit_room_transaction(
1121 &self,
1122 room_id: RoomId,
1123 mut tx: sqlx::Transaction<'_, D>,
1124 ) -> Result<proto::Room> {
1125 sqlx::query(
1126 "
1127 UPDATE rooms
1128 SET version = version + 1
1129 WHERE id = $1
1130 ",
1131 )
1132 .bind(room_id)
1133 .execute(&mut tx)
1134 .await?;
1135 let room = self.get_room(room_id, &mut tx).await?;
1136 tx.commit().await?;
1137
1138 Ok(room)
1139 }
1140
1141 async fn get_room(
1142 &self,
1143 room_id: RoomId,
1144 tx: &mut sqlx::Transaction<'_, D>,
1145 ) -> Result<proto::Room> {
1146 let room: Room = sqlx::query_as(
1147 "
1148 SELECT *
1149 FROM rooms
1150 WHERE id = $1
1151 ",
1152 )
1153 .bind(room_id)
1154 .fetch_one(&mut *tx)
1155 .await?;
1156
1157 let mut db_participants =
1158 sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1159 "
1160 SELECT user_id, connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1161 FROM room_participants
1162 WHERE room_id = $1
1163 ",
1164 )
1165 .bind(room_id)
1166 .fetch(&mut *tx);
1167
1168 let mut participants = Vec::new();
1169 let mut pending_participants = Vec::new();
1170 while let Some(participant) = db_participants.next().await {
1171 let (
1172 user_id,
1173 connection_id,
1174 _location_kind,
1175 _location_project_id,
1176 calling_user_id,
1177 initial_project_id,
1178 ) = participant?;
1179 if let Some(connection_id) = connection_id {
1180 participants.push(proto::Participant {
1181 user_id: user_id.to_proto(),
1182 peer_id: connection_id as u32,
1183 projects: Default::default(),
1184 location: Some(proto::ParticipantLocation {
1185 variant: Some(proto::participant_location::Variant::External(
1186 Default::default(),
1187 )),
1188 }),
1189 });
1190 } else {
1191 pending_participants.push(proto::PendingParticipant {
1192 user_id: user_id.to_proto(),
1193 calling_user_id: calling_user_id.to_proto(),
1194 initial_project_id: initial_project_id.map(|id| id.to_proto()),
1195 });
1196 }
1197 }
1198 drop(db_participants);
1199
1200 for participant in &mut participants {
1201 let mut entries = sqlx::query_as::<_, (ProjectId, String)>(
1202 "
1203 SELECT projects.id, worktrees.root_name
1204 FROM projects
1205 LEFT JOIN worktrees ON projects.id = worktrees.project_id
1206 WHERE room_id = $1 AND host_user_id = $2
1207 ",
1208 )
1209 .bind(room_id)
1210 .fetch(&mut *tx);
1211
1212 let mut projects = HashMap::default();
1213 while let Some(entry) = entries.next().await {
1214 let (project_id, worktree_root_name) = entry?;
1215 let participant_project =
1216 projects
1217 .entry(project_id)
1218 .or_insert(proto::ParticipantProject {
1219 id: project_id.to_proto(),
1220 worktree_root_names: Default::default(),
1221 });
1222 participant_project
1223 .worktree_root_names
1224 .push(worktree_root_name);
1225 }
1226
1227 participant.projects = projects.into_values().collect();
1228 }
1229 Ok(proto::Room {
1230 id: room.id.to_proto(),
1231 version: room.version as u64,
1232 live_kit_room: room.live_kit_room,
1233 participants,
1234 pending_participants,
1235 })
1236 }
1237
1238 // projects
1239
1240 pub async fn share_project(
1241 &self,
1242 user_id: UserId,
1243 connection_id: ConnectionId,
1244 room_id: RoomId,
1245 worktrees: &[proto::WorktreeMetadata],
1246 ) -> Result<(ProjectId, proto::Room)> {
1247 test_support!(self, {
1248 let mut tx = self.pool.begin().await?;
1249 let project_id = sqlx::query_scalar(
1250 "
1251 INSERT INTO projects (host_user_id, room_id)
1252 VALUES ($1)
1253 RETURNING id
1254 ",
1255 )
1256 .bind(user_id)
1257 .bind(room_id)
1258 .fetch_one(&mut tx)
1259 .await
1260 .map(ProjectId)?;
1261
1262 for worktree in worktrees {
1263 sqlx::query(
1264 "
1265 INSERT INTO worktrees (id, project_id, root_name)
1266 ",
1267 )
1268 .bind(worktree.id as i32)
1269 .bind(project_id)
1270 .bind(&worktree.root_name)
1271 .execute(&mut tx)
1272 .await?;
1273 }
1274
1275 sqlx::query(
1276 "
1277 INSERT INTO project_collaborators (
1278 project_id,
1279 connection_id,
1280 user_id,
1281 replica_id,
1282 is_host
1283 )
1284 VALUES ($1, $2, $3, $4, $5)
1285 ",
1286 )
1287 .bind(project_id)
1288 .bind(connection_id.0 as i32)
1289 .bind(user_id)
1290 .bind(0)
1291 .bind(true)
1292 .execute(&mut tx)
1293 .await?;
1294
1295 let room = self.commit_room_transaction(room_id, tx).await?;
1296 Ok((project_id, room))
1297 })
1298 }
1299
1300 pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
1301 todo!()
1302 // test_support!(self, {
1303 // sqlx::query(
1304 // "
1305 // UPDATE projects
1306 // SET unregistered = TRUE
1307 // WHERE id = $1
1308 // ",
1309 // )
1310 // .bind(project_id)
1311 // .execute(&self.pool)
1312 // .await?;
1313 // Ok(())
1314 // })
1315 }
1316
1317 // contacts
1318
1319 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1320 test_support!(self, {
1321 let query = "
1322 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1323 FROM contacts
1324 WHERE user_id_a = $1 OR user_id_b = $1;
1325 ";
1326
1327 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1328 .bind(user_id)
1329 .fetch(&self.pool);
1330
1331 let mut contacts = Vec::new();
1332 while let Some(row) = rows.next().await {
1333 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1334
1335 if user_id_a == user_id {
1336 if accepted {
1337 contacts.push(Contact::Accepted {
1338 user_id: user_id_b,
1339 should_notify: should_notify && a_to_b,
1340 });
1341 } else if a_to_b {
1342 contacts.push(Contact::Outgoing { user_id: user_id_b })
1343 } else {
1344 contacts.push(Contact::Incoming {
1345 user_id: user_id_b,
1346 should_notify,
1347 });
1348 }
1349 } else if accepted {
1350 contacts.push(Contact::Accepted {
1351 user_id: user_id_a,
1352 should_notify: should_notify && !a_to_b,
1353 });
1354 } else if a_to_b {
1355 contacts.push(Contact::Incoming {
1356 user_id: user_id_a,
1357 should_notify,
1358 });
1359 } else {
1360 contacts.push(Contact::Outgoing { user_id: user_id_a });
1361 }
1362 }
1363
1364 contacts.sort_unstable_by_key(|contact| contact.user_id());
1365
1366 Ok(contacts)
1367 })
1368 }
1369
1370 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1371 test_support!(self, {
1372 let (id_a, id_b) = if user_id_1 < user_id_2 {
1373 (user_id_1, user_id_2)
1374 } else {
1375 (user_id_2, user_id_1)
1376 };
1377
1378 let query = "
1379 SELECT 1 FROM contacts
1380 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
1381 LIMIT 1
1382 ";
1383 Ok(sqlx::query_scalar::<_, i32>(query)
1384 .bind(id_a.0)
1385 .bind(id_b.0)
1386 .fetch_optional(&self.pool)
1387 .await?
1388 .is_some())
1389 })
1390 }
1391
1392 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1393 test_support!(self, {
1394 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1395 (sender_id, receiver_id, true)
1396 } else {
1397 (receiver_id, sender_id, false)
1398 };
1399 let query = "
1400 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1401 VALUES ($1, $2, $3, FALSE, TRUE)
1402 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1403 SET
1404 accepted = TRUE,
1405 should_notify = FALSE
1406 WHERE
1407 NOT contacts.accepted AND
1408 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1409 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1410 ";
1411 let result = sqlx::query(query)
1412 .bind(id_a.0)
1413 .bind(id_b.0)
1414 .bind(a_to_b)
1415 .execute(&self.pool)
1416 .await?;
1417
1418 if result.rows_affected() == 1 {
1419 Ok(())
1420 } else {
1421 Err(anyhow!("contact already requested"))?
1422 }
1423 })
1424 }
1425
1426 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1427 test_support!(self, {
1428 let (id_a, id_b) = if responder_id < requester_id {
1429 (responder_id, requester_id)
1430 } else {
1431 (requester_id, responder_id)
1432 };
1433 let query = "
1434 DELETE FROM contacts
1435 WHERE user_id_a = $1 AND user_id_b = $2;
1436 ";
1437 let result = sqlx::query(query)
1438 .bind(id_a.0)
1439 .bind(id_b.0)
1440 .execute(&self.pool)
1441 .await?;
1442
1443 if result.rows_affected() == 1 {
1444 Ok(())
1445 } else {
1446 Err(anyhow!("no such contact"))?
1447 }
1448 })
1449 }
1450
1451 pub async fn dismiss_contact_notification(
1452 &self,
1453 user_id: UserId,
1454 contact_user_id: UserId,
1455 ) -> Result<()> {
1456 test_support!(self, {
1457 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1458 (user_id, contact_user_id, true)
1459 } else {
1460 (contact_user_id, user_id, false)
1461 };
1462
1463 let query = "
1464 UPDATE contacts
1465 SET should_notify = FALSE
1466 WHERE
1467 user_id_a = $1 AND user_id_b = $2 AND
1468 (
1469 (a_to_b = $3 AND accepted) OR
1470 (a_to_b != $3 AND NOT accepted)
1471 );
1472 ";
1473
1474 let result = sqlx::query(query)
1475 .bind(id_a.0)
1476 .bind(id_b.0)
1477 .bind(a_to_b)
1478 .execute(&self.pool)
1479 .await?;
1480
1481 if result.rows_affected() == 0 {
1482 Err(anyhow!("no such contact request"))?;
1483 }
1484
1485 Ok(())
1486 })
1487 }
1488
1489 pub async fn respond_to_contact_request(
1490 &self,
1491 responder_id: UserId,
1492 requester_id: UserId,
1493 accept: bool,
1494 ) -> Result<()> {
1495 test_support!(self, {
1496 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1497 (responder_id, requester_id, false)
1498 } else {
1499 (requester_id, responder_id, true)
1500 };
1501 let result = if accept {
1502 let query = "
1503 UPDATE contacts
1504 SET accepted = TRUE, should_notify = TRUE
1505 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1506 ";
1507 sqlx::query(query)
1508 .bind(id_a.0)
1509 .bind(id_b.0)
1510 .bind(a_to_b)
1511 .execute(&self.pool)
1512 .await?
1513 } else {
1514 let query = "
1515 DELETE FROM contacts
1516 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1517 ";
1518 sqlx::query(query)
1519 .bind(id_a.0)
1520 .bind(id_b.0)
1521 .bind(a_to_b)
1522 .execute(&self.pool)
1523 .await?
1524 };
1525 if result.rows_affected() == 1 {
1526 Ok(())
1527 } else {
1528 Err(anyhow!("no such contact request"))?
1529 }
1530 })
1531 }
1532
1533 // access tokens
1534
1535 pub async fn create_access_token_hash(
1536 &self,
1537 user_id: UserId,
1538 access_token_hash: &str,
1539 max_access_token_count: usize,
1540 ) -> Result<()> {
1541 test_support!(self, {
1542 let insert_query = "
1543 INSERT INTO access_tokens (user_id, hash)
1544 VALUES ($1, $2);
1545 ";
1546 let cleanup_query = "
1547 DELETE FROM access_tokens
1548 WHERE id IN (
1549 SELECT id from access_tokens
1550 WHERE user_id = $1
1551 ORDER BY id DESC
1552 LIMIT 10000
1553 OFFSET $3
1554 )
1555 ";
1556
1557 let mut tx = self.pool.begin().await?;
1558 sqlx::query(insert_query)
1559 .bind(user_id.0)
1560 .bind(access_token_hash)
1561 .execute(&mut tx)
1562 .await?;
1563 sqlx::query(cleanup_query)
1564 .bind(user_id.0)
1565 .bind(access_token_hash)
1566 .bind(max_access_token_count as i32)
1567 .execute(&mut tx)
1568 .await?;
1569 Ok(tx.commit().await?)
1570 })
1571 }
1572
1573 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1574 test_support!(self, {
1575 let query = "
1576 SELECT hash
1577 FROM access_tokens
1578 WHERE user_id = $1
1579 ORDER BY id DESC
1580 ";
1581 Ok(sqlx::query_scalar(query)
1582 .bind(user_id.0)
1583 .fetch_all(&self.pool)
1584 .await?)
1585 })
1586 }
1587}
1588
1589macro_rules! id_type {
1590 ($name:ident) => {
1591 #[derive(
1592 Clone,
1593 Copy,
1594 Debug,
1595 Default,
1596 PartialEq,
1597 Eq,
1598 PartialOrd,
1599 Ord,
1600 Hash,
1601 sqlx::Type,
1602 Serialize,
1603 Deserialize,
1604 )]
1605 #[sqlx(transparent)]
1606 #[serde(transparent)]
1607 pub struct $name(pub i32);
1608
1609 impl $name {
1610 #[allow(unused)]
1611 pub const MAX: Self = Self(i32::MAX);
1612
1613 #[allow(unused)]
1614 pub fn from_proto(value: u64) -> Self {
1615 Self(value as i32)
1616 }
1617
1618 #[allow(unused)]
1619 pub fn to_proto(self) -> u64 {
1620 self.0 as u64
1621 }
1622 }
1623
1624 impl std::fmt::Display for $name {
1625 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1626 self.0.fmt(f)
1627 }
1628 }
1629 };
1630}
1631
1632id_type!(UserId);
1633#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1634pub struct User {
1635 pub id: UserId,
1636 pub github_login: String,
1637 pub github_user_id: Option<i32>,
1638 pub email_address: Option<String>,
1639 pub admin: bool,
1640 pub invite_code: Option<String>,
1641 pub invite_count: i32,
1642 pub connected_once: bool,
1643}
1644
1645id_type!(RoomId);
1646#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1647pub struct Room {
1648 pub id: RoomId,
1649 pub version: i32,
1650 pub live_kit_room: String,
1651}
1652
1653#[derive(Clone, Debug, Default, FromRow, PartialEq)]
1654pub struct Call {
1655 pub room_id: RoomId,
1656 pub calling_user_id: UserId,
1657 pub called_user_id: UserId,
1658 pub answering_connection_id: Option<i32>,
1659 pub initial_project_id: Option<ProjectId>,
1660}
1661
1662id_type!(ProjectId);
1663#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1664pub struct Project {
1665 pub id: ProjectId,
1666 pub host_user_id: UserId,
1667 pub unregistered: bool,
1668}
1669
1670#[derive(Clone, Debug, PartialEq, Eq)]
1671pub enum Contact {
1672 Accepted {
1673 user_id: UserId,
1674 should_notify: bool,
1675 },
1676 Outgoing {
1677 user_id: UserId,
1678 },
1679 Incoming {
1680 user_id: UserId,
1681 should_notify: bool,
1682 },
1683}
1684
1685impl Contact {
1686 pub fn user_id(&self) -> UserId {
1687 match self {
1688 Contact::Accepted { user_id, .. } => *user_id,
1689 Contact::Outgoing { user_id } => *user_id,
1690 Contact::Incoming { user_id, .. } => *user_id,
1691 }
1692 }
1693}
1694
1695#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1696pub struct IncomingContactRequest {
1697 pub requester_id: UserId,
1698 pub should_notify: bool,
1699}
1700
1701#[derive(Clone, Deserialize)]
1702pub struct Signup {
1703 pub email_address: String,
1704 pub platform_mac: bool,
1705 pub platform_windows: bool,
1706 pub platform_linux: bool,
1707 pub editor_features: Vec<String>,
1708 pub programming_languages: Vec<String>,
1709 pub device_id: Option<String>,
1710}
1711
1712#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1713pub struct WaitlistSummary {
1714 #[sqlx(default)]
1715 pub count: i64,
1716 #[sqlx(default)]
1717 pub linux_count: i64,
1718 #[sqlx(default)]
1719 pub mac_count: i64,
1720 #[sqlx(default)]
1721 pub windows_count: i64,
1722 #[sqlx(default)]
1723 pub unknown_count: i64,
1724}
1725
1726#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1727pub struct Invite {
1728 pub email_address: String,
1729 pub email_confirmation_code: String,
1730}
1731
1732#[derive(Debug, Serialize, Deserialize)]
1733pub struct NewUserParams {
1734 pub github_login: String,
1735 pub github_user_id: i32,
1736 pub invite_count: i32,
1737}
1738
1739#[derive(Debug)]
1740pub struct NewUserResult {
1741 pub user_id: UserId,
1742 pub metrics_id: String,
1743 pub inviting_user_id: Option<UserId>,
1744 pub signup_device_id: Option<String>,
1745}
1746
1747fn random_invite_code() -> String {
1748 nanoid::nanoid!(16)
1749}
1750
1751fn random_email_confirmation_code() -> String {
1752 nanoid::nanoid!(64)
1753}
1754
1755#[cfg(test)]
1756pub use test::*;
1757
1758#[cfg(test)]
1759mod test {
1760 use super::*;
1761 use gpui::executor::Background;
1762 use lazy_static::lazy_static;
1763 use parking_lot::Mutex;
1764 use rand::prelude::*;
1765 use sqlx::migrate::MigrateDatabase;
1766 use std::sync::Arc;
1767
1768 pub struct SqliteTestDb {
1769 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1770 pub conn: sqlx::sqlite::SqliteConnection,
1771 }
1772
1773 pub struct PostgresTestDb {
1774 pub db: Option<Arc<Db<sqlx::Postgres>>>,
1775 pub url: String,
1776 }
1777
1778 impl SqliteTestDb {
1779 pub fn new(background: Arc<Background>) -> Self {
1780 let mut rng = StdRng::from_entropy();
1781 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1782 let runtime = tokio::runtime::Builder::new_current_thread()
1783 .enable_io()
1784 .enable_time()
1785 .build()
1786 .unwrap();
1787
1788 let (mut db, conn) = runtime.block_on(async {
1789 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1790 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1791 db.migrate(migrations_path.as_ref(), false).await.unwrap();
1792 let conn = db.pool.acquire().await.unwrap().detach();
1793 (db, conn)
1794 });
1795
1796 db.background = Some(background);
1797 db.runtime = Some(runtime);
1798
1799 Self {
1800 db: Some(Arc::new(db)),
1801 conn,
1802 }
1803 }
1804
1805 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1806 self.db.as_ref().unwrap()
1807 }
1808 }
1809
1810 impl PostgresTestDb {
1811 pub fn new(background: Arc<Background>) -> Self {
1812 lazy_static! {
1813 static ref LOCK: Mutex<()> = Mutex::new(());
1814 }
1815
1816 let _guard = LOCK.lock();
1817 let mut rng = StdRng::from_entropy();
1818 let url = format!(
1819 "postgres://postgres@localhost/zed-test-{}",
1820 rng.gen::<u128>()
1821 );
1822 let runtime = tokio::runtime::Builder::new_current_thread()
1823 .enable_io()
1824 .enable_time()
1825 .build()
1826 .unwrap();
1827
1828 let mut db = runtime.block_on(async {
1829 sqlx::Postgres::create_database(&url)
1830 .await
1831 .expect("failed to create test db");
1832 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1833 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1834 db.migrate(Path::new(migrations_path), false).await.unwrap();
1835 db
1836 });
1837
1838 db.background = Some(background);
1839 db.runtime = Some(runtime);
1840
1841 Self {
1842 db: Some(Arc::new(db)),
1843 url,
1844 }
1845 }
1846
1847 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1848 self.db.as_ref().unwrap()
1849 }
1850 }
1851
1852 impl Drop for PostgresTestDb {
1853 fn drop(&mut self) {
1854 let db = self.db.take().unwrap();
1855 db.teardown(&self.url);
1856 }
1857 }
1858}