1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use rpc::{proto, ConnectionId};
7use serde::{Deserialize, Serialize};
8use sqlx::{
9 migrate::{Migrate as _, Migration, MigrationSource},
10 types::Uuid,
11 FromRow,
12};
13use std::{path::Path, time::Duration};
14use time::{OffsetDateTime, PrimitiveDateTime};
15
16#[cfg(test)]
17pub type DefaultDb = Db<sqlx::Sqlite>;
18
19#[cfg(not(test))]
20pub type DefaultDb = Db<sqlx::Postgres>;
21
22pub struct Db<D: sqlx::Database> {
23 pool: sqlx::Pool<D>,
24 #[cfg(test)]
25 background: Option<std::sync::Arc<gpui::executor::Background>>,
26 #[cfg(test)]
27 runtime: Option<tokio::runtime::Runtime>,
28}
29
30macro_rules! test_support {
31 ($self:ident, { $($token:tt)* }) => {{
32 let body = async {
33 $($token)*
34 };
35
36 if cfg!(test) {
37 #[cfg(not(test))]
38 unreachable!();
39
40 #[cfg(test)]
41 if let Some(background) = $self.background.as_ref() {
42 background.simulate_random_delay().await;
43 }
44
45 #[cfg(test)]
46 $self.runtime.as_ref().unwrap().block_on(body)
47 } else {
48 body.await
49 }
50 }};
51}
52
53pub trait RowsAffected {
54 fn rows_affected(&self) -> u64;
55}
56
57#[cfg(test)]
58impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
59 fn rows_affected(&self) -> u64 {
60 self.rows_affected()
61 }
62}
63
64impl RowsAffected for sqlx::postgres::PgQueryResult {
65 fn rows_affected(&self) -> u64 {
66 self.rows_affected()
67 }
68}
69
70#[cfg(test)]
71impl Db<sqlx::Sqlite> {
72 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
73 use std::str::FromStr as _;
74 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
75 .unwrap()
76 .create_if_missing(true)
77 .shared_cache(true);
78 let pool = sqlx::sqlite::SqlitePoolOptions::new()
79 .min_connections(2)
80 .max_connections(max_connections)
81 .connect_with(options)
82 .await?;
83 Ok(Self {
84 pool,
85 background: None,
86 runtime: None,
87 })
88 }
89
90 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
91 test_support!(self, {
92 let query = "
93 SELECT users.*
94 FROM users
95 WHERE users.id IN (SELECT value from json_each($1))
96 ";
97 Ok(sqlx::query_as(query)
98 .bind(&serde_json::json!(ids))
99 .fetch_all(&self.pool)
100 .await?)
101 })
102 }
103
104 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
105 test_support!(self, {
106 let query = "
107 SELECT metrics_id
108 FROM users
109 WHERE id = $1
110 ";
111 Ok(sqlx::query_scalar(query)
112 .bind(id)
113 .fetch_one(&self.pool)
114 .await?)
115 })
116 }
117
118 pub async fn create_user(
119 &self,
120 email_address: &str,
121 admin: bool,
122 params: NewUserParams,
123 ) -> Result<NewUserResult> {
124 test_support!(self, {
125 let query = "
126 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
127 VALUES ($1, $2, $3, $4, $5)
128 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
129 RETURNING id, metrics_id
130 ";
131
132 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
133 .bind(email_address)
134 .bind(params.github_login)
135 .bind(params.github_user_id)
136 .bind(admin)
137 .bind(Uuid::new_v4().to_string())
138 .fetch_one(&self.pool)
139 .await?;
140 Ok(NewUserResult {
141 user_id,
142 metrics_id,
143 signup_device_id: None,
144 inviting_user_id: None,
145 })
146 })
147 }
148
149 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
150 unimplemented!()
151 }
152
153 pub async fn create_user_from_invite(
154 &self,
155 _invite: &Invite,
156 _user: NewUserParams,
157 ) -> Result<Option<NewUserResult>> {
158 unimplemented!()
159 }
160
161 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
162 unimplemented!()
163 }
164
165 pub async fn create_invite_from_code(
166 &self,
167 _code: &str,
168 _email_address: &str,
169 _device_id: Option<&str>,
170 ) -> Result<Invite> {
171 unimplemented!()
172 }
173
174 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
175 unimplemented!()
176 }
177}
178
179impl Db<sqlx::Postgres> {
180 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
181 let pool = sqlx::postgres::PgPoolOptions::new()
182 .max_connections(max_connections)
183 .connect(url)
184 .await?;
185 Ok(Self {
186 pool,
187 #[cfg(test)]
188 background: None,
189 #[cfg(test)]
190 runtime: None,
191 })
192 }
193
194 #[cfg(test)]
195 pub fn teardown(&self, url: &str) {
196 self.runtime.as_ref().unwrap().block_on(async {
197 use util::ResultExt;
198 let query = "
199 SELECT pg_terminate_backend(pg_stat_activity.pid)
200 FROM pg_stat_activity
201 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
202 ";
203 sqlx::query(query).execute(&self.pool).await.log_err();
204 self.pool.close().await;
205 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
206 .await
207 .log_err();
208 })
209 }
210
211 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
212 test_support!(self, {
213 let like_string = Self::fuzzy_like_string(name_query);
214 let query = "
215 SELECT users.*
216 FROM users
217 WHERE github_login ILIKE $1
218 ORDER BY github_login <-> $2
219 LIMIT $3
220 ";
221 Ok(sqlx::query_as(query)
222 .bind(like_string)
223 .bind(name_query)
224 .bind(limit as i32)
225 .fetch_all(&self.pool)
226 .await?)
227 })
228 }
229
230 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
231 test_support!(self, {
232 let query = "
233 SELECT users.*
234 FROM users
235 WHERE users.id = ANY ($1)
236 ";
237 Ok(sqlx::query_as(query)
238 .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
239 .fetch_all(&self.pool)
240 .await?)
241 })
242 }
243
244 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
245 test_support!(self, {
246 let query = "
247 SELECT metrics_id::text
248 FROM users
249 WHERE id = $1
250 ";
251 Ok(sqlx::query_scalar(query)
252 .bind(id)
253 .fetch_one(&self.pool)
254 .await?)
255 })
256 }
257
258 pub async fn create_user(
259 &self,
260 email_address: &str,
261 admin: bool,
262 params: NewUserParams,
263 ) -> Result<NewUserResult> {
264 test_support!(self, {
265 let query = "
266 INSERT INTO users (email_address, github_login, github_user_id, admin)
267 VALUES ($1, $2, $3, $4)
268 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
269 RETURNING id, metrics_id::text
270 ";
271
272 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
273 .bind(email_address)
274 .bind(params.github_login)
275 .bind(params.github_user_id)
276 .bind(admin)
277 .fetch_one(&self.pool)
278 .await?;
279 Ok(NewUserResult {
280 user_id,
281 metrics_id,
282 signup_device_id: None,
283 inviting_user_id: None,
284 })
285 })
286 }
287
288 pub async fn create_user_from_invite(
289 &self,
290 invite: &Invite,
291 user: NewUserParams,
292 ) -> Result<Option<NewUserResult>> {
293 test_support!(self, {
294 let mut tx = self.pool.begin().await?;
295
296 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
297 i32,
298 Option<UserId>,
299 Option<UserId>,
300 Option<String>,
301 ) = sqlx::query_as(
302 "
303 SELECT id, user_id, inviting_user_id, device_id
304 FROM signups
305 WHERE
306 email_address = $1 AND
307 email_confirmation_code = $2
308 ",
309 )
310 .bind(&invite.email_address)
311 .bind(&invite.email_confirmation_code)
312 .fetch_optional(&mut tx)
313 .await?
314 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
315
316 if existing_user_id.is_some() {
317 return Ok(None);
318 }
319
320 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
321 "
322 INSERT INTO users
323 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
324 VALUES
325 ($1, $2, $3, FALSE, $4, $5)
326 ON CONFLICT (github_login) DO UPDATE SET
327 email_address = excluded.email_address,
328 github_user_id = excluded.github_user_id,
329 admin = excluded.admin
330 RETURNING id, metrics_id::text
331 ",
332 )
333 .bind(&invite.email_address)
334 .bind(&user.github_login)
335 .bind(&user.github_user_id)
336 .bind(&user.invite_count)
337 .bind(random_invite_code())
338 .fetch_one(&mut tx)
339 .await?;
340
341 sqlx::query(
342 "
343 UPDATE signups
344 SET user_id = $1
345 WHERE id = $2
346 ",
347 )
348 .bind(&user_id)
349 .bind(&signup_id)
350 .execute(&mut tx)
351 .await?;
352
353 if let Some(inviting_user_id) = inviting_user_id {
354 let id: Option<UserId> = sqlx::query_scalar(
355 "
356 UPDATE users
357 SET invite_count = invite_count - 1
358 WHERE id = $1 AND invite_count > 0
359 RETURNING id
360 ",
361 )
362 .bind(&inviting_user_id)
363 .fetch_optional(&mut tx)
364 .await?;
365
366 if id.is_none() {
367 Err(Error::Http(
368 StatusCode::UNAUTHORIZED,
369 "no invites remaining".to_string(),
370 ))?;
371 }
372
373 sqlx::query(
374 "
375 INSERT INTO contacts
376 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
377 VALUES
378 ($1, $2, TRUE, TRUE, TRUE)
379 ON CONFLICT DO NOTHING
380 ",
381 )
382 .bind(inviting_user_id)
383 .bind(user_id)
384 .execute(&mut tx)
385 .await?;
386 }
387
388 tx.commit().await?;
389 Ok(Some(NewUserResult {
390 user_id,
391 metrics_id,
392 inviting_user_id,
393 signup_device_id,
394 }))
395 })
396 }
397
398 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
399 test_support!(self, {
400 sqlx::query(
401 "
402 INSERT INTO signups
403 (
404 email_address,
405 email_confirmation_code,
406 email_confirmation_sent,
407 platform_linux,
408 platform_mac,
409 platform_windows,
410 platform_unknown,
411 editor_features,
412 programming_languages,
413 device_id
414 )
415 VALUES
416 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
417 RETURNING id
418 ",
419 )
420 .bind(&signup.email_address)
421 .bind(&random_email_confirmation_code())
422 .bind(&signup.platform_linux)
423 .bind(&signup.platform_mac)
424 .bind(&signup.platform_windows)
425 .bind(&signup.editor_features)
426 .bind(&signup.programming_languages)
427 .bind(&signup.device_id)
428 .execute(&self.pool)
429 .await?;
430 Ok(())
431 })
432 }
433
434 pub async fn create_invite_from_code(
435 &self,
436 code: &str,
437 email_address: &str,
438 device_id: Option<&str>,
439 ) -> Result<Invite> {
440 test_support!(self, {
441 let mut tx = self.pool.begin().await?;
442
443 let existing_user: Option<UserId> = sqlx::query_scalar(
444 "
445 SELECT id
446 FROM users
447 WHERE email_address = $1
448 ",
449 )
450 .bind(email_address)
451 .fetch_optional(&mut tx)
452 .await?;
453 if existing_user.is_some() {
454 Err(anyhow!("email address is already in use"))?;
455 }
456
457 let row: Option<(UserId, i32)> = sqlx::query_as(
458 "
459 SELECT id, invite_count
460 FROM users
461 WHERE invite_code = $1
462 ",
463 )
464 .bind(code)
465 .fetch_optional(&mut tx)
466 .await?;
467
468 let (inviter_id, invite_count) = match row {
469 Some(row) => row,
470 None => Err(Error::Http(
471 StatusCode::NOT_FOUND,
472 "invite code not found".to_string(),
473 ))?,
474 };
475
476 if invite_count == 0 {
477 Err(Error::Http(
478 StatusCode::UNAUTHORIZED,
479 "no invites remaining".to_string(),
480 ))?;
481 }
482
483 let email_confirmation_code: String = sqlx::query_scalar(
484 "
485 INSERT INTO signups
486 (
487 email_address,
488 email_confirmation_code,
489 email_confirmation_sent,
490 inviting_user_id,
491 platform_linux,
492 platform_mac,
493 platform_windows,
494 platform_unknown,
495 device_id
496 )
497 VALUES
498 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
499 ON CONFLICT (email_address)
500 DO UPDATE SET
501 inviting_user_id = excluded.inviting_user_id
502 RETURNING email_confirmation_code
503 ",
504 )
505 .bind(&email_address)
506 .bind(&random_email_confirmation_code())
507 .bind(&inviter_id)
508 .bind(&device_id)
509 .fetch_one(&mut tx)
510 .await?;
511
512 tx.commit().await?;
513
514 Ok(Invite {
515 email_address: email_address.into(),
516 email_confirmation_code,
517 })
518 })
519 }
520
521 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
522 test_support!(self, {
523 let emails = invites
524 .iter()
525 .map(|s| s.email_address.as_str())
526 .collect::<Vec<_>>();
527 sqlx::query(
528 "
529 UPDATE signups
530 SET email_confirmation_sent = TRUE
531 WHERE email_address = ANY ($1)
532 ",
533 )
534 .bind(&emails)
535 .execute(&self.pool)
536 .await?;
537 Ok(())
538 })
539 }
540}
541
542impl<D> Db<D>
543where
544 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
545 D::Connection: sqlx::migrate::Migrate,
546 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
547 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
548 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
549 D::QueryResult: RowsAffected,
550 String: sqlx::Type<D>,
551 i32: sqlx::Type<D>,
552 i64: sqlx::Type<D>,
553 bool: sqlx::Type<D>,
554 str: sqlx::Type<D>,
555 Uuid: sqlx::Type<D>,
556 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
557 OffsetDateTime: sqlx::Type<D>,
558 PrimitiveDateTime: sqlx::Type<D>,
559 usize: sqlx::ColumnIndex<D::Row>,
560 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
561 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
562 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
563 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
564 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
565 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
566 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
567 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
568 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
569 for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
570 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
571 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
572 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
573{
574 pub async fn migrate(
575 &self,
576 migrations_path: &Path,
577 ignore_checksum_mismatch: bool,
578 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
579 let migrations = MigrationSource::resolve(migrations_path)
580 .await
581 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
582
583 let mut conn = self.pool.acquire().await?;
584
585 conn.ensure_migrations_table().await?;
586 let applied_migrations: HashMap<_, _> = conn
587 .list_applied_migrations()
588 .await?
589 .into_iter()
590 .map(|m| (m.version, m))
591 .collect();
592
593 let mut new_migrations = Vec::new();
594 for migration in migrations {
595 match applied_migrations.get(&migration.version) {
596 Some(applied_migration) => {
597 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
598 {
599 Err(anyhow!(
600 "checksum mismatch for applied migration {}",
601 migration.description
602 ))?;
603 }
604 }
605 None => {
606 let elapsed = conn.apply(&migration).await?;
607 new_migrations.push((migration, elapsed));
608 }
609 }
610 }
611
612 Ok(new_migrations)
613 }
614
615 pub fn fuzzy_like_string(string: &str) -> String {
616 let mut result = String::with_capacity(string.len() * 2 + 1);
617 for c in string.chars() {
618 if c.is_alphanumeric() {
619 result.push('%');
620 result.push(c);
621 }
622 }
623 result.push('%');
624 result
625 }
626
627 // users
628
629 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
630 test_support!(self, {
631 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
632 Ok(sqlx::query_as(query)
633 .bind(limit as i32)
634 .bind((page * limit) as i32)
635 .fetch_all(&self.pool)
636 .await?)
637 })
638 }
639
640 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
641 test_support!(self, {
642 let query = "
643 SELECT users.*
644 FROM users
645 WHERE id = $1
646 LIMIT 1
647 ";
648 Ok(sqlx::query_as(query)
649 .bind(&id)
650 .fetch_optional(&self.pool)
651 .await?)
652 })
653 }
654
655 pub async fn get_users_with_no_invites(
656 &self,
657 invited_by_another_user: bool,
658 ) -> Result<Vec<User>> {
659 test_support!(self, {
660 let query = format!(
661 "
662 SELECT users.*
663 FROM users
664 WHERE invite_count = 0
665 AND inviter_id IS{} NULL
666 ",
667 if invited_by_another_user { " NOT" } else { "" }
668 );
669
670 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
671 })
672 }
673
674 pub async fn get_user_by_github_account(
675 &self,
676 github_login: &str,
677 github_user_id: Option<i32>,
678 ) -> Result<Option<User>> {
679 test_support!(self, {
680 if let Some(github_user_id) = github_user_id {
681 let mut user = sqlx::query_as::<_, User>(
682 "
683 UPDATE users
684 SET github_login = $1
685 WHERE github_user_id = $2
686 RETURNING *
687 ",
688 )
689 .bind(github_login)
690 .bind(github_user_id)
691 .fetch_optional(&self.pool)
692 .await?;
693
694 if user.is_none() {
695 user = sqlx::query_as::<_, User>(
696 "
697 UPDATE users
698 SET github_user_id = $1
699 WHERE github_login = $2
700 RETURNING *
701 ",
702 )
703 .bind(github_user_id)
704 .bind(github_login)
705 .fetch_optional(&self.pool)
706 .await?;
707 }
708
709 Ok(user)
710 } else {
711 let user = sqlx::query_as(
712 "
713 SELECT * FROM users
714 WHERE github_login = $1
715 LIMIT 1
716 ",
717 )
718 .bind(github_login)
719 .fetch_optional(&self.pool)
720 .await?;
721 Ok(user)
722 }
723 })
724 }
725
726 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
727 test_support!(self, {
728 let query = "UPDATE users SET admin = $1 WHERE id = $2";
729 Ok(sqlx::query(query)
730 .bind(is_admin)
731 .bind(id.0)
732 .execute(&self.pool)
733 .await
734 .map(drop)?)
735 })
736 }
737
738 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
739 test_support!(self, {
740 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
741 Ok(sqlx::query(query)
742 .bind(connected_once)
743 .bind(id.0)
744 .execute(&self.pool)
745 .await
746 .map(drop)?)
747 })
748 }
749
750 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
751 test_support!(self, {
752 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
753 sqlx::query(query)
754 .bind(id.0)
755 .execute(&self.pool)
756 .await
757 .map(drop)?;
758 let query = "DELETE FROM users WHERE id = $1;";
759 Ok(sqlx::query(query)
760 .bind(id.0)
761 .execute(&self.pool)
762 .await
763 .map(drop)?)
764 })
765 }
766
767 // signups
768
769 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
770 test_support!(self, {
771 Ok(sqlx::query_as(
772 "
773 SELECT
774 COUNT(*) as count,
775 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
776 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
777 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
778 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
779 FROM (
780 SELECT *
781 FROM signups
782 WHERE
783 NOT email_confirmation_sent
784 ) AS unsent
785 ",
786 )
787 .fetch_one(&self.pool)
788 .await?)
789 })
790 }
791
792 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
793 test_support!(self, {
794 Ok(sqlx::query_as(
795 "
796 SELECT
797 email_address, email_confirmation_code
798 FROM signups
799 WHERE
800 NOT email_confirmation_sent AND
801 (platform_mac OR platform_unknown)
802 LIMIT $1
803 ",
804 )
805 .bind(count as i32)
806 .fetch_all(&self.pool)
807 .await?)
808 })
809 }
810
811 // invite codes
812
813 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
814 test_support!(self, {
815 let mut tx = self.pool.begin().await?;
816 if count > 0 {
817 sqlx::query(
818 "
819 UPDATE users
820 SET invite_code = $1
821 WHERE id = $2 AND invite_code IS NULL
822 ",
823 )
824 .bind(random_invite_code())
825 .bind(id)
826 .execute(&mut tx)
827 .await?;
828 }
829
830 sqlx::query(
831 "
832 UPDATE users
833 SET invite_count = $1
834 WHERE id = $2
835 ",
836 )
837 .bind(count as i32)
838 .bind(id)
839 .execute(&mut tx)
840 .await?;
841 tx.commit().await?;
842 Ok(())
843 })
844 }
845
846 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
847 test_support!(self, {
848 let result: Option<(String, i32)> = sqlx::query_as(
849 "
850 SELECT invite_code, invite_count
851 FROM users
852 WHERE id = $1 AND invite_code IS NOT NULL
853 ",
854 )
855 .bind(id)
856 .fetch_optional(&self.pool)
857 .await?;
858 if let Some((code, count)) = result {
859 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
860 } else {
861 Ok(None)
862 }
863 })
864 }
865
866 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
867 test_support!(self, {
868 sqlx::query_as(
869 "
870 SELECT *
871 FROM users
872 WHERE invite_code = $1
873 ",
874 )
875 .bind(code)
876 .fetch_optional(&self.pool)
877 .await?
878 .ok_or_else(|| {
879 Error::Http(
880 StatusCode::NOT_FOUND,
881 "that invite code does not exist".to_string(),
882 )
883 })
884 })
885 }
886
887 pub async fn create_room(
888 &self,
889 user_id: UserId,
890 connection_id: ConnectionId,
891 ) -> Result<proto::Room> {
892 test_support!(self, {
893 let mut tx = self.pool.begin().await?;
894 let live_kit_room = nanoid::nanoid!(30);
895 let room_id = sqlx::query_scalar(
896 "
897 INSERT INTO rooms (live_kit_room, version)
898 VALUES ($1, $2)
899 RETURNING id
900 ",
901 )
902 .bind(&live_kit_room)
903 .bind(0)
904 .fetch_one(&mut tx)
905 .await
906 .map(RoomId)?;
907
908 sqlx::query(
909 "
910 INSERT INTO room_participants (room_id, user_id, connection_id, calling_user_id)
911 VALUES ($1, $2, $3, $4)
912 ",
913 )
914 .bind(room_id)
915 .bind(user_id)
916 .bind(connection_id.0 as i32)
917 .bind(user_id)
918 .execute(&mut tx)
919 .await?;
920
921 self.commit_room_transaction(room_id, tx).await
922 })
923 }
924
925 pub async fn call(
926 &self,
927 room_id: RoomId,
928 calling_user_id: UserId,
929 called_user_id: UserId,
930 initial_project_id: Option<ProjectId>,
931 ) -> Result<(proto::Room, proto::IncomingCall)> {
932 test_support!(self, {
933 let mut tx = self.pool.begin().await?;
934 sqlx::query(
935 "
936 INSERT INTO room_participants (room_id, user_id, calling_user_id, initial_project_id)
937 VALUES ($1, $2, $3, $4)
938 ",
939 )
940 .bind(room_id)
941 .bind(called_user_id)
942 .bind(calling_user_id)
943 .bind(initial_project_id)
944 .execute(&mut tx)
945 .await?;
946
947 let room = self.commit_room_transaction(room_id, tx).await?;
948 let incoming_call = Self::build_incoming_call(&room, called_user_id)
949 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
950 Ok((room, incoming_call))
951 })
952 }
953
954 pub async fn incoming_call_for_user(
955 &self,
956 user_id: UserId,
957 ) -> Result<Option<proto::IncomingCall>> {
958 test_support!(self, {
959 let mut tx = self.pool.begin().await?;
960 let room_id = sqlx::query_scalar::<_, RoomId>(
961 "
962 SELECT room_id
963 FROM room_participants
964 WHERE user_id = $1 AND connection_id IS NULL
965 ",
966 )
967 .bind(user_id)
968 .fetch_optional(&mut tx)
969 .await?;
970
971 if let Some(room_id) = room_id {
972 let room = self.get_room(room_id, &mut tx).await?;
973 Ok(Self::build_incoming_call(&room, user_id))
974 } else {
975 Ok(None)
976 }
977 })
978 }
979
980 fn build_incoming_call(
981 room: &proto::Room,
982 called_user_id: UserId,
983 ) -> Option<proto::IncomingCall> {
984 let pending_participant = room
985 .pending_participants
986 .iter()
987 .find(|participant| participant.user_id == called_user_id.to_proto())?;
988
989 Some(proto::IncomingCall {
990 room_id: room.id,
991 calling_user_id: pending_participant.calling_user_id,
992 participant_user_ids: room
993 .participants
994 .iter()
995 .map(|participant| participant.user_id)
996 .collect(),
997 initial_project: room.participants.iter().find_map(|participant| {
998 let initial_project_id = pending_participant.initial_project_id?;
999 participant
1000 .projects
1001 .iter()
1002 .find(|project| project.id == initial_project_id)
1003 .cloned()
1004 }),
1005 })
1006 }
1007
1008 pub async fn call_failed(
1009 &self,
1010 room_id: RoomId,
1011 called_user_id: UserId,
1012 ) -> Result<proto::Room> {
1013 test_support!(self, {
1014 let mut tx = self.pool.begin().await?;
1015 sqlx::query(
1016 "
1017 DELETE FROM room_participants
1018 WHERE room_id = $1 AND user_id = $2
1019 ",
1020 )
1021 .bind(room_id)
1022 .bind(called_user_id)
1023 .execute(&mut tx)
1024 .await?;
1025
1026 self.commit_room_transaction(room_id, tx).await
1027 })
1028 }
1029
1030 pub async fn decline_call(&self, room_id: RoomId, user_id: UserId) -> Result<proto::Room> {
1031 test_support!(self, {
1032 let mut tx = self.pool.begin().await?;
1033 sqlx::query(
1034 "
1035 DELETE FROM room_participants
1036 WHERE room_id = $1 AND user_id = $2 AND connection_id IS NULL
1037 ",
1038 )
1039 .bind(room_id)
1040 .bind(user_id)
1041 .execute(&mut tx)
1042 .await?;
1043
1044 self.commit_room_transaction(room_id, tx).await
1045 })
1046 }
1047
1048 pub async fn join_room(
1049 &self,
1050 room_id: RoomId,
1051 user_id: UserId,
1052 connection_id: ConnectionId,
1053 ) -> Result<proto::Room> {
1054 test_support!(self, {
1055 let mut tx = self.pool.begin().await?;
1056 sqlx::query(
1057 "
1058 UPDATE room_participants
1059 SET connection_id = $1
1060 WHERE room_id = $2 AND user_id = $3
1061 RETURNING 1
1062 ",
1063 )
1064 .bind(connection_id.0 as i32)
1065 .bind(room_id)
1066 .bind(user_id)
1067 .fetch_one(&mut tx)
1068 .await?;
1069 self.commit_room_transaction(room_id, tx).await
1070 })
1071 }
1072
1073 pub async fn leave_room(
1074 &self,
1075 room_id: RoomId,
1076 connection_id: ConnectionId,
1077 ) -> Result<LeftRoom> {
1078 test_support!(self, {
1079 let mut tx = self.pool.begin().await?;
1080
1081 // Leave room.
1082 let user_id: UserId = sqlx::query_scalar(
1083 "
1084 DELETE FROM room_participants
1085 WHERE room_id = $1 AND connection_id = $2
1086 RETURNING user_id
1087 ",
1088 )
1089 .bind(room_id)
1090 .bind(connection_id.0 as i32)
1091 .fetch_one(&mut tx)
1092 .await?;
1093
1094 // Cancel pending calls initiated by the leaving user.
1095 let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1096 "
1097 DELETE FROM room_participants
1098 WHERE calling_user_id = $1 AND connection_id IS NULL
1099 RETURNING user_id
1100 ",
1101 )
1102 .bind(room_id)
1103 .bind(connection_id.0 as i32)
1104 .fetch_all(&mut tx)
1105 .await?;
1106
1107 let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1108 "
1109 SELECT project_collaborators.*
1110 FROM projects, project_collaborators
1111 WHERE
1112 projects.room_id = $1 AND
1113 projects.user_id = $2 AND
1114 projects.id = project_collaborators.project_id
1115 ",
1116 )
1117 .bind(room_id)
1118 .bind(user_id)
1119 .fetch(&mut tx);
1120
1121 let mut left_projects = HashMap::default();
1122 while let Some(collaborator) = project_collaborators.next().await {
1123 let collaborator = collaborator?;
1124 let left_project =
1125 left_projects
1126 .entry(collaborator.project_id)
1127 .or_insert(LeftProject {
1128 id: collaborator.project_id,
1129 host_user_id: Default::default(),
1130 connection_ids: Default::default(),
1131 });
1132
1133 let collaborator_connection_id = ConnectionId(collaborator.connection_id as u32);
1134 if collaborator_connection_id != connection_id || collaborator.is_host {
1135 left_project.connection_ids.push(collaborator_connection_id);
1136 }
1137
1138 if collaborator.is_host {
1139 left_project.host_user_id = collaborator.user_id;
1140 }
1141 }
1142 drop(project_collaborators);
1143
1144 sqlx::query(
1145 "
1146 DELETE FROM projects
1147 WHERE room_id = $1 AND user_id = $2
1148 ",
1149 )
1150 .bind(room_id)
1151 .bind(user_id)
1152 .execute(&mut tx)
1153 .await?;
1154
1155 let room = self.commit_room_transaction(room_id, tx).await?;
1156 Ok(LeftRoom {
1157 room,
1158 left_projects,
1159 canceled_calls_to_user_ids,
1160 })
1161 })
1162 }
1163
1164 pub async fn update_room_participant_location(
1165 &self,
1166 room_id: RoomId,
1167 user_id: UserId,
1168 location: proto::ParticipantLocation,
1169 ) -> Result<proto::Room> {
1170 test_support!(self, {
1171 let mut tx = self.pool.begin().await?;
1172
1173 let location_kind;
1174 let location_project_id;
1175 match location
1176 .variant
1177 .ok_or_else(|| anyhow!("invalid location"))?
1178 {
1179 proto::participant_location::Variant::SharedProject(project) => {
1180 location_kind = 0;
1181 location_project_id = Some(ProjectId::from_proto(project.id));
1182 }
1183 proto::participant_location::Variant::UnsharedProject(_) => {
1184 location_kind = 1;
1185 location_project_id = None;
1186 }
1187 proto::participant_location::Variant::External(_) => {
1188 location_kind = 2;
1189 location_project_id = None;
1190 }
1191 }
1192
1193 sqlx::query(
1194 "
1195 UPDATE room_participants
1196 SET location_kind = $1 AND location_project_id = $2
1197 WHERE room_id = $1 AND user_id = $2
1198 ",
1199 )
1200 .bind(location_kind)
1201 .bind(location_project_id)
1202 .bind(room_id)
1203 .bind(user_id)
1204 .execute(&mut tx)
1205 .await?;
1206
1207 self.commit_room_transaction(room_id, tx).await
1208 })
1209 }
1210
1211 async fn commit_room_transaction(
1212 &self,
1213 room_id: RoomId,
1214 mut tx: sqlx::Transaction<'_, D>,
1215 ) -> Result<proto::Room> {
1216 sqlx::query(
1217 "
1218 UPDATE rooms
1219 SET version = version + 1
1220 WHERE id = $1
1221 ",
1222 )
1223 .bind(room_id)
1224 .execute(&mut tx)
1225 .await?;
1226 let room = self.get_room(room_id, &mut tx).await?;
1227 tx.commit().await?;
1228
1229 Ok(room)
1230 }
1231
1232 async fn get_room(
1233 &self,
1234 room_id: RoomId,
1235 tx: &mut sqlx::Transaction<'_, D>,
1236 ) -> Result<proto::Room> {
1237 let room: Room = sqlx::query_as(
1238 "
1239 SELECT *
1240 FROM rooms
1241 WHERE id = $1
1242 ",
1243 )
1244 .bind(room_id)
1245 .fetch_one(&mut *tx)
1246 .await?;
1247
1248 let mut db_participants =
1249 sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1250 "
1251 SELECT user_id, connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1252 FROM room_participants
1253 WHERE room_id = $1
1254 ",
1255 )
1256 .bind(room_id)
1257 .fetch(&mut *tx);
1258
1259 let mut participants = Vec::new();
1260 let mut pending_participants = Vec::new();
1261 while let Some(participant) = db_participants.next().await {
1262 let (
1263 user_id,
1264 connection_id,
1265 _location_kind,
1266 _location_project_id,
1267 calling_user_id,
1268 initial_project_id,
1269 ) = participant?;
1270 if let Some(connection_id) = connection_id {
1271 participants.push(proto::Participant {
1272 user_id: user_id.to_proto(),
1273 peer_id: connection_id as u32,
1274 projects: Default::default(),
1275 location: Some(proto::ParticipantLocation {
1276 variant: Some(proto::participant_location::Variant::External(
1277 Default::default(),
1278 )),
1279 }),
1280 });
1281 } else {
1282 pending_participants.push(proto::PendingParticipant {
1283 user_id: user_id.to_proto(),
1284 calling_user_id: calling_user_id.to_proto(),
1285 initial_project_id: initial_project_id.map(|id| id.to_proto()),
1286 });
1287 }
1288 }
1289 drop(db_participants);
1290
1291 for participant in &mut participants {
1292 let mut entries = sqlx::query_as::<_, (ProjectId, String)>(
1293 "
1294 SELECT projects.id, worktrees.root_name
1295 FROM projects
1296 LEFT JOIN worktrees ON projects.id = worktrees.project_id
1297 WHERE room_id = $1 AND host_user_id = $2
1298 ",
1299 )
1300 .bind(room_id)
1301 .fetch(&mut *tx);
1302
1303 let mut projects = HashMap::default();
1304 while let Some(entry) = entries.next().await {
1305 let (project_id, worktree_root_name) = entry?;
1306 let participant_project =
1307 projects
1308 .entry(project_id)
1309 .or_insert(proto::ParticipantProject {
1310 id: project_id.to_proto(),
1311 worktree_root_names: Default::default(),
1312 });
1313 participant_project
1314 .worktree_root_names
1315 .push(worktree_root_name);
1316 }
1317
1318 participant.projects = projects.into_values().collect();
1319 }
1320 Ok(proto::Room {
1321 id: room.id.to_proto(),
1322 version: room.version as u64,
1323 live_kit_room: room.live_kit_room,
1324 participants,
1325 pending_participants,
1326 })
1327 }
1328
1329 // projects
1330
1331 pub async fn share_project(
1332 &self,
1333 user_id: UserId,
1334 connection_id: ConnectionId,
1335 room_id: RoomId,
1336 worktrees: &[proto::WorktreeMetadata],
1337 ) -> Result<(ProjectId, proto::Room)> {
1338 test_support!(self, {
1339 let mut tx = self.pool.begin().await?;
1340 let project_id = sqlx::query_scalar(
1341 "
1342 INSERT INTO projects (host_user_id, room_id)
1343 VALUES ($1)
1344 RETURNING id
1345 ",
1346 )
1347 .bind(user_id)
1348 .bind(room_id)
1349 .fetch_one(&mut tx)
1350 .await
1351 .map(ProjectId)?;
1352
1353 for worktree in worktrees {
1354 sqlx::query(
1355 "
1356 INSERT INTO worktrees (id, project_id, root_name)
1357 ",
1358 )
1359 .bind(worktree.id as i32)
1360 .bind(project_id)
1361 .bind(&worktree.root_name)
1362 .execute(&mut tx)
1363 .await?;
1364 }
1365
1366 sqlx::query(
1367 "
1368 INSERT INTO project_collaborators (
1369 project_id,
1370 connection_id,
1371 user_id,
1372 replica_id,
1373 is_host
1374 )
1375 VALUES ($1, $2, $3, $4, $5)
1376 ",
1377 )
1378 .bind(project_id)
1379 .bind(connection_id.0 as i32)
1380 .bind(user_id)
1381 .bind(0)
1382 .bind(true)
1383 .execute(&mut tx)
1384 .await?;
1385
1386 let room = self.commit_room_transaction(room_id, tx).await?;
1387 Ok((project_id, room))
1388 })
1389 }
1390
1391 pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
1392 todo!()
1393 // test_support!(self, {
1394 // sqlx::query(
1395 // "
1396 // UPDATE projects
1397 // SET unregistered = TRUE
1398 // WHERE id = $1
1399 // ",
1400 // )
1401 // .bind(project_id)
1402 // .execute(&self.pool)
1403 // .await?;
1404 // Ok(())
1405 // })
1406 }
1407
1408 // contacts
1409
1410 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1411 test_support!(self, {
1412 let query = "
1413 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
1414 FROM contacts
1415 WHERE user_id_a = $1 OR user_id_b = $1;
1416 ";
1417
1418 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
1419 .bind(user_id)
1420 .fetch(&self.pool);
1421
1422 let mut contacts = Vec::new();
1423 while let Some(row) = rows.next().await {
1424 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
1425
1426 if user_id_a == user_id {
1427 if accepted {
1428 contacts.push(Contact::Accepted {
1429 user_id: user_id_b,
1430 should_notify: should_notify && a_to_b,
1431 });
1432 } else if a_to_b {
1433 contacts.push(Contact::Outgoing { user_id: user_id_b })
1434 } else {
1435 contacts.push(Contact::Incoming {
1436 user_id: user_id_b,
1437 should_notify,
1438 });
1439 }
1440 } else if accepted {
1441 contacts.push(Contact::Accepted {
1442 user_id: user_id_a,
1443 should_notify: should_notify && !a_to_b,
1444 });
1445 } else if a_to_b {
1446 contacts.push(Contact::Incoming {
1447 user_id: user_id_a,
1448 should_notify,
1449 });
1450 } else {
1451 contacts.push(Contact::Outgoing { user_id: user_id_a });
1452 }
1453 }
1454
1455 contacts.sort_unstable_by_key(|contact| contact.user_id());
1456
1457 Ok(contacts)
1458 })
1459 }
1460
1461 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1462 test_support!(self, {
1463 let (id_a, id_b) = if user_id_1 < user_id_2 {
1464 (user_id_1, user_id_2)
1465 } else {
1466 (user_id_2, user_id_1)
1467 };
1468
1469 let query = "
1470 SELECT 1 FROM contacts
1471 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
1472 LIMIT 1
1473 ";
1474 Ok(sqlx::query_scalar::<_, i32>(query)
1475 .bind(id_a.0)
1476 .bind(id_b.0)
1477 .fetch_optional(&self.pool)
1478 .await?
1479 .is_some())
1480 })
1481 }
1482
1483 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1484 test_support!(self, {
1485 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1486 (sender_id, receiver_id, true)
1487 } else {
1488 (receiver_id, sender_id, false)
1489 };
1490 let query = "
1491 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1492 VALUES ($1, $2, $3, FALSE, TRUE)
1493 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1494 SET
1495 accepted = TRUE,
1496 should_notify = FALSE
1497 WHERE
1498 NOT contacts.accepted AND
1499 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1500 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1501 ";
1502 let result = sqlx::query(query)
1503 .bind(id_a.0)
1504 .bind(id_b.0)
1505 .bind(a_to_b)
1506 .execute(&self.pool)
1507 .await?;
1508
1509 if result.rows_affected() == 1 {
1510 Ok(())
1511 } else {
1512 Err(anyhow!("contact already requested"))?
1513 }
1514 })
1515 }
1516
1517 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1518 test_support!(self, {
1519 let (id_a, id_b) = if responder_id < requester_id {
1520 (responder_id, requester_id)
1521 } else {
1522 (requester_id, responder_id)
1523 };
1524 let query = "
1525 DELETE FROM contacts
1526 WHERE user_id_a = $1 AND user_id_b = $2;
1527 ";
1528 let result = sqlx::query(query)
1529 .bind(id_a.0)
1530 .bind(id_b.0)
1531 .execute(&self.pool)
1532 .await?;
1533
1534 if result.rows_affected() == 1 {
1535 Ok(())
1536 } else {
1537 Err(anyhow!("no such contact"))?
1538 }
1539 })
1540 }
1541
1542 pub async fn dismiss_contact_notification(
1543 &self,
1544 user_id: UserId,
1545 contact_user_id: UserId,
1546 ) -> Result<()> {
1547 test_support!(self, {
1548 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1549 (user_id, contact_user_id, true)
1550 } else {
1551 (contact_user_id, user_id, false)
1552 };
1553
1554 let query = "
1555 UPDATE contacts
1556 SET should_notify = FALSE
1557 WHERE
1558 user_id_a = $1 AND user_id_b = $2 AND
1559 (
1560 (a_to_b = $3 AND accepted) OR
1561 (a_to_b != $3 AND NOT accepted)
1562 );
1563 ";
1564
1565 let result = sqlx::query(query)
1566 .bind(id_a.0)
1567 .bind(id_b.0)
1568 .bind(a_to_b)
1569 .execute(&self.pool)
1570 .await?;
1571
1572 if result.rows_affected() == 0 {
1573 Err(anyhow!("no such contact request"))?;
1574 }
1575
1576 Ok(())
1577 })
1578 }
1579
1580 pub async fn respond_to_contact_request(
1581 &self,
1582 responder_id: UserId,
1583 requester_id: UserId,
1584 accept: bool,
1585 ) -> Result<()> {
1586 test_support!(self, {
1587 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1588 (responder_id, requester_id, false)
1589 } else {
1590 (requester_id, responder_id, true)
1591 };
1592 let result = if accept {
1593 let query = "
1594 UPDATE contacts
1595 SET accepted = TRUE, should_notify = TRUE
1596 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1597 ";
1598 sqlx::query(query)
1599 .bind(id_a.0)
1600 .bind(id_b.0)
1601 .bind(a_to_b)
1602 .execute(&self.pool)
1603 .await?
1604 } else {
1605 let query = "
1606 DELETE FROM contacts
1607 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1608 ";
1609 sqlx::query(query)
1610 .bind(id_a.0)
1611 .bind(id_b.0)
1612 .bind(a_to_b)
1613 .execute(&self.pool)
1614 .await?
1615 };
1616 if result.rows_affected() == 1 {
1617 Ok(())
1618 } else {
1619 Err(anyhow!("no such contact request"))?
1620 }
1621 })
1622 }
1623
1624 // access tokens
1625
1626 pub async fn create_access_token_hash(
1627 &self,
1628 user_id: UserId,
1629 access_token_hash: &str,
1630 max_access_token_count: usize,
1631 ) -> Result<()> {
1632 test_support!(self, {
1633 let insert_query = "
1634 INSERT INTO access_tokens (user_id, hash)
1635 VALUES ($1, $2);
1636 ";
1637 let cleanup_query = "
1638 DELETE FROM access_tokens
1639 WHERE id IN (
1640 SELECT id from access_tokens
1641 WHERE user_id = $1
1642 ORDER BY id DESC
1643 LIMIT 10000
1644 OFFSET $3
1645 )
1646 ";
1647
1648 let mut tx = self.pool.begin().await?;
1649 sqlx::query(insert_query)
1650 .bind(user_id.0)
1651 .bind(access_token_hash)
1652 .execute(&mut tx)
1653 .await?;
1654 sqlx::query(cleanup_query)
1655 .bind(user_id.0)
1656 .bind(access_token_hash)
1657 .bind(max_access_token_count as i32)
1658 .execute(&mut tx)
1659 .await?;
1660 Ok(tx.commit().await?)
1661 })
1662 }
1663
1664 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1665 test_support!(self, {
1666 let query = "
1667 SELECT hash
1668 FROM access_tokens
1669 WHERE user_id = $1
1670 ORDER BY id DESC
1671 ";
1672 Ok(sqlx::query_scalar(query)
1673 .bind(user_id.0)
1674 .fetch_all(&self.pool)
1675 .await?)
1676 })
1677 }
1678}
1679
1680macro_rules! id_type {
1681 ($name:ident) => {
1682 #[derive(
1683 Clone,
1684 Copy,
1685 Debug,
1686 Default,
1687 PartialEq,
1688 Eq,
1689 PartialOrd,
1690 Ord,
1691 Hash,
1692 sqlx::Type,
1693 Serialize,
1694 Deserialize,
1695 )]
1696 #[sqlx(transparent)]
1697 #[serde(transparent)]
1698 pub struct $name(pub i32);
1699
1700 impl $name {
1701 #[allow(unused)]
1702 pub const MAX: Self = Self(i32::MAX);
1703
1704 #[allow(unused)]
1705 pub fn from_proto(value: u64) -> Self {
1706 Self(value as i32)
1707 }
1708
1709 #[allow(unused)]
1710 pub fn to_proto(self) -> u64 {
1711 self.0 as u64
1712 }
1713 }
1714
1715 impl std::fmt::Display for $name {
1716 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1717 self.0.fmt(f)
1718 }
1719 }
1720 };
1721}
1722
1723id_type!(UserId);
1724#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1725pub struct User {
1726 pub id: UserId,
1727 pub github_login: String,
1728 pub github_user_id: Option<i32>,
1729 pub email_address: Option<String>,
1730 pub admin: bool,
1731 pub invite_code: Option<String>,
1732 pub invite_count: i32,
1733 pub connected_once: bool,
1734}
1735
1736id_type!(RoomId);
1737#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1738pub struct Room {
1739 pub id: RoomId,
1740 pub version: i32,
1741 pub live_kit_room: String,
1742}
1743
1744#[derive(Clone, Debug, Default, FromRow, PartialEq)]
1745pub struct Call {
1746 pub room_id: RoomId,
1747 pub calling_user_id: UserId,
1748 pub called_user_id: UserId,
1749 pub answering_connection_id: Option<i32>,
1750 pub initial_project_id: Option<ProjectId>,
1751}
1752
1753id_type!(ProjectId);
1754#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1755pub struct Project {
1756 pub id: ProjectId,
1757 pub host_user_id: UserId,
1758 pub unregistered: bool,
1759}
1760
1761#[derive(Clone, Debug, Default, FromRow, PartialEq)]
1762pub struct ProjectCollaborator {
1763 pub project_id: ProjectId,
1764 pub connection_id: i32,
1765 pub user_id: UserId,
1766 pub replica_id: i32,
1767 pub is_host: bool,
1768}
1769
1770pub struct LeftProject {
1771 pub id: ProjectId,
1772 pub host_user_id: UserId,
1773 pub connection_ids: Vec<ConnectionId>,
1774}
1775
1776pub struct LeftRoom {
1777 pub room: proto::Room,
1778 pub left_projects: HashMap<ProjectId, LeftProject>,
1779 pub canceled_calls_to_user_ids: Vec<UserId>,
1780}
1781
1782#[derive(Clone, Debug, PartialEq, Eq)]
1783pub enum Contact {
1784 Accepted {
1785 user_id: UserId,
1786 should_notify: bool,
1787 },
1788 Outgoing {
1789 user_id: UserId,
1790 },
1791 Incoming {
1792 user_id: UserId,
1793 should_notify: bool,
1794 },
1795}
1796
1797impl Contact {
1798 pub fn user_id(&self) -> UserId {
1799 match self {
1800 Contact::Accepted { user_id, .. } => *user_id,
1801 Contact::Outgoing { user_id } => *user_id,
1802 Contact::Incoming { user_id, .. } => *user_id,
1803 }
1804 }
1805}
1806
1807#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1808pub struct IncomingContactRequest {
1809 pub requester_id: UserId,
1810 pub should_notify: bool,
1811}
1812
1813#[derive(Clone, Deserialize)]
1814pub struct Signup {
1815 pub email_address: String,
1816 pub platform_mac: bool,
1817 pub platform_windows: bool,
1818 pub platform_linux: bool,
1819 pub editor_features: Vec<String>,
1820 pub programming_languages: Vec<String>,
1821 pub device_id: Option<String>,
1822}
1823
1824#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1825pub struct WaitlistSummary {
1826 #[sqlx(default)]
1827 pub count: i64,
1828 #[sqlx(default)]
1829 pub linux_count: i64,
1830 #[sqlx(default)]
1831 pub mac_count: i64,
1832 #[sqlx(default)]
1833 pub windows_count: i64,
1834 #[sqlx(default)]
1835 pub unknown_count: i64,
1836}
1837
1838#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1839pub struct Invite {
1840 pub email_address: String,
1841 pub email_confirmation_code: String,
1842}
1843
1844#[derive(Debug, Serialize, Deserialize)]
1845pub struct NewUserParams {
1846 pub github_login: String,
1847 pub github_user_id: i32,
1848 pub invite_count: i32,
1849}
1850
1851#[derive(Debug)]
1852pub struct NewUserResult {
1853 pub user_id: UserId,
1854 pub metrics_id: String,
1855 pub inviting_user_id: Option<UserId>,
1856 pub signup_device_id: Option<String>,
1857}
1858
1859fn random_invite_code() -> String {
1860 nanoid::nanoid!(16)
1861}
1862
1863fn random_email_confirmation_code() -> String {
1864 nanoid::nanoid!(64)
1865}
1866
1867#[cfg(test)]
1868pub use test::*;
1869
1870#[cfg(test)]
1871mod test {
1872 use super::*;
1873 use gpui::executor::Background;
1874 use lazy_static::lazy_static;
1875 use parking_lot::Mutex;
1876 use rand::prelude::*;
1877 use sqlx::migrate::MigrateDatabase;
1878 use std::sync::Arc;
1879
1880 pub struct SqliteTestDb {
1881 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1882 pub conn: sqlx::sqlite::SqliteConnection,
1883 }
1884
1885 pub struct PostgresTestDb {
1886 pub db: Option<Arc<Db<sqlx::Postgres>>>,
1887 pub url: String,
1888 }
1889
1890 impl SqliteTestDb {
1891 pub fn new(background: Arc<Background>) -> Self {
1892 let mut rng = StdRng::from_entropy();
1893 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1894 let runtime = tokio::runtime::Builder::new_current_thread()
1895 .enable_io()
1896 .enable_time()
1897 .build()
1898 .unwrap();
1899
1900 let (mut db, conn) = runtime.block_on(async {
1901 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1902 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1903 db.migrate(migrations_path.as_ref(), false).await.unwrap();
1904 let conn = db.pool.acquire().await.unwrap().detach();
1905 (db, conn)
1906 });
1907
1908 db.background = Some(background);
1909 db.runtime = Some(runtime);
1910
1911 Self {
1912 db: Some(Arc::new(db)),
1913 conn,
1914 }
1915 }
1916
1917 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1918 self.db.as_ref().unwrap()
1919 }
1920 }
1921
1922 impl PostgresTestDb {
1923 pub fn new(background: Arc<Background>) -> Self {
1924 lazy_static! {
1925 static ref LOCK: Mutex<()> = Mutex::new(());
1926 }
1927
1928 let _guard = LOCK.lock();
1929 let mut rng = StdRng::from_entropy();
1930 let url = format!(
1931 "postgres://postgres@localhost/zed-test-{}",
1932 rng.gen::<u128>()
1933 );
1934 let runtime = tokio::runtime::Builder::new_current_thread()
1935 .enable_io()
1936 .enable_time()
1937 .build()
1938 .unwrap();
1939
1940 let mut db = runtime.block_on(async {
1941 sqlx::Postgres::create_database(&url)
1942 .await
1943 .expect("failed to create test db");
1944 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1945 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1946 db.migrate(Path::new(migrations_path), false).await.unwrap();
1947 db
1948 });
1949
1950 db.background = Some(background);
1951 db.runtime = Some(runtime);
1952
1953 Self {
1954 db: Some(Arc::new(db)),
1955 url,
1956 }
1957 }
1958
1959 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1960 self.db.as_ref().unwrap()
1961 }
1962 }
1963
1964 impl Drop for PostgresTestDb {
1965 fn drop(&mut self) {
1966 let db = self.db.take().unwrap();
1967 db.teardown(&self.url);
1968 }
1969 }
1970}