1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use sqlx::{
8 migrate::{Migrate as _, Migration, MigrationSource},
9 types::Uuid,
10 FromRow,
11};
12use std::{path::Path, time::Duration};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[cfg(test)]
16pub type DefaultDb = Db<sqlx::Sqlite>;
17
18#[cfg(not(test))]
19pub type DefaultDb = Db<sqlx::Postgres>;
20
21pub struct Db<D: sqlx::Database> {
22 pool: sqlx::Pool<D>,
23 #[cfg(test)]
24 background: Option<std::sync::Arc<gpui::executor::Background>>,
25 #[cfg(test)]
26 runtime: Option<tokio::runtime::Runtime>,
27}
28
29macro_rules! test_support {
30 ($self:ident, { $($token:tt)* }) => {{
31 let body = async {
32 $($token)*
33 };
34
35 if cfg!(test) {
36 #[cfg(not(test))]
37 unreachable!();
38
39 #[cfg(test)]
40 if let Some(background) = $self.background.as_ref() {
41 background.simulate_random_delay().await;
42 }
43
44 #[cfg(test)]
45 $self.runtime.as_ref().unwrap().block_on(body)
46 } else {
47 body.await
48 }
49 }};
50}
51
52pub trait RowsAffected {
53 fn rows_affected(&self) -> u64;
54}
55
56impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
57 fn rows_affected(&self) -> u64 {
58 self.rows_affected()
59 }
60}
61
62impl RowsAffected for sqlx::postgres::PgQueryResult {
63 fn rows_affected(&self) -> u64 {
64 self.rows_affected()
65 }
66}
67
68#[cfg(test)]
69impl Db<sqlx::Sqlite> {
70 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
71 use std::str::FromStr as _;
72 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
73 .unwrap()
74 .create_if_missing(true)
75 .shared_cache(true);
76 let pool = sqlx::sqlite::SqlitePoolOptions::new()
77 .min_connections(2)
78 .max_connections(max_connections)
79 .connect_with(options)
80 .await?;
81 Ok(Self {
82 pool,
83 background: None,
84 runtime: None,
85 })
86 }
87
88 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
89 test_support!(self, {
90 let query = "
91 SELECT users.*
92 FROM users
93 WHERE users.id IN (SELECT value from json_each($1))
94 ";
95 Ok(sqlx::query_as(query)
96 .bind(&serde_json::json!(ids))
97 .fetch_all(&self.pool)
98 .await?)
99 })
100 }
101
102 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
103 test_support!(self, {
104 let query = "
105 SELECT metrics_id
106 FROM users
107 WHERE id = $1
108 ";
109 Ok(sqlx::query_scalar(query)
110 .bind(id)
111 .fetch_one(&self.pool)
112 .await?)
113 })
114 }
115
116 pub async fn create_user(
117 &self,
118 email_address: &str,
119 admin: bool,
120 params: NewUserParams,
121 ) -> Result<NewUserResult> {
122 test_support!(self, {
123 let query = "
124 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
125 VALUES ($1, $2, $3, $4, $5)
126 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
127 RETURNING id, metrics_id
128 ";
129
130 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
131 .bind(email_address)
132 .bind(params.github_login)
133 .bind(params.github_user_id)
134 .bind(admin)
135 .bind(Uuid::new_v4().to_string())
136 .fetch_one(&self.pool)
137 .await?;
138 Ok(NewUserResult {
139 user_id,
140 metrics_id,
141 signup_device_id: None,
142 inviting_user_id: None,
143 })
144 })
145 }
146
147 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
148 unimplemented!()
149 }
150
151 pub async fn create_user_from_invite(
152 &self,
153 _invite: &Invite,
154 _user: NewUserParams,
155 ) -> Result<Option<NewUserResult>> {
156 unimplemented!()
157 }
158
159 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
160 unimplemented!()
161 }
162
163 pub async fn create_invite_from_code(
164 &self,
165 _code: &str,
166 _email_address: &str,
167 _device_id: Option<&str>,
168 ) -> Result<Invite> {
169 unimplemented!()
170 }
171
172 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
173 unimplemented!()
174 }
175}
176
177impl Db<sqlx::Postgres> {
178 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
179 let pool = sqlx::postgres::PgPoolOptions::new()
180 .max_connections(max_connections)
181 .connect(url)
182 .await?;
183 Ok(Self {
184 pool,
185 #[cfg(test)]
186 background: None,
187 #[cfg(test)]
188 runtime: None,
189 })
190 }
191
192 #[cfg(test)]
193 pub fn teardown(&self, url: &str) {
194 self.runtime.as_ref().unwrap().block_on(async {
195 use util::ResultExt;
196 let query = "
197 SELECT pg_terminate_backend(pg_stat_activity.pid)
198 FROM pg_stat_activity
199 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
200 ";
201 sqlx::query(query).execute(&self.pool).await.log_err();
202 self.pool.close().await;
203 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
204 .await
205 .log_err();
206 })
207 }
208
209 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
210 test_support!(self, {
211 let like_string = Self::fuzzy_like_string(name_query);
212 let query = "
213 SELECT users.*
214 FROM users
215 WHERE github_login ILIKE $1
216 ORDER BY github_login <-> $2
217 LIMIT $3
218 ";
219 Ok(sqlx::query_as(query)
220 .bind(like_string)
221 .bind(name_query)
222 .bind(limit as i32)
223 .fetch_all(&self.pool)
224 .await?)
225 })
226 }
227
228 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
229 test_support!(self, {
230 let query = "
231 SELECT users.*
232 FROM users
233 WHERE users.id = ANY ($1)
234 ";
235 Ok(sqlx::query_as(query)
236 .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
237 .fetch_all(&self.pool)
238 .await?)
239 })
240 }
241
242 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
243 test_support!(self, {
244 let query = "
245 SELECT metrics_id::text
246 FROM users
247 WHERE id = $1
248 ";
249 Ok(sqlx::query_scalar(query)
250 .bind(id)
251 .fetch_one(&self.pool)
252 .await?)
253 })
254 }
255
256 pub async fn create_user(
257 &self,
258 email_address: &str,
259 admin: bool,
260 params: NewUserParams,
261 ) -> Result<NewUserResult> {
262 test_support!(self, {
263 let query = "
264 INSERT INTO users (email_address, github_login, github_user_id, admin)
265 VALUES ($1, $2, $3, $4)
266 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
267 RETURNING id, metrics_id::text
268 ";
269
270 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
271 .bind(email_address)
272 .bind(params.github_login)
273 .bind(params.github_user_id)
274 .bind(admin)
275 .fetch_one(&self.pool)
276 .await?;
277 Ok(NewUserResult {
278 user_id,
279 metrics_id,
280 signup_device_id: None,
281 inviting_user_id: None,
282 })
283 })
284 }
285
286 pub async fn create_user_from_invite(
287 &self,
288 invite: &Invite,
289 user: NewUserParams,
290 ) -> Result<Option<NewUserResult>> {
291 test_support!(self, {
292 let mut tx = self.pool.begin().await?;
293
294 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
295 i32,
296 Option<UserId>,
297 Option<UserId>,
298 Option<String>,
299 ) = sqlx::query_as(
300 "
301 SELECT id, user_id, inviting_user_id, device_id
302 FROM signups
303 WHERE
304 email_address = $1 AND
305 email_confirmation_code = $2
306 ",
307 )
308 .bind(&invite.email_address)
309 .bind(&invite.email_confirmation_code)
310 .fetch_optional(&mut tx)
311 .await?
312 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
313
314 if existing_user_id.is_some() {
315 return Ok(None);
316 }
317
318 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
319 "
320 INSERT INTO users
321 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
322 VALUES
323 ($1, $2, $3, FALSE, $4, $5)
324 ON CONFLICT (github_login) DO UPDATE SET
325 email_address = excluded.email_address,
326 github_user_id = excluded.github_user_id,
327 admin = excluded.admin
328 RETURNING id, metrics_id::text
329 ",
330 )
331 .bind(&invite.email_address)
332 .bind(&user.github_login)
333 .bind(&user.github_user_id)
334 .bind(&user.invite_count)
335 .bind(random_invite_code())
336 .fetch_one(&mut tx)
337 .await?;
338
339 sqlx::query(
340 "
341 UPDATE signups
342 SET user_id = $1
343 WHERE id = $2
344 ",
345 )
346 .bind(&user_id)
347 .bind(&signup_id)
348 .execute(&mut tx)
349 .await?;
350
351 if let Some(inviting_user_id) = inviting_user_id {
352 let id: Option<UserId> = sqlx::query_scalar(
353 "
354 UPDATE users
355 SET invite_count = invite_count - 1
356 WHERE id = $1 AND invite_count > 0
357 RETURNING id
358 ",
359 )
360 .bind(&inviting_user_id)
361 .fetch_optional(&mut tx)
362 .await?;
363
364 if id.is_none() {
365 Err(Error::Http(
366 StatusCode::UNAUTHORIZED,
367 "no invites remaining".to_string(),
368 ))?;
369 }
370
371 sqlx::query(
372 "
373 INSERT INTO contacts
374 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
375 VALUES
376 ($1, $2, TRUE, TRUE, TRUE)
377 ON CONFLICT DO NOTHING
378 ",
379 )
380 .bind(inviting_user_id)
381 .bind(user_id)
382 .execute(&mut tx)
383 .await?;
384 }
385
386 tx.commit().await?;
387 Ok(Some(NewUserResult {
388 user_id,
389 metrics_id,
390 inviting_user_id,
391 signup_device_id,
392 }))
393 })
394 }
395
396 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
397 test_support!(self, {
398 sqlx::query(
399 "
400 INSERT INTO signups
401 (
402 email_address,
403 email_confirmation_code,
404 email_confirmation_sent,
405 platform_linux,
406 platform_mac,
407 platform_windows,
408 platform_unknown,
409 editor_features,
410 programming_languages,
411 device_id
412 )
413 VALUES
414 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
415 RETURNING id
416 ",
417 )
418 .bind(&signup.email_address)
419 .bind(&random_email_confirmation_code())
420 .bind(&signup.platform_linux)
421 .bind(&signup.platform_mac)
422 .bind(&signup.platform_windows)
423 .bind(&signup.editor_features)
424 .bind(&signup.programming_languages)
425 .bind(&signup.device_id)
426 .execute(&self.pool)
427 .await?;
428 Ok(())
429 })
430 }
431
432 pub async fn create_invite_from_code(
433 &self,
434 code: &str,
435 email_address: &str,
436 device_id: Option<&str>,
437 ) -> Result<Invite> {
438 test_support!(self, {
439 let mut tx = self.pool.begin().await?;
440
441 let existing_user: Option<UserId> = sqlx::query_scalar(
442 "
443 SELECT id
444 FROM users
445 WHERE email_address = $1
446 ",
447 )
448 .bind(email_address)
449 .fetch_optional(&mut tx)
450 .await?;
451 if existing_user.is_some() {
452 Err(anyhow!("email address is already in use"))?;
453 }
454
455 let row: Option<(UserId, i32)> = sqlx::query_as(
456 "
457 SELECT id, invite_count
458 FROM users
459 WHERE invite_code = $1
460 ",
461 )
462 .bind(code)
463 .fetch_optional(&mut tx)
464 .await?;
465
466 let (inviter_id, invite_count) = match row {
467 Some(row) => row,
468 None => Err(Error::Http(
469 StatusCode::NOT_FOUND,
470 "invite code not found".to_string(),
471 ))?,
472 };
473
474 if invite_count == 0 {
475 Err(Error::Http(
476 StatusCode::UNAUTHORIZED,
477 "no invites remaining".to_string(),
478 ))?;
479 }
480
481 let email_confirmation_code: String = sqlx::query_scalar(
482 "
483 INSERT INTO signups
484 (
485 email_address,
486 email_confirmation_code,
487 email_confirmation_sent,
488 inviting_user_id,
489 platform_linux,
490 platform_mac,
491 platform_windows,
492 platform_unknown,
493 device_id
494 )
495 VALUES
496 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
497 ON CONFLICT (email_address)
498 DO UPDATE SET
499 inviting_user_id = excluded.inviting_user_id
500 RETURNING email_confirmation_code
501 ",
502 )
503 .bind(&email_address)
504 .bind(&random_email_confirmation_code())
505 .bind(&inviter_id)
506 .bind(&device_id)
507 .fetch_one(&mut tx)
508 .await?;
509
510 tx.commit().await?;
511
512 Ok(Invite {
513 email_address: email_address.into(),
514 email_confirmation_code,
515 })
516 })
517 }
518
519 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
520 test_support!(self, {
521 let emails = invites
522 .iter()
523 .map(|s| s.email_address.as_str())
524 .collect::<Vec<_>>();
525 sqlx::query(
526 "
527 UPDATE signups
528 SET email_confirmation_sent = TRUE
529 WHERE email_address = ANY ($1)
530 ",
531 )
532 .bind(&emails)
533 .execute(&self.pool)
534 .await?;
535 Ok(())
536 })
537 }
538}
539
540impl<D> Db<D>
541where
542 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
543 D::Connection: sqlx::migrate::Migrate,
544 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
545 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
546 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
547 D::QueryResult: RowsAffected,
548 String: sqlx::Type<D>,
549 i32: sqlx::Type<D>,
550 i64: sqlx::Type<D>,
551 bool: sqlx::Type<D>,
552 str: sqlx::Type<D>,
553 Uuid: sqlx::Type<D>,
554 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
555 OffsetDateTime: sqlx::Type<D>,
556 PrimitiveDateTime: sqlx::Type<D>,
557 usize: sqlx::ColumnIndex<D::Row>,
558 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
559 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
560 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
561 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
562 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
563 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
564 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
565 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
566 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
567 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
568 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
569 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
570{
571 pub async fn migrate(
572 &self,
573 migrations_path: &Path,
574 ignore_checksum_mismatch: bool,
575 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
576 let migrations = MigrationSource::resolve(migrations_path)
577 .await
578 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
579
580 let mut conn = self.pool.acquire().await?;
581
582 conn.ensure_migrations_table().await?;
583 let applied_migrations: HashMap<_, _> = conn
584 .list_applied_migrations()
585 .await?
586 .into_iter()
587 .map(|m| (m.version, m))
588 .collect();
589
590 let mut new_migrations = Vec::new();
591 for migration in migrations {
592 match applied_migrations.get(&migration.version) {
593 Some(applied_migration) => {
594 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
595 {
596 Err(anyhow!(
597 "checksum mismatch for applied migration {}",
598 migration.description
599 ))?;
600 }
601 }
602 None => {
603 let elapsed = conn.apply(&migration).await?;
604 new_migrations.push((migration, elapsed));
605 }
606 }
607 }
608
609 Ok(new_migrations)
610 }
611
612 pub fn fuzzy_like_string(string: &str) -> String {
613 let mut result = String::with_capacity(string.len() * 2 + 1);
614 for c in string.chars() {
615 if c.is_alphanumeric() {
616 result.push('%');
617 result.push(c);
618 }
619 }
620 result.push('%');
621 result
622 }
623
624 // users
625
626 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
627 test_support!(self, {
628 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
629 Ok(sqlx::query_as(query)
630 .bind(limit as i32)
631 .bind((page * limit) as i32)
632 .fetch_all(&self.pool)
633 .await?)
634 })
635 }
636
637 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
638 test_support!(self, {
639 let query = "
640 SELECT users.*
641 FROM users
642 WHERE id = $1
643 LIMIT 1
644 ";
645 Ok(sqlx::query_as(query)
646 .bind(&id)
647 .fetch_optional(&self.pool)
648 .await?)
649 })
650 }
651
652 pub async fn get_users_with_no_invites(
653 &self,
654 invited_by_another_user: bool,
655 ) -> Result<Vec<User>> {
656 test_support!(self, {
657 let query = format!(
658 "
659 SELECT users.*
660 FROM users
661 WHERE invite_count = 0
662 AND inviter_id IS{} NULL
663 ",
664 if invited_by_another_user { " NOT" } else { "" }
665 );
666
667 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
668 })
669 }
670
671 pub async fn get_user_by_github_account(
672 &self,
673 github_login: &str,
674 github_user_id: Option<i32>,
675 ) -> Result<Option<User>> {
676 test_support!(self, {
677 if let Some(github_user_id) = github_user_id {
678 let mut user = sqlx::query_as::<_, User>(
679 "
680 UPDATE users
681 SET github_login = $1
682 WHERE github_user_id = $2
683 RETURNING *
684 ",
685 )
686 .bind(github_login)
687 .bind(github_user_id)
688 .fetch_optional(&self.pool)
689 .await?;
690
691 if user.is_none() {
692 user = sqlx::query_as::<_, User>(
693 "
694 UPDATE users
695 SET github_user_id = $1
696 WHERE github_login = $2
697 RETURNING *
698 ",
699 )
700 .bind(github_user_id)
701 .bind(github_login)
702 .fetch_optional(&self.pool)
703 .await?;
704 }
705
706 Ok(user)
707 } else {
708 let user = sqlx::query_as(
709 "
710 SELECT * FROM users
711 WHERE github_login = $1
712 LIMIT 1
713 ",
714 )
715 .bind(github_login)
716 .fetch_optional(&self.pool)
717 .await?;
718 Ok(user)
719 }
720 })
721 }
722
723 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
724 test_support!(self, {
725 let query = "UPDATE users SET admin = $1 WHERE id = $2";
726 Ok(sqlx::query(query)
727 .bind(is_admin)
728 .bind(id.0)
729 .execute(&self.pool)
730 .await
731 .map(drop)?)
732 })
733 }
734
735 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
736 test_support!(self, {
737 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
738 Ok(sqlx::query(query)
739 .bind(connected_once)
740 .bind(id.0)
741 .execute(&self.pool)
742 .await
743 .map(drop)?)
744 })
745 }
746
747 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
748 test_support!(self, {
749 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
750 sqlx::query(query)
751 .bind(id.0)
752 .execute(&self.pool)
753 .await
754 .map(drop)?;
755 let query = "DELETE FROM users WHERE id = $1;";
756 Ok(sqlx::query(query)
757 .bind(id.0)
758 .execute(&self.pool)
759 .await
760 .map(drop)?)
761 })
762 }
763
764 // signups
765
766 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
767 test_support!(self, {
768 Ok(sqlx::query_as(
769 "
770 SELECT
771 COUNT(*) as count,
772 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
773 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
774 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
775 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
776 FROM (
777 SELECT *
778 FROM signups
779 WHERE
780 NOT email_confirmation_sent
781 ) AS unsent
782 ",
783 )
784 .fetch_one(&self.pool)
785 .await?)
786 })
787 }
788
789 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
790 test_support!(self, {
791 Ok(sqlx::query_as(
792 "
793 SELECT
794 email_address, email_confirmation_code
795 FROM signups
796 WHERE
797 NOT email_confirmation_sent AND
798 (platform_mac OR platform_unknown)
799 LIMIT $1
800 ",
801 )
802 .bind(count as i32)
803 .fetch_all(&self.pool)
804 .await?)
805 })
806 }
807
808 // invite codes
809
810 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
811 test_support!(self, {
812 let mut tx = self.pool.begin().await?;
813 if count > 0 {
814 sqlx::query(
815 "
816 UPDATE users
817 SET invite_code = $1
818 WHERE id = $2 AND invite_code IS NULL
819 ",
820 )
821 .bind(random_invite_code())
822 .bind(id)
823 .execute(&mut tx)
824 .await?;
825 }
826
827 sqlx::query(
828 "
829 UPDATE users
830 SET invite_count = $1
831 WHERE id = $2
832 ",
833 )
834 .bind(count as i32)
835 .bind(id)
836 .execute(&mut tx)
837 .await?;
838 tx.commit().await?;
839 Ok(())
840 })
841 }
842
843 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
844 test_support!(self, {
845 let result: Option<(String, i32)> = sqlx::query_as(
846 "
847 SELECT invite_code, invite_count
848 FROM users
849 WHERE id = $1 AND invite_code IS NOT NULL
850 ",
851 )
852 .bind(id)
853 .fetch_optional(&self.pool)
854 .await?;
855 if let Some((code, count)) = result {
856 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
857 } else {
858 Ok(None)
859 }
860 })
861 }
862
863 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
864 test_support!(self, {
865 sqlx::query_as(
866 "
867 SELECT *
868 FROM users
869 WHERE invite_code = $1
870 ",
871 )
872 .bind(code)
873 .fetch_optional(&self.pool)
874 .await?
875 .ok_or_else(|| {
876 Error::Http(
877 StatusCode::NOT_FOUND,
878 "that invite code does not exist".to_string(),
879 )
880 })
881 })
882 }
883
884 // projects
885
886 /// Registers a new project for the given user.
887 pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
888 test_support!(self, {
889 Ok(sqlx::query_scalar(
890 "
891 INSERT INTO projects(host_user_id)
892 VALUES ($1)
893 RETURNING id
894 ",
895 )
896 .bind(host_user_id)
897 .fetch_one(&self.pool)
898 .await
899 .map(ProjectId)?)
900 })
901 }
902
903 /// Unregisters a project for the given project id.
904 pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
905 test_support!(self, {
906 sqlx::query(
907 "
908 UPDATE projects
909 SET unregistered = TRUE
910 WHERE id = $1
911 ",
912 )
913 .bind(project_id)
914 .execute(&self.pool)
915 .await?;
916 Ok(())
917 })
918 }
919
920 // contacts
921
922 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
923 test_support!(self, {
924 let query = "
925 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
926 FROM contacts
927 WHERE user_id_a = $1 OR user_id_b = $1;
928 ";
929
930 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
931 .bind(user_id)
932 .fetch(&self.pool);
933
934 let mut contacts = Vec::new();
935 while let Some(row) = rows.next().await {
936 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
937
938 if user_id_a == user_id {
939 if accepted {
940 contacts.push(Contact::Accepted {
941 user_id: user_id_b,
942 should_notify: should_notify && a_to_b,
943 });
944 } else if a_to_b {
945 contacts.push(Contact::Outgoing { user_id: user_id_b })
946 } else {
947 contacts.push(Contact::Incoming {
948 user_id: user_id_b,
949 should_notify,
950 });
951 }
952 } else if accepted {
953 contacts.push(Contact::Accepted {
954 user_id: user_id_a,
955 should_notify: should_notify && !a_to_b,
956 });
957 } else if a_to_b {
958 contacts.push(Contact::Incoming {
959 user_id: user_id_a,
960 should_notify,
961 });
962 } else {
963 contacts.push(Contact::Outgoing { user_id: user_id_a });
964 }
965 }
966
967 contacts.sort_unstable_by_key(|contact| contact.user_id());
968
969 Ok(contacts)
970 })
971 }
972
973 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
974 test_support!(self, {
975 let (id_a, id_b) = if user_id_1 < user_id_2 {
976 (user_id_1, user_id_2)
977 } else {
978 (user_id_2, user_id_1)
979 };
980
981 let query = "
982 SELECT 1 FROM contacts
983 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
984 LIMIT 1
985 ";
986 Ok(sqlx::query_scalar::<_, i32>(query)
987 .bind(id_a.0)
988 .bind(id_b.0)
989 .fetch_optional(&self.pool)
990 .await?
991 .is_some())
992 })
993 }
994
995 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
996 test_support!(self, {
997 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
998 (sender_id, receiver_id, true)
999 } else {
1000 (receiver_id, sender_id, false)
1001 };
1002 let query = "
1003 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1004 VALUES ($1, $2, $3, FALSE, TRUE)
1005 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1006 SET
1007 accepted = TRUE,
1008 should_notify = FALSE
1009 WHERE
1010 NOT contacts.accepted AND
1011 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1012 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1013 ";
1014 let result = sqlx::query(query)
1015 .bind(id_a.0)
1016 .bind(id_b.0)
1017 .bind(a_to_b)
1018 .execute(&self.pool)
1019 .await?;
1020
1021 if result.rows_affected() == 1 {
1022 Ok(())
1023 } else {
1024 Err(anyhow!("contact already requested"))?
1025 }
1026 })
1027 }
1028
1029 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1030 test_support!(self, {
1031 let (id_a, id_b) = if responder_id < requester_id {
1032 (responder_id, requester_id)
1033 } else {
1034 (requester_id, responder_id)
1035 };
1036 let query = "
1037 DELETE FROM contacts
1038 WHERE user_id_a = $1 AND user_id_b = $2;
1039 ";
1040 let result = sqlx::query(query)
1041 .bind(id_a.0)
1042 .bind(id_b.0)
1043 .execute(&self.pool)
1044 .await?;
1045
1046 if result.rows_affected() == 1 {
1047 Ok(())
1048 } else {
1049 Err(anyhow!("no such contact"))?
1050 }
1051 })
1052 }
1053
1054 pub async fn dismiss_contact_notification(
1055 &self,
1056 user_id: UserId,
1057 contact_user_id: UserId,
1058 ) -> Result<()> {
1059 test_support!(self, {
1060 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1061 (user_id, contact_user_id, true)
1062 } else {
1063 (contact_user_id, user_id, false)
1064 };
1065
1066 let query = "
1067 UPDATE contacts
1068 SET should_notify = FALSE
1069 WHERE
1070 user_id_a = $1 AND user_id_b = $2 AND
1071 (
1072 (a_to_b = $3 AND accepted) OR
1073 (a_to_b != $3 AND NOT accepted)
1074 );
1075 ";
1076
1077 let result = sqlx::query(query)
1078 .bind(id_a.0)
1079 .bind(id_b.0)
1080 .bind(a_to_b)
1081 .execute(&self.pool)
1082 .await?;
1083
1084 if result.rows_affected() == 0 {
1085 Err(anyhow!("no such contact request"))?;
1086 }
1087
1088 Ok(())
1089 })
1090 }
1091
1092 pub async fn respond_to_contact_request(
1093 &self,
1094 responder_id: UserId,
1095 requester_id: UserId,
1096 accept: bool,
1097 ) -> Result<()> {
1098 test_support!(self, {
1099 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1100 (responder_id, requester_id, false)
1101 } else {
1102 (requester_id, responder_id, true)
1103 };
1104 let result = if accept {
1105 let query = "
1106 UPDATE contacts
1107 SET accepted = TRUE, should_notify = TRUE
1108 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1109 ";
1110 sqlx::query(query)
1111 .bind(id_a.0)
1112 .bind(id_b.0)
1113 .bind(a_to_b)
1114 .execute(&self.pool)
1115 .await?
1116 } else {
1117 let query = "
1118 DELETE FROM contacts
1119 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1120 ";
1121 sqlx::query(query)
1122 .bind(id_a.0)
1123 .bind(id_b.0)
1124 .bind(a_to_b)
1125 .execute(&self.pool)
1126 .await?
1127 };
1128 if result.rows_affected() == 1 {
1129 Ok(())
1130 } else {
1131 Err(anyhow!("no such contact request"))?
1132 }
1133 })
1134 }
1135
1136 // access tokens
1137
1138 pub async fn create_access_token_hash(
1139 &self,
1140 user_id: UserId,
1141 access_token_hash: &str,
1142 max_access_token_count: usize,
1143 ) -> Result<()> {
1144 test_support!(self, {
1145 let insert_query = "
1146 INSERT INTO access_tokens (user_id, hash)
1147 VALUES ($1, $2);
1148 ";
1149 let cleanup_query = "
1150 DELETE FROM access_tokens
1151 WHERE id IN (
1152 SELECT id from access_tokens
1153 WHERE user_id = $1
1154 ORDER BY id DESC
1155 LIMIT 10000
1156 OFFSET $3
1157 )
1158 ";
1159
1160 let mut tx = self.pool.begin().await?;
1161 sqlx::query(insert_query)
1162 .bind(user_id.0)
1163 .bind(access_token_hash)
1164 .execute(&mut tx)
1165 .await?;
1166 sqlx::query(cleanup_query)
1167 .bind(user_id.0)
1168 .bind(access_token_hash)
1169 .bind(max_access_token_count as i32)
1170 .execute(&mut tx)
1171 .await?;
1172 Ok(tx.commit().await?)
1173 })
1174 }
1175
1176 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1177 test_support!(self, {
1178 let query = "
1179 SELECT hash
1180 FROM access_tokens
1181 WHERE user_id = $1
1182 ORDER BY id DESC
1183 ";
1184 Ok(sqlx::query_scalar(query)
1185 .bind(user_id.0)
1186 .fetch_all(&self.pool)
1187 .await?)
1188 })
1189 }
1190}
1191
1192macro_rules! id_type {
1193 ($name:ident) => {
1194 #[derive(
1195 Clone,
1196 Copy,
1197 Debug,
1198 Default,
1199 PartialEq,
1200 Eq,
1201 PartialOrd,
1202 Ord,
1203 Hash,
1204 sqlx::Type,
1205 Serialize,
1206 Deserialize,
1207 )]
1208 #[sqlx(transparent)]
1209 #[serde(transparent)]
1210 pub struct $name(pub i32);
1211
1212 impl $name {
1213 #[allow(unused)]
1214 pub const MAX: Self = Self(i32::MAX);
1215
1216 #[allow(unused)]
1217 pub fn from_proto(value: u64) -> Self {
1218 Self(value as i32)
1219 }
1220
1221 #[allow(unused)]
1222 pub fn to_proto(self) -> u64 {
1223 self.0 as u64
1224 }
1225 }
1226
1227 impl std::fmt::Display for $name {
1228 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1229 self.0.fmt(f)
1230 }
1231 }
1232 };
1233}
1234
1235id_type!(UserId);
1236#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1237pub struct User {
1238 pub id: UserId,
1239 pub github_login: String,
1240 pub github_user_id: Option<i32>,
1241 pub email_address: Option<String>,
1242 pub admin: bool,
1243 pub invite_code: Option<String>,
1244 pub invite_count: i32,
1245 pub connected_once: bool,
1246}
1247
1248id_type!(ProjectId);
1249#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1250pub struct Project {
1251 pub id: ProjectId,
1252 pub host_user_id: UserId,
1253 pub unregistered: bool,
1254}
1255
1256#[derive(Clone, Debug, PartialEq, Eq)]
1257pub enum Contact {
1258 Accepted {
1259 user_id: UserId,
1260 should_notify: bool,
1261 },
1262 Outgoing {
1263 user_id: UserId,
1264 },
1265 Incoming {
1266 user_id: UserId,
1267 should_notify: bool,
1268 },
1269}
1270
1271impl Contact {
1272 pub fn user_id(&self) -> UserId {
1273 match self {
1274 Contact::Accepted { user_id, .. } => *user_id,
1275 Contact::Outgoing { user_id } => *user_id,
1276 Contact::Incoming { user_id, .. } => *user_id,
1277 }
1278 }
1279}
1280
1281#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1282pub struct IncomingContactRequest {
1283 pub requester_id: UserId,
1284 pub should_notify: bool,
1285}
1286
1287#[derive(Clone, Deserialize)]
1288pub struct Signup {
1289 pub email_address: String,
1290 pub platform_mac: bool,
1291 pub platform_windows: bool,
1292 pub platform_linux: bool,
1293 pub editor_features: Vec<String>,
1294 pub programming_languages: Vec<String>,
1295 pub device_id: Option<String>,
1296}
1297
1298#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1299pub struct WaitlistSummary {
1300 #[sqlx(default)]
1301 pub count: i64,
1302 #[sqlx(default)]
1303 pub linux_count: i64,
1304 #[sqlx(default)]
1305 pub mac_count: i64,
1306 #[sqlx(default)]
1307 pub windows_count: i64,
1308 #[sqlx(default)]
1309 pub unknown_count: i64,
1310}
1311
1312#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1313pub struct Invite {
1314 pub email_address: String,
1315 pub email_confirmation_code: String,
1316}
1317
1318#[derive(Debug, Serialize, Deserialize)]
1319pub struct NewUserParams {
1320 pub github_login: String,
1321 pub github_user_id: i32,
1322 pub invite_count: i32,
1323}
1324
1325#[derive(Debug)]
1326pub struct NewUserResult {
1327 pub user_id: UserId,
1328 pub metrics_id: String,
1329 pub inviting_user_id: Option<UserId>,
1330 pub signup_device_id: Option<String>,
1331}
1332
1333fn random_invite_code() -> String {
1334 nanoid::nanoid!(16)
1335}
1336
1337fn random_email_confirmation_code() -> String {
1338 nanoid::nanoid!(64)
1339}
1340
1341#[cfg(test)]
1342pub use test::*;
1343
1344#[cfg(test)]
1345mod test {
1346 use super::*;
1347 use gpui::executor::Background;
1348 use lazy_static::lazy_static;
1349 use parking_lot::Mutex;
1350 use rand::prelude::*;
1351 use sqlx::migrate::MigrateDatabase;
1352 use std::sync::Arc;
1353
1354 pub struct SqliteTestDb {
1355 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1356 pub conn: sqlx::sqlite::SqliteConnection,
1357 }
1358
1359 pub struct PostgresTestDb {
1360 pub db: Option<Arc<Db<sqlx::Postgres>>>,
1361 pub url: String,
1362 }
1363
1364 impl SqliteTestDb {
1365 pub fn new(background: Arc<Background>) -> Self {
1366 let mut rng = StdRng::from_entropy();
1367 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1368 let runtime = tokio::runtime::Builder::new_current_thread()
1369 .enable_io()
1370 .enable_time()
1371 .build()
1372 .unwrap();
1373
1374 let (mut db, conn) = runtime.block_on(async {
1375 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1376 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1377 db.migrate(migrations_path.as_ref(), false).await.unwrap();
1378 let conn = db.pool.acquire().await.unwrap().detach();
1379 (db, conn)
1380 });
1381
1382 db.background = Some(background);
1383 db.runtime = Some(runtime);
1384
1385 Self {
1386 db: Some(Arc::new(db)),
1387 conn,
1388 }
1389 }
1390
1391 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1392 self.db.as_ref().unwrap()
1393 }
1394 }
1395
1396 impl PostgresTestDb {
1397 pub fn new(background: Arc<Background>) -> Self {
1398 lazy_static! {
1399 static ref LOCK: Mutex<()> = Mutex::new(());
1400 }
1401
1402 let _guard = LOCK.lock();
1403 let mut rng = StdRng::from_entropy();
1404 let url = format!(
1405 "postgres://postgres@localhost/zed-test-{}",
1406 rng.gen::<u128>()
1407 );
1408 let runtime = tokio::runtime::Builder::new_current_thread()
1409 .enable_io()
1410 .enable_time()
1411 .build()
1412 .unwrap();
1413
1414 let mut db = runtime.block_on(async {
1415 sqlx::Postgres::create_database(&url)
1416 .await
1417 .expect("failed to create test db");
1418 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1419 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1420 db.migrate(Path::new(migrations_path), false).await.unwrap();
1421 db
1422 });
1423
1424 db.background = Some(background);
1425 db.runtime = Some(runtime);
1426
1427 Self {
1428 db: Some(Arc::new(db)),
1429 url,
1430 }
1431 }
1432
1433 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1434 self.db.as_ref().unwrap()
1435 }
1436 }
1437
1438 impl Drop for PostgresTestDb {
1439 fn drop(&mut self) {
1440 let db = self.db.take().unwrap();
1441 db.teardown(&self.url);
1442 }
1443 }
1444}