1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use sqlx::{
8 migrate::{Migrate as _, Migration, MigrationSource},
9 types::Uuid,
10 FromRow,
11};
12use std::{path::Path, time::Duration};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[cfg(test)]
16pub type DefaultDb = Db<sqlx::Sqlite>;
17
18#[cfg(not(test))]
19pub type DefaultDb = Db<sqlx::Postgres>;
20
21pub struct Db<D: sqlx::Database> {
22 pool: sqlx::Pool<D>,
23 #[cfg(test)]
24 background: Option<std::sync::Arc<gpui::executor::Background>>,
25 #[cfg(test)]
26 runtime: Option<tokio::runtime::Runtime>,
27}
28
29macro_rules! test_support {
30 ($self:ident, { $($token:tt)* }) => {{
31 let body = async {
32 $($token)*
33 };
34
35 if cfg!(test) {
36 #[cfg(not(test))]
37 unreachable!();
38
39 #[cfg(test)]
40 if let Some(background) = $self.background.as_ref() {
41 background.simulate_random_delay().await;
42 }
43
44 #[cfg(test)]
45 $self.runtime.as_ref().unwrap().block_on(body)
46 } else {
47 body.await
48 }
49 }};
50}
51
52pub trait RowsAffected {
53 fn rows_affected(&self) -> u64;
54}
55
56#[cfg(test)]
57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
58 fn rows_affected(&self) -> u64 {
59 self.rows_affected()
60 }
61}
62
63impl RowsAffected for sqlx::postgres::PgQueryResult {
64 fn rows_affected(&self) -> u64 {
65 self.rows_affected()
66 }
67}
68
69#[cfg(test)]
70impl Db<sqlx::Sqlite> {
71 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
72 use std::str::FromStr as _;
73 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
74 .unwrap()
75 .create_if_missing(true)
76 .shared_cache(true);
77 let pool = sqlx::sqlite::SqlitePoolOptions::new()
78 .min_connections(2)
79 .max_connections(max_connections)
80 .connect_with(options)
81 .await?;
82 Ok(Self {
83 pool,
84 background: None,
85 runtime: None,
86 })
87 }
88
89 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
90 test_support!(self, {
91 let query = "
92 SELECT users.*
93 FROM users
94 WHERE users.id IN (SELECT value from json_each($1))
95 ";
96 Ok(sqlx::query_as(query)
97 .bind(&serde_json::json!(ids))
98 .fetch_all(&self.pool)
99 .await?)
100 })
101 }
102
103 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
104 test_support!(self, {
105 let query = "
106 SELECT metrics_id
107 FROM users
108 WHERE id = $1
109 ";
110 Ok(sqlx::query_scalar(query)
111 .bind(id)
112 .fetch_one(&self.pool)
113 .await?)
114 })
115 }
116
117 pub async fn create_user(
118 &self,
119 email_address: &str,
120 admin: bool,
121 params: NewUserParams,
122 ) -> Result<NewUserResult> {
123 test_support!(self, {
124 let query = "
125 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
126 VALUES ($1, $2, $3, $4, $5)
127 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
128 RETURNING id, metrics_id
129 ";
130
131 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
132 .bind(email_address)
133 .bind(params.github_login)
134 .bind(params.github_user_id)
135 .bind(admin)
136 .bind(Uuid::new_v4().to_string())
137 .fetch_one(&self.pool)
138 .await?;
139 Ok(NewUserResult {
140 user_id,
141 metrics_id,
142 signup_device_id: None,
143 inviting_user_id: None,
144 })
145 })
146 }
147
148 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
149 unimplemented!()
150 }
151
152 pub async fn create_user_from_invite(
153 &self,
154 _invite: &Invite,
155 _user: NewUserParams,
156 ) -> Result<Option<NewUserResult>> {
157 unimplemented!()
158 }
159
160 pub async fn create_signup(&self, _signup: &Signup) -> Result<()> {
161 unimplemented!()
162 }
163
164 pub async fn create_invite_from_code(
165 &self,
166 _code: &str,
167 _email_address: &str,
168 _device_id: Option<&str>,
169 ) -> Result<Invite> {
170 unimplemented!()
171 }
172
173 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
174 unimplemented!()
175 }
176}
177
178impl Db<sqlx::Postgres> {
179 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
180 let pool = sqlx::postgres::PgPoolOptions::new()
181 .max_connections(max_connections)
182 .connect(url)
183 .await?;
184 Ok(Self {
185 pool,
186 #[cfg(test)]
187 background: None,
188 #[cfg(test)]
189 runtime: None,
190 })
191 }
192
193 #[cfg(test)]
194 pub fn teardown(&self, url: &str) {
195 self.runtime.as_ref().unwrap().block_on(async {
196 use util::ResultExt;
197 let query = "
198 SELECT pg_terminate_backend(pg_stat_activity.pid)
199 FROM pg_stat_activity
200 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
201 ";
202 sqlx::query(query).execute(&self.pool).await.log_err();
203 self.pool.close().await;
204 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
205 .await
206 .log_err();
207 })
208 }
209
210 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
211 test_support!(self, {
212 let like_string = Self::fuzzy_like_string(name_query);
213 let query = "
214 SELECT users.*
215 FROM users
216 WHERE github_login ILIKE $1
217 ORDER BY github_login <-> $2
218 LIMIT $3
219 ";
220 Ok(sqlx::query_as(query)
221 .bind(like_string)
222 .bind(name_query)
223 .bind(limit as i32)
224 .fetch_all(&self.pool)
225 .await?)
226 })
227 }
228
229 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
230 test_support!(self, {
231 let query = "
232 SELECT users.*
233 FROM users
234 WHERE users.id = ANY ($1)
235 ";
236 Ok(sqlx::query_as(query)
237 .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
238 .fetch_all(&self.pool)
239 .await?)
240 })
241 }
242
243 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
244 test_support!(self, {
245 let query = "
246 SELECT metrics_id::text
247 FROM users
248 WHERE id = $1
249 ";
250 Ok(sqlx::query_scalar(query)
251 .bind(id)
252 .fetch_one(&self.pool)
253 .await?)
254 })
255 }
256
257 pub async fn create_user(
258 &self,
259 email_address: &str,
260 admin: bool,
261 params: NewUserParams,
262 ) -> Result<NewUserResult> {
263 test_support!(self, {
264 let query = "
265 INSERT INTO users (email_address, github_login, github_user_id, admin)
266 VALUES ($1, $2, $3, $4)
267 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
268 RETURNING id, metrics_id::text
269 ";
270
271 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
272 .bind(email_address)
273 .bind(params.github_login)
274 .bind(params.github_user_id)
275 .bind(admin)
276 .fetch_one(&self.pool)
277 .await?;
278 Ok(NewUserResult {
279 user_id,
280 metrics_id,
281 signup_device_id: None,
282 inviting_user_id: None,
283 })
284 })
285 }
286
287 pub async fn create_user_from_invite(
288 &self,
289 invite: &Invite,
290 user: NewUserParams,
291 ) -> Result<Option<NewUserResult>> {
292 test_support!(self, {
293 let mut tx = self.pool.begin().await?;
294
295 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
296 i32,
297 Option<UserId>,
298 Option<UserId>,
299 Option<String>,
300 ) = sqlx::query_as(
301 "
302 SELECT id, user_id, inviting_user_id, device_id
303 FROM signups
304 WHERE
305 email_address = $1 AND
306 email_confirmation_code = $2
307 ",
308 )
309 .bind(&invite.email_address)
310 .bind(&invite.email_confirmation_code)
311 .fetch_optional(&mut tx)
312 .await?
313 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
314
315 if existing_user_id.is_some() {
316 return Ok(None);
317 }
318
319 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
320 "
321 INSERT INTO users
322 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
323 VALUES
324 ($1, $2, $3, FALSE, $4, $5)
325 ON CONFLICT (github_login) DO UPDATE SET
326 email_address = excluded.email_address,
327 github_user_id = excluded.github_user_id,
328 admin = excluded.admin
329 RETURNING id, metrics_id::text
330 ",
331 )
332 .bind(&invite.email_address)
333 .bind(&user.github_login)
334 .bind(&user.github_user_id)
335 .bind(&user.invite_count)
336 .bind(random_invite_code())
337 .fetch_one(&mut tx)
338 .await?;
339
340 sqlx::query(
341 "
342 UPDATE signups
343 SET user_id = $1
344 WHERE id = $2
345 ",
346 )
347 .bind(&user_id)
348 .bind(&signup_id)
349 .execute(&mut tx)
350 .await?;
351
352 if let Some(inviting_user_id) = inviting_user_id {
353 sqlx::query(
354 "
355 INSERT INTO contacts
356 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
357 VALUES
358 ($1, $2, TRUE, TRUE, TRUE)
359 ON CONFLICT DO NOTHING
360 ",
361 )
362 .bind(inviting_user_id)
363 .bind(user_id)
364 .execute(&mut tx)
365 .await?;
366 }
367
368 tx.commit().await?;
369 Ok(Some(NewUserResult {
370 user_id,
371 metrics_id,
372 inviting_user_id,
373 signup_device_id,
374 }))
375 })
376 }
377
378 pub async fn create_signup(&self, signup: &Signup) -> Result<()> {
379 test_support!(self, {
380 sqlx::query(
381 "
382 INSERT INTO signups
383 (
384 email_address,
385 email_confirmation_code,
386 email_confirmation_sent,
387 platform_linux,
388 platform_mac,
389 platform_windows,
390 platform_unknown,
391 editor_features,
392 programming_languages,
393 device_id
394 )
395 VALUES
396 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
397 ON CONFLICT (email_address) DO UPDATE SET
398 email_address = excluded.email_address
399 RETURNING id
400 ",
401 )
402 .bind(&signup.email_address)
403 .bind(&random_email_confirmation_code())
404 .bind(&signup.platform_linux)
405 .bind(&signup.platform_mac)
406 .bind(&signup.platform_windows)
407 .bind(&signup.editor_features)
408 .bind(&signup.programming_languages)
409 .bind(&signup.device_id)
410 .execute(&self.pool)
411 .await?;
412 Ok(())
413 })
414 }
415
416 pub async fn create_invite_from_code(
417 &self,
418 code: &str,
419 email_address: &str,
420 device_id: Option<&str>,
421 ) -> Result<Invite> {
422 test_support!(self, {
423 let mut tx = self.pool.begin().await?;
424
425 let existing_user: Option<UserId> = sqlx::query_scalar(
426 "
427 SELECT id
428 FROM users
429 WHERE email_address = $1
430 ",
431 )
432 .bind(email_address)
433 .fetch_optional(&mut tx)
434 .await?;
435 if existing_user.is_some() {
436 Err(anyhow!("email address is already in use"))?;
437 }
438
439 let inviting_user_id_with_invites: Option<UserId> = sqlx::query_scalar(
440 "
441 UPDATE users
442 SET invite_count = invite_count - 1
443 WHERE invite_code = $1 AND invite_count > 0
444 RETURNING id
445 ",
446 )
447 .bind(code)
448 .fetch_optional(&mut tx)
449 .await?;
450
451 let Some(inviter_id) = inviting_user_id_with_invites else {
452 return Err(Error::Http(
453 StatusCode::UNAUTHORIZED,
454 "unable to find an invite code with invites remaining".to_string(),
455 ));
456 };
457
458 let email_confirmation_code: String = sqlx::query_scalar(
459 "
460 INSERT INTO signups
461 (
462 email_address,
463 email_confirmation_code,
464 email_confirmation_sent,
465 inviting_user_id,
466 platform_linux,
467 platform_mac,
468 platform_windows,
469 platform_unknown,
470 device_id
471 )
472 VALUES
473 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
474 ON CONFLICT (email_address)
475 DO UPDATE SET
476 inviting_user_id = excluded.inviting_user_id
477 RETURNING email_confirmation_code
478 ",
479 )
480 .bind(&email_address)
481 .bind(&random_email_confirmation_code())
482 .bind(&inviter_id)
483 .bind(&device_id)
484 .fetch_one(&mut tx)
485 .await?;
486
487 tx.commit().await?;
488
489 Ok(Invite {
490 email_address: email_address.into(),
491 email_confirmation_code,
492 })
493 })
494 }
495
496 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
497 test_support!(self, {
498 let emails = invites
499 .iter()
500 .map(|s| s.email_address.as_str())
501 .collect::<Vec<_>>();
502 sqlx::query(
503 "
504 UPDATE signups
505 SET email_confirmation_sent = TRUE
506 WHERE email_address = ANY ($1)
507 ",
508 )
509 .bind(&emails)
510 .execute(&self.pool)
511 .await?;
512 Ok(())
513 })
514 }
515}
516
517impl<D> Db<D>
518where
519 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
520 D::Connection: sqlx::migrate::Migrate,
521 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
522 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
523 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
524 D::QueryResult: RowsAffected,
525 String: sqlx::Type<D>,
526 i32: sqlx::Type<D>,
527 i64: sqlx::Type<D>,
528 bool: sqlx::Type<D>,
529 str: sqlx::Type<D>,
530 Uuid: sqlx::Type<D>,
531 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
532 OffsetDateTime: sqlx::Type<D>,
533 PrimitiveDateTime: sqlx::Type<D>,
534 usize: sqlx::ColumnIndex<D::Row>,
535 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
536 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
537 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
538 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
539 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
540 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
541 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
542 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
543 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
544 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
545 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
546 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
547{
548 pub async fn migrate(
549 &self,
550 migrations_path: &Path,
551 ignore_checksum_mismatch: bool,
552 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
553 let migrations = MigrationSource::resolve(migrations_path)
554 .await
555 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
556
557 let mut conn = self.pool.acquire().await?;
558
559 conn.ensure_migrations_table().await?;
560 let applied_migrations: HashMap<_, _> = conn
561 .list_applied_migrations()
562 .await?
563 .into_iter()
564 .map(|m| (m.version, m))
565 .collect();
566
567 let mut new_migrations = Vec::new();
568 for migration in migrations {
569 match applied_migrations.get(&migration.version) {
570 Some(applied_migration) => {
571 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
572 {
573 Err(anyhow!(
574 "checksum mismatch for applied migration {}",
575 migration.description
576 ))?;
577 }
578 }
579 None => {
580 let elapsed = conn.apply(&migration).await?;
581 new_migrations.push((migration, elapsed));
582 }
583 }
584 }
585
586 Ok(new_migrations)
587 }
588
589 pub fn fuzzy_like_string(string: &str) -> String {
590 let mut result = String::with_capacity(string.len() * 2 + 1);
591 for c in string.chars() {
592 if c.is_alphanumeric() {
593 result.push('%');
594 result.push(c);
595 }
596 }
597 result.push('%');
598 result
599 }
600
601 // users
602
603 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
604 test_support!(self, {
605 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
606 Ok(sqlx::query_as(query)
607 .bind(limit as i32)
608 .bind((page * limit) as i32)
609 .fetch_all(&self.pool)
610 .await?)
611 })
612 }
613
614 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
615 test_support!(self, {
616 let query = "
617 SELECT users.*
618 FROM users
619 WHERE id = $1
620 LIMIT 1
621 ";
622 Ok(sqlx::query_as(query)
623 .bind(&id)
624 .fetch_optional(&self.pool)
625 .await?)
626 })
627 }
628
629 pub async fn get_users_with_no_invites(
630 &self,
631 invited_by_another_user: bool,
632 ) -> Result<Vec<User>> {
633 test_support!(self, {
634 let query = format!(
635 "
636 SELECT users.*
637 FROM users
638 WHERE invite_count = 0
639 AND inviter_id IS{} NULL
640 ",
641 if invited_by_another_user { " NOT" } else { "" }
642 );
643
644 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
645 })
646 }
647
648 pub async fn get_user_by_github_account(
649 &self,
650 github_login: &str,
651 github_user_id: Option<i32>,
652 ) -> Result<Option<User>> {
653 test_support!(self, {
654 if let Some(github_user_id) = github_user_id {
655 let mut user = sqlx::query_as::<_, User>(
656 "
657 UPDATE users
658 SET github_login = $1
659 WHERE github_user_id = $2
660 RETURNING *
661 ",
662 )
663 .bind(github_login)
664 .bind(github_user_id)
665 .fetch_optional(&self.pool)
666 .await?;
667
668 if user.is_none() {
669 user = sqlx::query_as::<_, User>(
670 "
671 UPDATE users
672 SET github_user_id = $1
673 WHERE github_login = $2
674 RETURNING *
675 ",
676 )
677 .bind(github_user_id)
678 .bind(github_login)
679 .fetch_optional(&self.pool)
680 .await?;
681 }
682
683 Ok(user)
684 } else {
685 let user = sqlx::query_as(
686 "
687 SELECT * FROM users
688 WHERE github_login = $1
689 LIMIT 1
690 ",
691 )
692 .bind(github_login)
693 .fetch_optional(&self.pool)
694 .await?;
695 Ok(user)
696 }
697 })
698 }
699
700 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
701 test_support!(self, {
702 let query = "UPDATE users SET admin = $1 WHERE id = $2";
703 Ok(sqlx::query(query)
704 .bind(is_admin)
705 .bind(id.0)
706 .execute(&self.pool)
707 .await
708 .map(drop)?)
709 })
710 }
711
712 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
713 test_support!(self, {
714 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
715 Ok(sqlx::query(query)
716 .bind(connected_once)
717 .bind(id.0)
718 .execute(&self.pool)
719 .await
720 .map(drop)?)
721 })
722 }
723
724 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
725 test_support!(self, {
726 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
727 sqlx::query(query)
728 .bind(id.0)
729 .execute(&self.pool)
730 .await
731 .map(drop)?;
732 let query = "DELETE FROM users WHERE id = $1;";
733 Ok(sqlx::query(query)
734 .bind(id.0)
735 .execute(&self.pool)
736 .await
737 .map(drop)?)
738 })
739 }
740
741 // signups
742
743 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
744 test_support!(self, {
745 Ok(sqlx::query_as(
746 "
747 SELECT
748 COUNT(*) as count,
749 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
750 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
751 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
752 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
753 FROM (
754 SELECT *
755 FROM signups
756 WHERE
757 NOT email_confirmation_sent
758 ) AS unsent
759 ",
760 )
761 .fetch_one(&self.pool)
762 .await?)
763 })
764 }
765
766 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
767 test_support!(self, {
768 Ok(sqlx::query_as(
769 "
770 SELECT
771 email_address, email_confirmation_code
772 FROM signups
773 WHERE
774 NOT email_confirmation_sent AND
775 (platform_mac OR platform_unknown)
776 LIMIT $1
777 ",
778 )
779 .bind(count as i32)
780 .fetch_all(&self.pool)
781 .await?)
782 })
783 }
784
785 // invite codes
786
787 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
788 test_support!(self, {
789 let mut tx = self.pool.begin().await?;
790 if count > 0 {
791 sqlx::query(
792 "
793 UPDATE users
794 SET invite_code = $1
795 WHERE id = $2 AND invite_code IS NULL
796 ",
797 )
798 .bind(random_invite_code())
799 .bind(id)
800 .execute(&mut tx)
801 .await?;
802 }
803
804 sqlx::query(
805 "
806 UPDATE users
807 SET invite_count = $1
808 WHERE id = $2
809 ",
810 )
811 .bind(count as i32)
812 .bind(id)
813 .execute(&mut tx)
814 .await?;
815 tx.commit().await?;
816 Ok(())
817 })
818 }
819
820 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
821 test_support!(self, {
822 let result: Option<(String, i32)> = sqlx::query_as(
823 "
824 SELECT invite_code, invite_count
825 FROM users
826 WHERE id = $1 AND invite_code IS NOT NULL
827 ",
828 )
829 .bind(id)
830 .fetch_optional(&self.pool)
831 .await?;
832 if let Some((code, count)) = result {
833 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
834 } else {
835 Ok(None)
836 }
837 })
838 }
839
840 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
841 test_support!(self, {
842 sqlx::query_as(
843 "
844 SELECT *
845 FROM users
846 WHERE invite_code = $1
847 ",
848 )
849 .bind(code)
850 .fetch_optional(&self.pool)
851 .await?
852 .ok_or_else(|| {
853 Error::Http(
854 StatusCode::NOT_FOUND,
855 "that invite code does not exist".to_string(),
856 )
857 })
858 })
859 }
860
861 // projects
862
863 /// Registers a new project for the given user.
864 pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
865 test_support!(self, {
866 Ok(sqlx::query_scalar(
867 "
868 INSERT INTO projects(host_user_id)
869 VALUES ($1)
870 RETURNING id
871 ",
872 )
873 .bind(host_user_id)
874 .fetch_one(&self.pool)
875 .await
876 .map(ProjectId)?)
877 })
878 }
879
880 /// Unregisters a project for the given project id.
881 pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
882 test_support!(self, {
883 sqlx::query(
884 "
885 UPDATE projects
886 SET unregistered = TRUE
887 WHERE id = $1
888 ",
889 )
890 .bind(project_id)
891 .execute(&self.pool)
892 .await?;
893 Ok(())
894 })
895 }
896
897 // contacts
898
899 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
900 test_support!(self, {
901 let query = "
902 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
903 FROM contacts
904 WHERE user_id_a = $1 OR user_id_b = $1;
905 ";
906
907 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
908 .bind(user_id)
909 .fetch(&self.pool);
910
911 let mut contacts = Vec::new();
912 while let Some(row) = rows.next().await {
913 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
914
915 if user_id_a == user_id {
916 if accepted {
917 contacts.push(Contact::Accepted {
918 user_id: user_id_b,
919 should_notify: should_notify && a_to_b,
920 });
921 } else if a_to_b {
922 contacts.push(Contact::Outgoing { user_id: user_id_b })
923 } else {
924 contacts.push(Contact::Incoming {
925 user_id: user_id_b,
926 should_notify,
927 });
928 }
929 } else if accepted {
930 contacts.push(Contact::Accepted {
931 user_id: user_id_a,
932 should_notify: should_notify && !a_to_b,
933 });
934 } else if a_to_b {
935 contacts.push(Contact::Incoming {
936 user_id: user_id_a,
937 should_notify,
938 });
939 } else {
940 contacts.push(Contact::Outgoing { user_id: user_id_a });
941 }
942 }
943
944 contacts.sort_unstable_by_key(|contact| contact.user_id());
945
946 Ok(contacts)
947 })
948 }
949
950 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
951 test_support!(self, {
952 let (id_a, id_b) = if user_id_1 < user_id_2 {
953 (user_id_1, user_id_2)
954 } else {
955 (user_id_2, user_id_1)
956 };
957
958 let query = "
959 SELECT 1 FROM contacts
960 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
961 LIMIT 1
962 ";
963 Ok(sqlx::query_scalar::<_, i32>(query)
964 .bind(id_a.0)
965 .bind(id_b.0)
966 .fetch_optional(&self.pool)
967 .await?
968 .is_some())
969 })
970 }
971
972 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
973 test_support!(self, {
974 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
975 (sender_id, receiver_id, true)
976 } else {
977 (receiver_id, sender_id, false)
978 };
979 let query = "
980 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
981 VALUES ($1, $2, $3, FALSE, TRUE)
982 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
983 SET
984 accepted = TRUE,
985 should_notify = FALSE
986 WHERE
987 NOT contacts.accepted AND
988 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
989 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
990 ";
991 let result = sqlx::query(query)
992 .bind(id_a.0)
993 .bind(id_b.0)
994 .bind(a_to_b)
995 .execute(&self.pool)
996 .await?;
997
998 if result.rows_affected() == 1 {
999 Ok(())
1000 } else {
1001 Err(anyhow!("contact already requested"))?
1002 }
1003 })
1004 }
1005
1006 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1007 test_support!(self, {
1008 let (id_a, id_b) = if responder_id < requester_id {
1009 (responder_id, requester_id)
1010 } else {
1011 (requester_id, responder_id)
1012 };
1013 let query = "
1014 DELETE FROM contacts
1015 WHERE user_id_a = $1 AND user_id_b = $2;
1016 ";
1017 let result = sqlx::query(query)
1018 .bind(id_a.0)
1019 .bind(id_b.0)
1020 .execute(&self.pool)
1021 .await?;
1022
1023 if result.rows_affected() == 1 {
1024 Ok(())
1025 } else {
1026 Err(anyhow!("no such contact"))?
1027 }
1028 })
1029 }
1030
1031 pub async fn dismiss_contact_notification(
1032 &self,
1033 user_id: UserId,
1034 contact_user_id: UserId,
1035 ) -> Result<()> {
1036 test_support!(self, {
1037 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1038 (user_id, contact_user_id, true)
1039 } else {
1040 (contact_user_id, user_id, false)
1041 };
1042
1043 let query = "
1044 UPDATE contacts
1045 SET should_notify = FALSE
1046 WHERE
1047 user_id_a = $1 AND user_id_b = $2 AND
1048 (
1049 (a_to_b = $3 AND accepted) OR
1050 (a_to_b != $3 AND NOT accepted)
1051 );
1052 ";
1053
1054 let result = sqlx::query(query)
1055 .bind(id_a.0)
1056 .bind(id_b.0)
1057 .bind(a_to_b)
1058 .execute(&self.pool)
1059 .await?;
1060
1061 if result.rows_affected() == 0 {
1062 Err(anyhow!("no such contact request"))?;
1063 }
1064
1065 Ok(())
1066 })
1067 }
1068
1069 pub async fn respond_to_contact_request(
1070 &self,
1071 responder_id: UserId,
1072 requester_id: UserId,
1073 accept: bool,
1074 ) -> Result<()> {
1075 test_support!(self, {
1076 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1077 (responder_id, requester_id, false)
1078 } else {
1079 (requester_id, responder_id, true)
1080 };
1081 let result = if accept {
1082 let query = "
1083 UPDATE contacts
1084 SET accepted = TRUE, should_notify = TRUE
1085 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1086 ";
1087 sqlx::query(query)
1088 .bind(id_a.0)
1089 .bind(id_b.0)
1090 .bind(a_to_b)
1091 .execute(&self.pool)
1092 .await?
1093 } else {
1094 let query = "
1095 DELETE FROM contacts
1096 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1097 ";
1098 sqlx::query(query)
1099 .bind(id_a.0)
1100 .bind(id_b.0)
1101 .bind(a_to_b)
1102 .execute(&self.pool)
1103 .await?
1104 };
1105 if result.rows_affected() == 1 {
1106 Ok(())
1107 } else {
1108 Err(anyhow!("no such contact request"))?
1109 }
1110 })
1111 }
1112
1113 // access tokens
1114
1115 pub async fn create_access_token_hash(
1116 &self,
1117 user_id: UserId,
1118 access_token_hash: &str,
1119 max_access_token_count: usize,
1120 ) -> Result<()> {
1121 test_support!(self, {
1122 let insert_query = "
1123 INSERT INTO access_tokens (user_id, hash)
1124 VALUES ($1, $2);
1125 ";
1126 let cleanup_query = "
1127 DELETE FROM access_tokens
1128 WHERE id IN (
1129 SELECT id from access_tokens
1130 WHERE user_id = $1
1131 ORDER BY id DESC
1132 LIMIT 10000
1133 OFFSET $3
1134 )
1135 ";
1136
1137 let mut tx = self.pool.begin().await?;
1138 sqlx::query(insert_query)
1139 .bind(user_id.0)
1140 .bind(access_token_hash)
1141 .execute(&mut tx)
1142 .await?;
1143 sqlx::query(cleanup_query)
1144 .bind(user_id.0)
1145 .bind(access_token_hash)
1146 .bind(max_access_token_count as i32)
1147 .execute(&mut tx)
1148 .await?;
1149 Ok(tx.commit().await?)
1150 })
1151 }
1152
1153 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1154 test_support!(self, {
1155 let query = "
1156 SELECT hash
1157 FROM access_tokens
1158 WHERE user_id = $1
1159 ORDER BY id DESC
1160 ";
1161 Ok(sqlx::query_scalar(query)
1162 .bind(user_id.0)
1163 .fetch_all(&self.pool)
1164 .await?)
1165 })
1166 }
1167}
1168
1169macro_rules! id_type {
1170 ($name:ident) => {
1171 #[derive(
1172 Clone,
1173 Copy,
1174 Debug,
1175 Default,
1176 PartialEq,
1177 Eq,
1178 PartialOrd,
1179 Ord,
1180 Hash,
1181 sqlx::Type,
1182 Serialize,
1183 Deserialize,
1184 )]
1185 #[sqlx(transparent)]
1186 #[serde(transparent)]
1187 pub struct $name(pub i32);
1188
1189 impl $name {
1190 #[allow(unused)]
1191 pub const MAX: Self = Self(i32::MAX);
1192
1193 #[allow(unused)]
1194 pub fn from_proto(value: u64) -> Self {
1195 Self(value as i32)
1196 }
1197
1198 #[allow(unused)]
1199 pub fn to_proto(self) -> u64 {
1200 self.0 as u64
1201 }
1202 }
1203
1204 impl std::fmt::Display for $name {
1205 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1206 self.0.fmt(f)
1207 }
1208 }
1209 };
1210}
1211
1212id_type!(UserId);
1213#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1214pub struct User {
1215 pub id: UserId,
1216 pub github_login: String,
1217 pub github_user_id: Option<i32>,
1218 pub email_address: Option<String>,
1219 pub admin: bool,
1220 pub invite_code: Option<String>,
1221 pub invite_count: i32,
1222 pub connected_once: bool,
1223}
1224
1225id_type!(ProjectId);
1226#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1227pub struct Project {
1228 pub id: ProjectId,
1229 pub host_user_id: UserId,
1230 pub unregistered: bool,
1231}
1232
1233#[derive(Clone, Debug, PartialEq, Eq)]
1234pub enum Contact {
1235 Accepted {
1236 user_id: UserId,
1237 should_notify: bool,
1238 },
1239 Outgoing {
1240 user_id: UserId,
1241 },
1242 Incoming {
1243 user_id: UserId,
1244 should_notify: bool,
1245 },
1246}
1247
1248impl Contact {
1249 pub fn user_id(&self) -> UserId {
1250 match self {
1251 Contact::Accepted { user_id, .. } => *user_id,
1252 Contact::Outgoing { user_id } => *user_id,
1253 Contact::Incoming { user_id, .. } => *user_id,
1254 }
1255 }
1256}
1257
1258#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1259pub struct IncomingContactRequest {
1260 pub requester_id: UserId,
1261 pub should_notify: bool,
1262}
1263
1264#[derive(Clone, Deserialize, Default)]
1265pub struct Signup {
1266 pub email_address: String,
1267 pub platform_mac: bool,
1268 pub platform_windows: bool,
1269 pub platform_linux: bool,
1270 pub editor_features: Vec<String>,
1271 pub programming_languages: Vec<String>,
1272 pub device_id: Option<String>,
1273}
1274
1275#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1276pub struct WaitlistSummary {
1277 #[sqlx(default)]
1278 pub count: i64,
1279 #[sqlx(default)]
1280 pub linux_count: i64,
1281 #[sqlx(default)]
1282 pub mac_count: i64,
1283 #[sqlx(default)]
1284 pub windows_count: i64,
1285 #[sqlx(default)]
1286 pub unknown_count: i64,
1287}
1288
1289#[derive(Clone, FromRow, PartialEq, Debug, Serialize, Deserialize)]
1290pub struct Invite {
1291 pub email_address: String,
1292 pub email_confirmation_code: String,
1293}
1294
1295#[derive(Debug, Serialize, Deserialize)]
1296pub struct NewUserParams {
1297 pub github_login: String,
1298 pub github_user_id: i32,
1299 pub invite_count: i32,
1300}
1301
1302#[derive(Debug)]
1303pub struct NewUserResult {
1304 pub user_id: UserId,
1305 pub metrics_id: String,
1306 pub inviting_user_id: Option<UserId>,
1307 pub signup_device_id: Option<String>,
1308}
1309
1310fn random_invite_code() -> String {
1311 nanoid::nanoid!(16)
1312}
1313
1314fn random_email_confirmation_code() -> String {
1315 nanoid::nanoid!(64)
1316}
1317
1318#[cfg(test)]
1319pub use test::*;
1320
1321#[cfg(test)]
1322mod test {
1323 use super::*;
1324 use gpui::executor::Background;
1325 use lazy_static::lazy_static;
1326 use parking_lot::Mutex;
1327 use rand::prelude::*;
1328 use sqlx::migrate::MigrateDatabase;
1329 use std::sync::Arc;
1330
1331 pub struct SqliteTestDb {
1332 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1333 pub conn: sqlx::sqlite::SqliteConnection,
1334 }
1335
1336 pub struct PostgresTestDb {
1337 pub db: Option<Arc<Db<sqlx::Postgres>>>,
1338 pub url: String,
1339 }
1340
1341 impl SqliteTestDb {
1342 pub fn new(background: Arc<Background>) -> Self {
1343 let mut rng = StdRng::from_entropy();
1344 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1345 let runtime = tokio::runtime::Builder::new_current_thread()
1346 .enable_io()
1347 .enable_time()
1348 .build()
1349 .unwrap();
1350
1351 let (mut db, conn) = runtime.block_on(async {
1352 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1353 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1354 db.migrate(migrations_path.as_ref(), false).await.unwrap();
1355 let conn = db.pool.acquire().await.unwrap().detach();
1356 (db, conn)
1357 });
1358
1359 db.background = Some(background);
1360 db.runtime = Some(runtime);
1361
1362 Self {
1363 db: Some(Arc::new(db)),
1364 conn,
1365 }
1366 }
1367
1368 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1369 self.db.as_ref().unwrap()
1370 }
1371 }
1372
1373 impl PostgresTestDb {
1374 pub fn new(background: Arc<Background>) -> Self {
1375 lazy_static! {
1376 static ref LOCK: Mutex<()> = Mutex::new(());
1377 }
1378
1379 let _guard = LOCK.lock();
1380 let mut rng = StdRng::from_entropy();
1381 let url = format!(
1382 "postgres://postgres@localhost/zed-test-{}",
1383 rng.gen::<u128>()
1384 );
1385 let runtime = tokio::runtime::Builder::new_current_thread()
1386 .enable_io()
1387 .enable_time()
1388 .build()
1389 .unwrap();
1390
1391 let mut db = runtime.block_on(async {
1392 sqlx::Postgres::create_database(&url)
1393 .await
1394 .expect("failed to create test db");
1395 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1396 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1397 db.migrate(Path::new(migrations_path), false).await.unwrap();
1398 db
1399 });
1400
1401 db.background = Some(background);
1402 db.runtime = Some(runtime);
1403
1404 Self {
1405 db: Some(Arc::new(db)),
1406 url,
1407 }
1408 }
1409
1410 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1411 self.db.as_ref().unwrap()
1412 }
1413 }
1414
1415 impl Drop for PostgresTestDb {
1416 fn drop(&mut self) {
1417 let db = self.db.take().unwrap();
1418 db.teardown(&self.url);
1419 }
1420 }
1421}