1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use sqlx::{
8 migrate::{Migrate as _, Migration, MigrationSource},
9 types::Uuid,
10 FromRow,
11};
12use std::{path::Path, time::Duration};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[cfg(test)]
16pub type DefaultDb = Db<sqlx::Sqlite>;
17
18#[cfg(not(test))]
19pub type DefaultDb = Db<sqlx::Postgres>;
20
21pub struct Db<D: sqlx::Database> {
22 pool: sqlx::Pool<D>,
23 #[cfg(test)]
24 background: Option<std::sync::Arc<gpui::executor::Background>>,
25 #[cfg(test)]
26 runtime: Option<tokio::runtime::Runtime>,
27}
28
29macro_rules! test_support {
30 ($self:ident, { $($token:tt)* }) => {{
31 let body = async {
32 $($token)*
33 };
34
35 if cfg!(test) {
36 #[cfg(not(test))]
37 unreachable!();
38
39 #[cfg(test)]
40 if let Some(background) = $self.background.as_ref() {
41 background.simulate_random_delay().await;
42 }
43
44 #[cfg(test)]
45 $self.runtime.as_ref().unwrap().block_on(body)
46 } else {
47 body.await
48 }
49 }};
50}
51
52pub trait RowsAffected {
53 fn rows_affected(&self) -> u64;
54}
55
56#[cfg(test)]
57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
58 fn rows_affected(&self) -> u64 {
59 self.rows_affected()
60 }
61}
62
63impl RowsAffected for sqlx::postgres::PgQueryResult {
64 fn rows_affected(&self) -> u64 {
65 self.rows_affected()
66 }
67}
68
69#[cfg(test)]
70impl Db<sqlx::Sqlite> {
71 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
72 use std::str::FromStr as _;
73 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
74 .unwrap()
75 .create_if_missing(true)
76 .shared_cache(true);
77 let pool = sqlx::sqlite::SqlitePoolOptions::new()
78 .min_connections(2)
79 .max_connections(max_connections)
80 .connect_with(options)
81 .await?;
82 Ok(Self {
83 pool,
84 background: None,
85 runtime: None,
86 })
87 }
88
89 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
90 test_support!(self, {
91 let query = "
92 SELECT users.*
93 FROM users
94 WHERE users.id IN (SELECT value from json_each($1))
95 ";
96 Ok(sqlx::query_as(query)
97 .bind(&serde_json::json!(ids))
98 .fetch_all(&self.pool)
99 .await?)
100 })
101 }
102
103 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
104 test_support!(self, {
105 let query = "
106 SELECT metrics_id
107 FROM users
108 WHERE id = $1
109 ";
110 Ok(sqlx::query_scalar(query)
111 .bind(id)
112 .fetch_one(&self.pool)
113 .await?)
114 })
115 }
116
117 pub async fn create_user(
118 &self,
119 email_address: &str,
120 admin: bool,
121 params: NewUserParams,
122 ) -> Result<NewUserResult> {
123 test_support!(self, {
124 let query = "
125 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
126 VALUES ($1, $2, $3, $4, $5)
127 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
128 RETURNING id, metrics_id
129 ";
130
131 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
132 .bind(email_address)
133 .bind(params.github_login)
134 .bind(params.github_user_id)
135 .bind(admin)
136 .bind(Uuid::new_v4().to_string())
137 .fetch_one(&self.pool)
138 .await?;
139 Ok(NewUserResult {
140 user_id,
141 metrics_id,
142 signup_device_id: None,
143 inviting_user_id: None,
144 })
145 })
146 }
147
148 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
149 unimplemented!()
150 }
151
152 pub async fn create_user_from_invite(
153 &self,
154 _invite: &Invite,
155 _user: NewUserParams,
156 ) -> Result<Option<NewUserResult>> {
157 unimplemented!()
158 }
159
160 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
161 unimplemented!()
162 }
163
164 pub async fn create_invite_from_code(
165 &self,
166 _code: &str,
167 _email_address: &str,
168 _device_id: Option<&str>,
169 ) -> Result<Invite> {
170 unimplemented!()
171 }
172
173 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
174 unimplemented!()
175 }
176}
177
178impl Db<sqlx::Postgres> {
179 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
180 let pool = sqlx::postgres::PgPoolOptions::new()
181 .max_connections(max_connections)
182 .connect(url)
183 .await?;
184 Ok(Self {
185 pool,
186 #[cfg(test)]
187 background: None,
188 #[cfg(test)]
189 runtime: None,
190 })
191 }
192
193 #[cfg(test)]
194 pub fn teardown(&self, url: &str) {
195 self.runtime.as_ref().unwrap().block_on(async {
196 use util::ResultExt;
197 let query = "
198 SELECT pg_terminate_backend(pg_stat_activity.pid)
199 FROM pg_stat_activity
200 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
201 ";
202 sqlx::query(query).execute(&self.pool).await.log_err();
203 self.pool.close().await;
204 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
205 .await
206 .log_err();
207 })
208 }
209
210 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
211 test_support!(self, {
212 let like_string = Self::fuzzy_like_string(name_query);
213 let query = "
214 SELECT users.*
215 FROM users
216 WHERE github_login ILIKE $1
217 ORDER BY github_login <-> $2
218 LIMIT $3
219 ";
220 Ok(sqlx::query_as(query)
221 .bind(like_string)
222 .bind(name_query)
223 .bind(limit as i32)
224 .fetch_all(&self.pool)
225 .await?)
226 })
227 }
228
229 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
230 test_support!(self, {
231 let query = "
232 SELECT users.*
233 FROM users
234 WHERE users.id = ANY ($1)
235 ";
236 Ok(sqlx::query_as(query)
237 .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
238 .fetch_all(&self.pool)
239 .await?)
240 })
241 }
242
243 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
244 test_support!(self, {
245 let query = "
246 SELECT metrics_id::text
247 FROM users
248 WHERE id = $1
249 ";
250 Ok(sqlx::query_scalar(query)
251 .bind(id)
252 .fetch_one(&self.pool)
253 .await?)
254 })
255 }
256
257 pub async fn create_user(
258 &self,
259 email_address: &str,
260 admin: bool,
261 params: NewUserParams,
262 ) -> Result<NewUserResult> {
263 test_support!(self, {
264 let query = "
265 INSERT INTO users (email_address, github_login, github_user_id, admin)
266 VALUES ($1, $2, $3, $4)
267 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
268 RETURNING id, metrics_id::text
269 ";
270
271 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
272 .bind(email_address)
273 .bind(params.github_login)
274 .bind(params.github_user_id)
275 .bind(admin)
276 .fetch_one(&self.pool)
277 .await?;
278 Ok(NewUserResult {
279 user_id,
280 metrics_id,
281 signup_device_id: None,
282 inviting_user_id: None,
283 })
284 })
285 }
286
287 pub async fn create_user_from_invite(
288 &self,
289 invite: &Invite,
290 user: NewUserParams,
291 ) -> Result<Option<NewUserResult>> {
292 test_support!(self, {
293 let mut tx = self.pool.begin().await?;
294
295 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
296 i32,
297 Option<UserId>,
298 Option<UserId>,
299 Option<String>,
300 ) = sqlx::query_as(
301 "
302 SELECT id, user_id, inviting_user_id, device_id
303 FROM signups
304 WHERE
305 email_address = $1 AND
306 email_confirmation_code = $2
307 ",
308 )
309 .bind(&invite.email_address)
310 .bind(&invite.email_confirmation_code)
311 .fetch_optional(&mut tx)
312 .await?
313 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
314
315 if existing_user_id.is_some() {
316 return Ok(None);
317 }
318
319 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
320 "
321 INSERT INTO users
322 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
323 VALUES
324 ($1, $2, $3, FALSE, $4, $5)
325 ON CONFLICT (github_login) DO UPDATE SET
326 email_address = excluded.email_address,
327 github_user_id = excluded.github_user_id,
328 admin = excluded.admin
329 RETURNING id, metrics_id::text
330 ",
331 )
332 .bind(&invite.email_address)
333 .bind(&user.github_login)
334 .bind(&user.github_user_id)
335 .bind(&user.invite_count)
336 .bind(random_invite_code())
337 .fetch_one(&mut tx)
338 .await?;
339
340 sqlx::query(
341 "
342 UPDATE signups
343 SET user_id = $1
344 WHERE id = $2
345 ",
346 )
347 .bind(&user_id)
348 .bind(&signup_id)
349 .execute(&mut tx)
350 .await?;
351
352 if let Some(inviting_user_id) = inviting_user_id {
353 sqlx::query(
354 "
355 INSERT INTO contacts
356 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
357 VALUES
358 ($1, $2, TRUE, TRUE, TRUE)
359 ON CONFLICT DO NOTHING
360 ",
361 )
362 .bind(inviting_user_id)
363 .bind(user_id)
364 .execute(&mut tx)
365 .await?;
366 }
367
368 tx.commit().await?;
369 Ok(Some(NewUserResult {
370 user_id,
371 metrics_id,
372 inviting_user_id,
373 signup_device_id,
374 }))
375 })
376 }
377
378 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
379 test_support!(self, {
380 sqlx::query(
381 "
382 INSERT INTO signups
383 (
384 email_address,
385 email_confirmation_code,
386 email_confirmation_sent,
387 platform_linux,
388 platform_mac,
389 platform_windows,
390 platform_unknown,
391 editor_features,
392 programming_languages,
393 device_id
394 )
395 VALUES
396 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
397 RETURNING id
398 ",
399 )
400 .bind(&signup.email_address)
401 .bind(&random_email_confirmation_code())
402 .bind(&signup.platform_linux)
403 .bind(&signup.platform_mac)
404 .bind(&signup.platform_windows)
405 .bind(&signup.editor_features)
406 .bind(&signup.programming_languages)
407 .bind(&signup.device_id)
408 .execute(&self.pool)
409 .await?;
410 Ok(())
411 })
412 }
413
414 pub async fn create_invite_from_code(
415 &self,
416 code: &str,
417 email_address: &str,
418 device_id: Option<&str>,
419 ) -> Result<Invite> {
420 test_support!(self, {
421 let mut tx = self.pool.begin().await?;
422
423 let existing_user: Option<UserId> = sqlx::query_scalar(
424 "
425 SELECT id
426 FROM users
427 WHERE email_address = $1
428 ",
429 )
430 .bind(email_address)
431 .fetch_optional(&mut tx)
432 .await?;
433 if existing_user.is_some() {
434 Err(anyhow!("email address is already in use"))?;
435 }
436
437 let inviting_user_id_with_invites: Option<UserId> = sqlx::query_scalar(
438 "
439 UPDATE users
440 SET invite_count = invite_count - 1
441 WHERE invite_code = $1 AND invite_count > 0
442 RETURNING id
443 ",
444 )
445 .bind(code)
446 .fetch_optional(&mut tx)
447 .await?;
448
449 let Some(inviter_id) = inviting_user_id_with_invites else {
450 return Err(Error::Http(
451 StatusCode::UNAUTHORIZED,
452 "unable to find an invite code with invites remaining".to_string(),
453 ));
454 };
455
456 let email_confirmation_code: String = sqlx::query_scalar(
457 "
458 INSERT INTO signups
459 (
460 email_address,
461 email_confirmation_code,
462 email_confirmation_sent,
463 inviting_user_id,
464 platform_linux,
465 platform_mac,
466 platform_windows,
467 platform_unknown,
468 device_id
469 )
470 VALUES
471 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
472 ON CONFLICT (email_address)
473 DO UPDATE SET
474 inviting_user_id = excluded.inviting_user_id
475 RETURNING email_confirmation_code
476 ",
477 )
478 .bind(&email_address)
479 .bind(&random_email_confirmation_code())
480 .bind(&inviter_id)
481 .bind(&device_id)
482 .fetch_one(&mut tx)
483 .await?;
484
485 tx.commit().await?;
486
487 Ok(Invite {
488 email_address: email_address.into(),
489 email_confirmation_code,
490 })
491 })
492 }
493
494 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
495 test_support!(self, {
496 let emails = invites
497 .iter()
498 .map(|s| s.email_address.as_str())
499 .collect::<Vec<_>>();
500 sqlx::query(
501 "
502 UPDATE signups
503 SET email_confirmation_sent = TRUE
504 WHERE email_address = ANY ($1)
505 ",
506 )
507 .bind(&emails)
508 .execute(&self.pool)
509 .await?;
510 Ok(())
511 })
512 }
513}
514
515impl<D> Db<D>
516where
517 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
518 D::Connection: sqlx::migrate::Migrate,
519 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
520 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
521 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
522 D::QueryResult: RowsAffected,
523 String: sqlx::Type<D>,
524 i32: sqlx::Type<D>,
525 i64: sqlx::Type<D>,
526 bool: sqlx::Type<D>,
527 str: sqlx::Type<D>,
528 Uuid: sqlx::Type<D>,
529 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
530 OffsetDateTime: sqlx::Type<D>,
531 PrimitiveDateTime: sqlx::Type<D>,
532 usize: sqlx::ColumnIndex<D::Row>,
533 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
534 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
535 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
536 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
537 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
538 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
539 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
540 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
541 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
542 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
543 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
544 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
545{
546 pub async fn migrate(
547 &self,
548 migrations_path: &Path,
549 ignore_checksum_mismatch: bool,
550 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
551 let migrations = MigrationSource::resolve(migrations_path)
552 .await
553 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
554
555 let mut conn = self.pool.acquire().await?;
556
557 conn.ensure_migrations_table().await?;
558 let applied_migrations: HashMap<_, _> = conn
559 .list_applied_migrations()
560 .await?
561 .into_iter()
562 .map(|m| (m.version, m))
563 .collect();
564
565 let mut new_migrations = Vec::new();
566 for migration in migrations {
567 match applied_migrations.get(&migration.version) {
568 Some(applied_migration) => {
569 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
570 {
571 Err(anyhow!(
572 "checksum mismatch for applied migration {}",
573 migration.description
574 ))?;
575 }
576 }
577 None => {
578 let elapsed = conn.apply(&migration).await?;
579 new_migrations.push((migration, elapsed));
580 }
581 }
582 }
583
584 Ok(new_migrations)
585 }
586
587 pub fn fuzzy_like_string(string: &str) -> String {
588 let mut result = String::with_capacity(string.len() * 2 + 1);
589 for c in string.chars() {
590 if c.is_alphanumeric() {
591 result.push('%');
592 result.push(c);
593 }
594 }
595 result.push('%');
596 result
597 }
598
599 // users
600
601 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
602 test_support!(self, {
603 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
604 Ok(sqlx::query_as(query)
605 .bind(limit as i32)
606 .bind((page * limit) as i32)
607 .fetch_all(&self.pool)
608 .await?)
609 })
610 }
611
612 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
613 test_support!(self, {
614 let query = "
615 SELECT users.*
616 FROM users
617 WHERE id = $1
618 LIMIT 1
619 ";
620 Ok(sqlx::query_as(query)
621 .bind(&id)
622 .fetch_optional(&self.pool)
623 .await?)
624 })
625 }
626
627 pub async fn get_users_with_no_invites(
628 &self,
629 invited_by_another_user: bool,
630 ) -> Result<Vec<User>> {
631 test_support!(self, {
632 let query = format!(
633 "
634 SELECT users.*
635 FROM users
636 WHERE invite_count = 0
637 AND inviter_id IS{} NULL
638 ",
639 if invited_by_another_user { " NOT" } else { "" }
640 );
641
642 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
643 })
644 }
645
646 pub async fn get_user_by_github_account(
647 &self,
648 github_login: &str,
649 github_user_id: Option<i32>,
650 ) -> Result<Option<User>> {
651 test_support!(self, {
652 if let Some(github_user_id) = github_user_id {
653 let mut user = sqlx::query_as::<_, User>(
654 "
655 UPDATE users
656 SET github_login = $1
657 WHERE github_user_id = $2
658 RETURNING *
659 ",
660 )
661 .bind(github_login)
662 .bind(github_user_id)
663 .fetch_optional(&self.pool)
664 .await?;
665
666 if user.is_none() {
667 user = sqlx::query_as::<_, User>(
668 "
669 UPDATE users
670 SET github_user_id = $1
671 WHERE github_login = $2
672 RETURNING *
673 ",
674 )
675 .bind(github_user_id)
676 .bind(github_login)
677 .fetch_optional(&self.pool)
678 .await?;
679 }
680
681 Ok(user)
682 } else {
683 let user = sqlx::query_as(
684 "
685 SELECT * FROM users
686 WHERE github_login = $1
687 LIMIT 1
688 ",
689 )
690 .bind(github_login)
691 .fetch_optional(&self.pool)
692 .await?;
693 Ok(user)
694 }
695 })
696 }
697
698 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
699 test_support!(self, {
700 let query = "UPDATE users SET admin = $1 WHERE id = $2";
701 Ok(sqlx::query(query)
702 .bind(is_admin)
703 .bind(id.0)
704 .execute(&self.pool)
705 .await
706 .map(drop)?)
707 })
708 }
709
710 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
711 test_support!(self, {
712 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
713 Ok(sqlx::query(query)
714 .bind(connected_once)
715 .bind(id.0)
716 .execute(&self.pool)
717 .await
718 .map(drop)?)
719 })
720 }
721
722 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
723 test_support!(self, {
724 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
725 sqlx::query(query)
726 .bind(id.0)
727 .execute(&self.pool)
728 .await
729 .map(drop)?;
730 let query = "DELETE FROM users WHERE id = $1;";
731 Ok(sqlx::query(query)
732 .bind(id.0)
733 .execute(&self.pool)
734 .await
735 .map(drop)?)
736 })
737 }
738
739 // signups
740
741 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
742 test_support!(self, {
743 Ok(sqlx::query_as(
744 "
745 SELECT
746 COUNT(*) as count,
747 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
748 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
749 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
750 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
751 FROM (
752 SELECT *
753 FROM signups
754 WHERE
755 NOT email_confirmation_sent
756 ) AS unsent
757 ",
758 )
759 .fetch_one(&self.pool)
760 .await?)
761 })
762 }
763
764 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
765 test_support!(self, {
766 Ok(sqlx::query_as(
767 "
768 SELECT
769 email_address, email_confirmation_code
770 FROM signups
771 WHERE
772 NOT email_confirmation_sent AND
773 (platform_mac OR platform_unknown)
774 LIMIT $1
775 ",
776 )
777 .bind(count as i32)
778 .fetch_all(&self.pool)
779 .await?)
780 })
781 }
782
783 // invite codes
784
785 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
786 test_support!(self, {
787 let mut tx = self.pool.begin().await?;
788 if count > 0 {
789 sqlx::query(
790 "
791 UPDATE users
792 SET invite_code = $1
793 WHERE id = $2 AND invite_code IS NULL
794 ",
795 )
796 .bind(random_invite_code())
797 .bind(id)
798 .execute(&mut tx)
799 .await?;
800 }
801
802 sqlx::query(
803 "
804 UPDATE users
805 SET invite_count = $1
806 WHERE id = $2
807 ",
808 )
809 .bind(count as i32)
810 .bind(id)
811 .execute(&mut tx)
812 .await?;
813 tx.commit().await?;
814 Ok(())
815 })
816 }
817
818 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
819 test_support!(self, {
820 let result: Option<(String, i32)> = sqlx::query_as(
821 "
822 SELECT invite_code, invite_count
823 FROM users
824 WHERE id = $1 AND invite_code IS NOT NULL
825 ",
826 )
827 .bind(id)
828 .fetch_optional(&self.pool)
829 .await?;
830 if let Some((code, count)) = result {
831 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
832 } else {
833 Ok(None)
834 }
835 })
836 }
837
838 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
839 test_support!(self, {
840 sqlx::query_as(
841 "
842 SELECT *
843 FROM users
844 WHERE invite_code = $1
845 ",
846 )
847 .bind(code)
848 .fetch_optional(&self.pool)
849 .await?
850 .ok_or_else(|| {
851 Error::Http(
852 StatusCode::NOT_FOUND,
853 "that invite code does not exist".to_string(),
854 )
855 })
856 })
857 }
858
859 // projects
860
861 /// Registers a new project for the given user.
862 pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
863 test_support!(self, {
864 Ok(sqlx::query_scalar(
865 "
866 INSERT INTO projects(host_user_id)
867 VALUES ($1)
868 RETURNING id
869 ",
870 )
871 .bind(host_user_id)
872 .fetch_one(&self.pool)
873 .await
874 .map(ProjectId)?)
875 })
876 }
877
878 /// Unregisters a project for the given project id.
879 pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
880 test_support!(self, {
881 sqlx::query(
882 "
883 UPDATE projects
884 SET unregistered = TRUE
885 WHERE id = $1
886 ",
887 )
888 .bind(project_id)
889 .execute(&self.pool)
890 .await?;
891 Ok(())
892 })
893 }
894
895 // contacts
896
897 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
898 test_support!(self, {
899 let query = "
900 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
901 FROM contacts
902 WHERE user_id_a = $1 OR user_id_b = $1;
903 ";
904
905 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
906 .bind(user_id)
907 .fetch(&self.pool);
908
909 let mut contacts = Vec::new();
910 while let Some(row) = rows.next().await {
911 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
912
913 if user_id_a == user_id {
914 if accepted {
915 contacts.push(Contact::Accepted {
916 user_id: user_id_b,
917 should_notify: should_notify && a_to_b,
918 });
919 } else if a_to_b {
920 contacts.push(Contact::Outgoing { user_id: user_id_b })
921 } else {
922 contacts.push(Contact::Incoming {
923 user_id: user_id_b,
924 should_notify,
925 });
926 }
927 } else if accepted {
928 contacts.push(Contact::Accepted {
929 user_id: user_id_a,
930 should_notify: should_notify && !a_to_b,
931 });
932 } else if a_to_b {
933 contacts.push(Contact::Incoming {
934 user_id: user_id_a,
935 should_notify,
936 });
937 } else {
938 contacts.push(Contact::Outgoing { user_id: user_id_a });
939 }
940 }
941
942 contacts.sort_unstable_by_key(|contact| contact.user_id());
943
944 Ok(contacts)
945 })
946 }
947
948 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
949 test_support!(self, {
950 let (id_a, id_b) = if user_id_1 < user_id_2 {
951 (user_id_1, user_id_2)
952 } else {
953 (user_id_2, user_id_1)
954 };
955
956 let query = "
957 SELECT 1 FROM contacts
958 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
959 LIMIT 1
960 ";
961 Ok(sqlx::query_scalar::<_, i32>(query)
962 .bind(id_a.0)
963 .bind(id_b.0)
964 .fetch_optional(&self.pool)
965 .await?
966 .is_some())
967 })
968 }
969
970 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
971 test_support!(self, {
972 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
973 (sender_id, receiver_id, true)
974 } else {
975 (receiver_id, sender_id, false)
976 };
977 let query = "
978 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
979 VALUES ($1, $2, $3, FALSE, TRUE)
980 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
981 SET
982 accepted = TRUE,
983 should_notify = FALSE
984 WHERE
985 NOT contacts.accepted AND
986 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
987 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
988 ";
989 let result = sqlx::query(query)
990 .bind(id_a.0)
991 .bind(id_b.0)
992 .bind(a_to_b)
993 .execute(&self.pool)
994 .await?;
995
996 if result.rows_affected() == 1 {
997 Ok(())
998 } else {
999 Err(anyhow!("contact already requested"))?
1000 }
1001 })
1002 }
1003
1004 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1005 test_support!(self, {
1006 let (id_a, id_b) = if responder_id < requester_id {
1007 (responder_id, requester_id)
1008 } else {
1009 (requester_id, responder_id)
1010 };
1011 let query = "
1012 DELETE FROM contacts
1013 WHERE user_id_a = $1 AND user_id_b = $2;
1014 ";
1015 let result = sqlx::query(query)
1016 .bind(id_a.0)
1017 .bind(id_b.0)
1018 .execute(&self.pool)
1019 .await?;
1020
1021 if result.rows_affected() == 1 {
1022 Ok(())
1023 } else {
1024 Err(anyhow!("no such contact"))?
1025 }
1026 })
1027 }
1028
1029 pub async fn dismiss_contact_notification(
1030 &self,
1031 user_id: UserId,
1032 contact_user_id: UserId,
1033 ) -> Result<()> {
1034 test_support!(self, {
1035 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1036 (user_id, contact_user_id, true)
1037 } else {
1038 (contact_user_id, user_id, false)
1039 };
1040
1041 let query = "
1042 UPDATE contacts
1043 SET should_notify = FALSE
1044 WHERE
1045 user_id_a = $1 AND user_id_b = $2 AND
1046 (
1047 (a_to_b = $3 AND accepted) OR
1048 (a_to_b != $3 AND NOT accepted)
1049 );
1050 ";
1051
1052 let result = sqlx::query(query)
1053 .bind(id_a.0)
1054 .bind(id_b.0)
1055 .bind(a_to_b)
1056 .execute(&self.pool)
1057 .await?;
1058
1059 if result.rows_affected() == 0 {
1060 Err(anyhow!("no such contact request"))?;
1061 }
1062
1063 Ok(())
1064 })
1065 }
1066
1067 pub async fn respond_to_contact_request(
1068 &self,
1069 responder_id: UserId,
1070 requester_id: UserId,
1071 accept: bool,
1072 ) -> Result<()> {
1073 test_support!(self, {
1074 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1075 (responder_id, requester_id, false)
1076 } else {
1077 (requester_id, responder_id, true)
1078 };
1079 let result = if accept {
1080 let query = "
1081 UPDATE contacts
1082 SET accepted = TRUE, should_notify = TRUE
1083 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1084 ";
1085 sqlx::query(query)
1086 .bind(id_a.0)
1087 .bind(id_b.0)
1088 .bind(a_to_b)
1089 .execute(&self.pool)
1090 .await?
1091 } else {
1092 let query = "
1093 DELETE FROM contacts
1094 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1095 ";
1096 sqlx::query(query)
1097 .bind(id_a.0)
1098 .bind(id_b.0)
1099 .bind(a_to_b)
1100 .execute(&self.pool)
1101 .await?
1102 };
1103 if result.rows_affected() == 1 {
1104 Ok(())
1105 } else {
1106 Err(anyhow!("no such contact request"))?
1107 }
1108 })
1109 }
1110
1111 // access tokens
1112
1113 pub async fn create_access_token_hash(
1114 &self,
1115 user_id: UserId,
1116 access_token_hash: &str,
1117 max_access_token_count: usize,
1118 ) -> Result<()> {
1119 test_support!(self, {
1120 let insert_query = "
1121 INSERT INTO access_tokens (user_id, hash)
1122 VALUES ($1, $2);
1123 ";
1124 let cleanup_query = "
1125 DELETE FROM access_tokens
1126 WHERE id IN (
1127 SELECT id from access_tokens
1128 WHERE user_id = $1
1129 ORDER BY id DESC
1130 LIMIT 10000
1131 OFFSET $3
1132 )
1133 ";
1134
1135 let mut tx = self.pool.begin().await?;
1136 sqlx::query(insert_query)
1137 .bind(user_id.0)
1138 .bind(access_token_hash)
1139 .execute(&mut tx)
1140 .await?;
1141 sqlx::query(cleanup_query)
1142 .bind(user_id.0)
1143 .bind(access_token_hash)
1144 .bind(max_access_token_count as i32)
1145 .execute(&mut tx)
1146 .await?;
1147 Ok(tx.commit().await?)
1148 })
1149 }
1150
1151 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1152 test_support!(self, {
1153 let query = "
1154 SELECT hash
1155 FROM access_tokens
1156 WHERE user_id = $1
1157 ORDER BY id DESC
1158 ";
1159 Ok(sqlx::query_scalar(query)
1160 .bind(user_id.0)
1161 .fetch_all(&self.pool)
1162 .await?)
1163 })
1164 }
1165}
1166
1167macro_rules! id_type {
1168 ($name:ident) => {
1169 #[derive(
1170 Clone,
1171 Copy,
1172 Debug,
1173 Default,
1174 PartialEq,
1175 Eq,
1176 PartialOrd,
1177 Ord,
1178 Hash,
1179 sqlx::Type,
1180 Serialize,
1181 Deserialize,
1182 )]
1183 #[sqlx(transparent)]
1184 #[serde(transparent)]
1185 pub struct $name(pub i32);
1186
1187 impl $name {
1188 #[allow(unused)]
1189 pub const MAX: Self = Self(i32::MAX);
1190
1191 #[allow(unused)]
1192 pub fn from_proto(value: u64) -> Self {
1193 Self(value as i32)
1194 }
1195
1196 #[allow(unused)]
1197 pub fn to_proto(self) -> u64 {
1198 self.0 as u64
1199 }
1200 }
1201
1202 impl std::fmt::Display for $name {
1203 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1204 self.0.fmt(f)
1205 }
1206 }
1207 };
1208}
1209
1210id_type!(UserId);
1211#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1212pub struct User {
1213 pub id: UserId,
1214 pub github_login: String,
1215 pub github_user_id: Option<i32>,
1216 pub email_address: Option<String>,
1217 pub admin: bool,
1218 pub invite_code: Option<String>,
1219 pub invite_count: i32,
1220 pub connected_once: bool,
1221}
1222
1223id_type!(ProjectId);
1224#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1225pub struct Project {
1226 pub id: ProjectId,
1227 pub host_user_id: UserId,
1228 pub unregistered: bool,
1229}
1230
1231#[derive(Clone, Debug, PartialEq, Eq)]
1232pub enum Contact {
1233 Accepted {
1234 user_id: UserId,
1235 should_notify: bool,
1236 },
1237 Outgoing {
1238 user_id: UserId,
1239 },
1240 Incoming {
1241 user_id: UserId,
1242 should_notify: bool,
1243 },
1244}
1245
1246impl Contact {
1247 pub fn user_id(&self) -> UserId {
1248 match self {
1249 Contact::Accepted { user_id, .. } => *user_id,
1250 Contact::Outgoing { user_id } => *user_id,
1251 Contact::Incoming { user_id, .. } => *user_id,
1252 }
1253 }
1254}
1255
1256#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1257pub struct IncomingContactRequest {
1258 pub requester_id: UserId,
1259 pub should_notify: bool,
1260}
1261
1262#[derive(Clone, Deserialize)]
1263pub struct Signup {
1264 pub email_address: String,
1265 pub platform_mac: bool,
1266 pub platform_windows: bool,
1267 pub platform_linux: bool,
1268 pub editor_features: Vec<String>,
1269 pub programming_languages: Vec<String>,
1270 pub device_id: Option<String>,
1271}
1272
1273#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1274pub struct WaitlistSummary {
1275 #[sqlx(default)]
1276 pub count: i64,
1277 #[sqlx(default)]
1278 pub linux_count: i64,
1279 #[sqlx(default)]
1280 pub mac_count: i64,
1281 #[sqlx(default)]
1282 pub windows_count: i64,
1283 #[sqlx(default)]
1284 pub unknown_count: i64,
1285}
1286
1287#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1288pub struct Invite {
1289 pub email_address: String,
1290 pub email_confirmation_code: String,
1291}
1292
1293#[derive(Debug, Serialize, Deserialize)]
1294pub struct NewUserParams {
1295 pub github_login: String,
1296 pub github_user_id: i32,
1297 pub invite_count: i32,
1298}
1299
1300#[derive(Debug)]
1301pub struct NewUserResult {
1302 pub user_id: UserId,
1303 pub metrics_id: String,
1304 pub inviting_user_id: Option<UserId>,
1305 pub signup_device_id: Option<String>,
1306}
1307
1308fn random_invite_code() -> String {
1309 nanoid::nanoid!(16)
1310}
1311
1312fn random_email_confirmation_code() -> String {
1313 nanoid::nanoid!(64)
1314}
1315
1316#[cfg(test)]
1317pub use test::*;
1318
1319#[cfg(test)]
1320mod test {
1321 use super::*;
1322 use gpui::executor::Background;
1323 use lazy_static::lazy_static;
1324 use parking_lot::Mutex;
1325 use rand::prelude::*;
1326 use sqlx::migrate::MigrateDatabase;
1327 use std::sync::Arc;
1328
1329 pub struct SqliteTestDb {
1330 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1331 pub conn: sqlx::sqlite::SqliteConnection,
1332 }
1333
1334 pub struct PostgresTestDb {
1335 pub db: Option<Arc<Db<sqlx::Postgres>>>,
1336 pub url: String,
1337 }
1338
1339 impl SqliteTestDb {
1340 pub fn new(background: Arc<Background>) -> Self {
1341 let mut rng = StdRng::from_entropy();
1342 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1343 let runtime = tokio::runtime::Builder::new_current_thread()
1344 .enable_io()
1345 .enable_time()
1346 .build()
1347 .unwrap();
1348
1349 let (mut db, conn) = runtime.block_on(async {
1350 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1351 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1352 db.migrate(migrations_path.as_ref(), false).await.unwrap();
1353 let conn = db.pool.acquire().await.unwrap().detach();
1354 (db, conn)
1355 });
1356
1357 db.background = Some(background);
1358 db.runtime = Some(runtime);
1359
1360 Self {
1361 db: Some(Arc::new(db)),
1362 conn,
1363 }
1364 }
1365
1366 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1367 self.db.as_ref().unwrap()
1368 }
1369 }
1370
1371 impl PostgresTestDb {
1372 pub fn new(background: Arc<Background>) -> Self {
1373 lazy_static! {
1374 static ref LOCK: Mutex<()> = Mutex::new(());
1375 }
1376
1377 let _guard = LOCK.lock();
1378 let mut rng = StdRng::from_entropy();
1379 let url = format!(
1380 "postgres://postgres@localhost/zed-test-{}",
1381 rng.gen::<u128>()
1382 );
1383 let runtime = tokio::runtime::Builder::new_current_thread()
1384 .enable_io()
1385 .enable_time()
1386 .build()
1387 .unwrap();
1388
1389 let mut db = runtime.block_on(async {
1390 sqlx::Postgres::create_database(&url)
1391 .await
1392 .expect("failed to create test db");
1393 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1394 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1395 db.migrate(Path::new(migrations_path), false).await.unwrap();
1396 db
1397 });
1398
1399 db.background = Some(background);
1400 db.runtime = Some(runtime);
1401
1402 Self {
1403 db: Some(Arc::new(db)),
1404 url,
1405 }
1406 }
1407
1408 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1409 self.db.as_ref().unwrap()
1410 }
1411 }
1412
1413 impl Drop for PostgresTestDb {
1414 fn drop(&mut self) {
1415 let db = self.db.take().unwrap();
1416 db.teardown(&self.url);
1417 }
1418 }
1419}