1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use sqlx::{
8 migrate::{Migrate as _, Migration, MigrationSource},
9 types::Uuid,
10 FromRow,
11};
12use std::{path::Path, time::Duration};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[cfg(test)]
16pub type DefaultDb = Db<sqlx::Sqlite>;
17
18#[cfg(not(test))]
19pub type DefaultDb = Db<sqlx::Postgres>;
20
21pub struct Db<D: sqlx::Database> {
22 pool: sqlx::Pool<D>,
23 #[cfg(test)]
24 background: Option<std::sync::Arc<gpui::executor::Background>>,
25 #[cfg(test)]
26 runtime: Option<tokio::runtime::Runtime>,
27}
28
29macro_rules! test_support {
30 ($self:ident, { $($token:tt)* }) => {{
31 let body = async {
32 $($token)*
33 };
34
35 if cfg!(test) {
36 #[cfg(not(test))]
37 unreachable!();
38
39 #[cfg(test)]
40 if let Some(background) = $self.background.as_ref() {
41 background.simulate_random_delay().await;
42 }
43
44 #[cfg(test)]
45 $self.runtime.as_ref().unwrap().block_on(body)
46 } else {
47 body.await
48 }
49 }};
50}
51
52pub trait RowsAffected {
53 fn rows_affected(&self) -> u64;
54}
55
56#[cfg(test)]
57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
58 fn rows_affected(&self) -> u64 {
59 self.rows_affected()
60 }
61}
62
63impl RowsAffected for sqlx::postgres::PgQueryResult {
64 fn rows_affected(&self) -> u64 {
65 self.rows_affected()
66 }
67}
68
69#[cfg(test)]
70impl Db<sqlx::Sqlite> {
71 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
72 use std::str::FromStr as _;
73 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
74 .unwrap()
75 .create_if_missing(true)
76 .shared_cache(true);
77 let pool = sqlx::sqlite::SqlitePoolOptions::new()
78 .min_connections(2)
79 .max_connections(max_connections)
80 .connect_with(options)
81 .await?;
82 Ok(Self {
83 pool,
84 background: None,
85 runtime: None,
86 })
87 }
88
89 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
90 test_support!(self, {
91 let query = "
92 SELECT users.*
93 FROM users
94 WHERE users.id IN (SELECT value from json_each($1))
95 ";
96 Ok(sqlx::query_as(query)
97 .bind(&serde_json::json!(ids))
98 .fetch_all(&self.pool)
99 .await?)
100 })
101 }
102
103 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
104 test_support!(self, {
105 let query = "
106 SELECT metrics_id
107 FROM users
108 WHERE id = $1
109 ";
110 Ok(sqlx::query_scalar(query)
111 .bind(id)
112 .fetch_one(&self.pool)
113 .await?)
114 })
115 }
116
117 pub async fn create_user(
118 &self,
119 email_address: &str,
120 admin: bool,
121 params: NewUserParams,
122 ) -> Result<NewUserResult> {
123 test_support!(self, {
124 let query = "
125 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
126 VALUES ($1, $2, $3, $4, $5)
127 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
128 RETURNING id, metrics_id
129 ";
130
131 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
132 .bind(email_address)
133 .bind(params.github_login)
134 .bind(params.github_user_id)
135 .bind(admin)
136 .bind(Uuid::new_v4().to_string())
137 .fetch_one(&self.pool)
138 .await?;
139 Ok(NewUserResult {
140 user_id,
141 metrics_id,
142 signup_device_id: None,
143 inviting_user_id: None,
144 })
145 })
146 }
147
148 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
149 unimplemented!()
150 }
151
152 pub async fn create_user_from_invite(
153 &self,
154 _invite: &Invite,
155 _user: NewUserParams,
156 ) -> Result<Option<NewUserResult>> {
157 unimplemented!()
158 }
159
160 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
161 unimplemented!()
162 }
163
164 pub async fn create_invite_from_code(
165 &self,
166 _code: &str,
167 _email_address: &str,
168 _device_id: Option<&str>,
169 ) -> Result<Invite> {
170 unimplemented!()
171 }
172
173 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
174 unimplemented!()
175 }
176}
177
178impl Db<sqlx::Postgres> {
179 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
180 let pool = sqlx::postgres::PgPoolOptions::new()
181 .max_connections(max_connections)
182 .connect(url)
183 .await?;
184 Ok(Self {
185 pool,
186 #[cfg(test)]
187 background: None,
188 #[cfg(test)]
189 runtime: None,
190 })
191 }
192
193 #[cfg(test)]
194 pub fn teardown(&self, url: &str) {
195 self.runtime.as_ref().unwrap().block_on(async {
196 use util::ResultExt;
197 let query = "
198 SELECT pg_terminate_backend(pg_stat_activity.pid)
199 FROM pg_stat_activity
200 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
201 ";
202 sqlx::query(query).execute(&self.pool).await.log_err();
203 self.pool.close().await;
204 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
205 .await
206 .log_err();
207 })
208 }
209
210 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
211 test_support!(self, {
212 let like_string = Self::fuzzy_like_string(name_query);
213 let query = "
214 SELECT users.*
215 FROM users
216 WHERE github_login ILIKE $1
217 ORDER BY github_login <-> $2
218 LIMIT $3
219 ";
220 Ok(sqlx::query_as(query)
221 .bind(like_string)
222 .bind(name_query)
223 .bind(limit as i32)
224 .fetch_all(&self.pool)
225 .await?)
226 })
227 }
228
229 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
230 test_support!(self, {
231 let query = "
232 SELECT users.*
233 FROM users
234 WHERE users.id = ANY ($1)
235 ";
236 Ok(sqlx::query_as(query)
237 .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
238 .fetch_all(&self.pool)
239 .await?)
240 })
241 }
242
243 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
244 test_support!(self, {
245 let query = "
246 SELECT metrics_id::text
247 FROM users
248 WHERE id = $1
249 ";
250 Ok(sqlx::query_scalar(query)
251 .bind(id)
252 .fetch_one(&self.pool)
253 .await?)
254 })
255 }
256
257 pub async fn create_user(
258 &self,
259 email_address: &str,
260 admin: bool,
261 params: NewUserParams,
262 ) -> Result<NewUserResult> {
263 test_support!(self, {
264 let query = "
265 INSERT INTO users (email_address, github_login, github_user_id, admin)
266 VALUES ($1, $2, $3, $4)
267 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
268 RETURNING id, metrics_id::text
269 ";
270
271 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
272 .bind(email_address)
273 .bind(params.github_login)
274 .bind(params.github_user_id)
275 .bind(admin)
276 .fetch_one(&self.pool)
277 .await?;
278 Ok(NewUserResult {
279 user_id,
280 metrics_id,
281 signup_device_id: None,
282 inviting_user_id: None,
283 })
284 })
285 }
286
287 pub async fn create_user_from_invite(
288 &self,
289 invite: &Invite,
290 user: NewUserParams,
291 ) -> Result<Option<NewUserResult>> {
292 test_support!(self, {
293 let mut tx = self.pool.begin().await?;
294
295 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
296 i32,
297 Option<UserId>,
298 Option<UserId>,
299 Option<String>,
300 ) = sqlx::query_as(
301 "
302 SELECT id, user_id, inviting_user_id, device_id
303 FROM signups
304 WHERE
305 email_address = $1 AND
306 email_confirmation_code = $2
307 ",
308 )
309 .bind(&invite.email_address)
310 .bind(&invite.email_confirmation_code)
311 .fetch_optional(&mut tx)
312 .await?
313 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
314
315 if existing_user_id.is_some() {
316 return Ok(None);
317 }
318
319 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
320 "
321 INSERT INTO users
322 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
323 VALUES
324 ($1, $2, $3, FALSE, $4, $5)
325 ON CONFLICT (github_login) DO UPDATE SET
326 email_address = excluded.email_address,
327 github_user_id = excluded.github_user_id,
328 admin = excluded.admin
329 RETURNING id, metrics_id::text
330 ",
331 )
332 .bind(&invite.email_address)
333 .bind(&user.github_login)
334 .bind(&user.github_user_id)
335 .bind(&user.invite_count)
336 .bind(random_invite_code())
337 .fetch_one(&mut tx)
338 .await?;
339
340 sqlx::query(
341 "
342 UPDATE signups
343 SET user_id = $1
344 WHERE id = $2
345 ",
346 )
347 .bind(&user_id)
348 .bind(&signup_id)
349 .execute(&mut tx)
350 .await?;
351
352 if let Some(inviting_user_id) = inviting_user_id {
353 let id: Option<UserId> = sqlx::query_scalar(
354 "
355 UPDATE users
356 SET invite_count = invite_count - 1
357 WHERE id = $1 AND invite_count > 0
358 RETURNING id
359 ",
360 )
361 .bind(&inviting_user_id)
362 .fetch_optional(&mut tx)
363 .await?;
364
365 if id.is_none() {
366 Err(Error::Http(
367 StatusCode::UNAUTHORIZED,
368 "no invites remaining".to_string(),
369 ))?;
370 }
371
372 sqlx::query(
373 "
374 INSERT INTO contacts
375 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
376 VALUES
377 ($1, $2, TRUE, TRUE, TRUE)
378 ON CONFLICT DO NOTHING
379 ",
380 )
381 .bind(inviting_user_id)
382 .bind(user_id)
383 .execute(&mut tx)
384 .await?;
385 }
386
387 tx.commit().await?;
388 Ok(Some(NewUserResult {
389 user_id,
390 metrics_id,
391 inviting_user_id,
392 signup_device_id,
393 }))
394 })
395 }
396
397 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
398 test_support!(self, {
399 sqlx::query(
400 "
401 INSERT INTO signups
402 (
403 email_address,
404 email_confirmation_code,
405 email_confirmation_sent,
406 platform_linux,
407 platform_mac,
408 platform_windows,
409 platform_unknown,
410 editor_features,
411 programming_languages,
412 device_id
413 )
414 VALUES
415 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
416 RETURNING id
417 ",
418 )
419 .bind(&signup.email_address)
420 .bind(&random_email_confirmation_code())
421 .bind(&signup.platform_linux)
422 .bind(&signup.platform_mac)
423 .bind(&signup.platform_windows)
424 .bind(&signup.editor_features)
425 .bind(&signup.programming_languages)
426 .bind(&signup.device_id)
427 .execute(&self.pool)
428 .await?;
429 Ok(())
430 })
431 }
432
433 pub async fn create_invite_from_code(
434 &self,
435 code: &str,
436 email_address: &str,
437 device_id: Option<&str>,
438 ) -> Result<Invite> {
439 test_support!(self, {
440 let mut tx = self.pool.begin().await?;
441
442 let existing_user: Option<UserId> = sqlx::query_scalar(
443 "
444 SELECT id
445 FROM users
446 WHERE email_address = $1
447 ",
448 )
449 .bind(email_address)
450 .fetch_optional(&mut tx)
451 .await?;
452 if existing_user.is_some() {
453 Err(anyhow!("email address is already in use"))?;
454 }
455
456 let row: Option<(UserId, i32)> = sqlx::query_as(
457 "
458 SELECT id, invite_count
459 FROM users
460 WHERE invite_code = $1
461 ",
462 )
463 .bind(code)
464 .fetch_optional(&mut tx)
465 .await?;
466
467 let (inviter_id, invite_count) = match row {
468 Some(row) => row,
469 None => Err(Error::Http(
470 StatusCode::NOT_FOUND,
471 "invite code not found".to_string(),
472 ))?,
473 };
474
475 if invite_count == 0 {
476 Err(Error::Http(
477 StatusCode::UNAUTHORIZED,
478 "no invites remaining".to_string(),
479 ))?;
480 }
481
482 let email_confirmation_code: String = sqlx::query_scalar(
483 "
484 INSERT INTO signups
485 (
486 email_address,
487 email_confirmation_code,
488 email_confirmation_sent,
489 inviting_user_id,
490 platform_linux,
491 platform_mac,
492 platform_windows,
493 platform_unknown,
494 device_id
495 )
496 VALUES
497 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
498 ON CONFLICT (email_address)
499 DO UPDATE SET
500 inviting_user_id = excluded.inviting_user_id
501 RETURNING email_confirmation_code
502 ",
503 )
504 .bind(&email_address)
505 .bind(&random_email_confirmation_code())
506 .bind(&inviter_id)
507 .bind(&device_id)
508 .fetch_one(&mut tx)
509 .await?;
510
511 tx.commit().await?;
512
513 Ok(Invite {
514 email_address: email_address.into(),
515 email_confirmation_code,
516 })
517 })
518 }
519
520 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
521 test_support!(self, {
522 let emails = invites
523 .iter()
524 .map(|s| s.email_address.as_str())
525 .collect::<Vec<_>>();
526 sqlx::query(
527 "
528 UPDATE signups
529 SET email_confirmation_sent = TRUE
530 WHERE email_address = ANY ($1)
531 ",
532 )
533 .bind(&emails)
534 .execute(&self.pool)
535 .await?;
536 Ok(())
537 })
538 }
539}
540
541impl<D> Db<D>
542where
543 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
544 D::Connection: sqlx::migrate::Migrate,
545 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
546 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
547 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
548 D::QueryResult: RowsAffected,
549 String: sqlx::Type<D>,
550 i32: sqlx::Type<D>,
551 i64: sqlx::Type<D>,
552 bool: sqlx::Type<D>,
553 str: sqlx::Type<D>,
554 Uuid: sqlx::Type<D>,
555 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
556 OffsetDateTime: sqlx::Type<D>,
557 PrimitiveDateTime: sqlx::Type<D>,
558 usize: sqlx::ColumnIndex<D::Row>,
559 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
560 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
561 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
562 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
563 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
564 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
565 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
566 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
567 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
568 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
569 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
570 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
571{
572 pub async fn migrate(
573 &self,
574 migrations_path: &Path,
575 ignore_checksum_mismatch: bool,
576 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
577 let migrations = MigrationSource::resolve(migrations_path)
578 .await
579 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
580
581 let mut conn = self.pool.acquire().await?;
582
583 conn.ensure_migrations_table().await?;
584 let applied_migrations: HashMap<_, _> = conn
585 .list_applied_migrations()
586 .await?
587 .into_iter()
588 .map(|m| (m.version, m))
589 .collect();
590
591 let mut new_migrations = Vec::new();
592 for migration in migrations {
593 match applied_migrations.get(&migration.version) {
594 Some(applied_migration) => {
595 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
596 {
597 Err(anyhow!(
598 "checksum mismatch for applied migration {}",
599 migration.description
600 ))?;
601 }
602 }
603 None => {
604 let elapsed = conn.apply(&migration).await?;
605 new_migrations.push((migration, elapsed));
606 }
607 }
608 }
609
610 Ok(new_migrations)
611 }
612
613 pub fn fuzzy_like_string(string: &str) -> String {
614 let mut result = String::with_capacity(string.len() * 2 + 1);
615 for c in string.chars() {
616 if c.is_alphanumeric() {
617 result.push('%');
618 result.push(c);
619 }
620 }
621 result.push('%');
622 result
623 }
624
625 // users
626
627 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
628 test_support!(self, {
629 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
630 Ok(sqlx::query_as(query)
631 .bind(limit as i32)
632 .bind((page * limit) as i32)
633 .fetch_all(&self.pool)
634 .await?)
635 })
636 }
637
638 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
639 test_support!(self, {
640 let query = "
641 SELECT users.*
642 FROM users
643 WHERE id = $1
644 LIMIT 1
645 ";
646 Ok(sqlx::query_as(query)
647 .bind(&id)
648 .fetch_optional(&self.pool)
649 .await?)
650 })
651 }
652
653 pub async fn get_users_with_no_invites(
654 &self,
655 invited_by_another_user: bool,
656 ) -> Result<Vec<User>> {
657 test_support!(self, {
658 let query = format!(
659 "
660 SELECT users.*
661 FROM users
662 WHERE invite_count = 0
663 AND inviter_id IS{} NULL
664 ",
665 if invited_by_another_user { " NOT" } else { "" }
666 );
667
668 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
669 })
670 }
671
672 pub async fn get_user_by_github_account(
673 &self,
674 github_login: &str,
675 github_user_id: Option<i32>,
676 ) -> Result<Option<User>> {
677 test_support!(self, {
678 if let Some(github_user_id) = github_user_id {
679 let mut user = sqlx::query_as::<_, User>(
680 "
681 UPDATE users
682 SET github_login = $1
683 WHERE github_user_id = $2
684 RETURNING *
685 ",
686 )
687 .bind(github_login)
688 .bind(github_user_id)
689 .fetch_optional(&self.pool)
690 .await?;
691
692 if user.is_none() {
693 user = sqlx::query_as::<_, User>(
694 "
695 UPDATE users
696 SET github_user_id = $1
697 WHERE github_login = $2
698 RETURNING *
699 ",
700 )
701 .bind(github_user_id)
702 .bind(github_login)
703 .fetch_optional(&self.pool)
704 .await?;
705 }
706
707 Ok(user)
708 } else {
709 let user = sqlx::query_as(
710 "
711 SELECT * FROM users
712 WHERE github_login = $1
713 LIMIT 1
714 ",
715 )
716 .bind(github_login)
717 .fetch_optional(&self.pool)
718 .await?;
719 Ok(user)
720 }
721 })
722 }
723
724 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
725 test_support!(self, {
726 let query = "UPDATE users SET admin = $1 WHERE id = $2";
727 Ok(sqlx::query(query)
728 .bind(is_admin)
729 .bind(id.0)
730 .execute(&self.pool)
731 .await
732 .map(drop)?)
733 })
734 }
735
736 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
737 test_support!(self, {
738 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
739 Ok(sqlx::query(query)
740 .bind(connected_once)
741 .bind(id.0)
742 .execute(&self.pool)
743 .await
744 .map(drop)?)
745 })
746 }
747
748 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
749 test_support!(self, {
750 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
751 sqlx::query(query)
752 .bind(id.0)
753 .execute(&self.pool)
754 .await
755 .map(drop)?;
756 let query = "DELETE FROM users WHERE id = $1;";
757 Ok(sqlx::query(query)
758 .bind(id.0)
759 .execute(&self.pool)
760 .await
761 .map(drop)?)
762 })
763 }
764
765 // signups
766
767 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
768 test_support!(self, {
769 Ok(sqlx::query_as(
770 "
771 SELECT
772 COUNT(*) as count,
773 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
774 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
775 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
776 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
777 FROM (
778 SELECT *
779 FROM signups
780 WHERE
781 NOT email_confirmation_sent
782 ) AS unsent
783 ",
784 )
785 .fetch_one(&self.pool)
786 .await?)
787 })
788 }
789
790 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
791 test_support!(self, {
792 Ok(sqlx::query_as(
793 "
794 SELECT
795 email_address, email_confirmation_code
796 FROM signups
797 WHERE
798 NOT email_confirmation_sent AND
799 (platform_mac OR platform_unknown)
800 LIMIT $1
801 ",
802 )
803 .bind(count as i32)
804 .fetch_all(&self.pool)
805 .await?)
806 })
807 }
808
809 // invite codes
810
811 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
812 test_support!(self, {
813 let mut tx = self.pool.begin().await?;
814 if count > 0 {
815 sqlx::query(
816 "
817 UPDATE users
818 SET invite_code = $1
819 WHERE id = $2 AND invite_code IS NULL
820 ",
821 )
822 .bind(random_invite_code())
823 .bind(id)
824 .execute(&mut tx)
825 .await?;
826 }
827
828 sqlx::query(
829 "
830 UPDATE users
831 SET invite_count = $1
832 WHERE id = $2
833 ",
834 )
835 .bind(count as i32)
836 .bind(id)
837 .execute(&mut tx)
838 .await?;
839 tx.commit().await?;
840 Ok(())
841 })
842 }
843
844 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
845 test_support!(self, {
846 let result: Option<(String, i32)> = sqlx::query_as(
847 "
848 SELECT invite_code, invite_count
849 FROM users
850 WHERE id = $1 AND invite_code IS NOT NULL
851 ",
852 )
853 .bind(id)
854 .fetch_optional(&self.pool)
855 .await?;
856 if let Some((code, count)) = result {
857 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
858 } else {
859 Ok(None)
860 }
861 })
862 }
863
864 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
865 test_support!(self, {
866 sqlx::query_as(
867 "
868 SELECT *
869 FROM users
870 WHERE invite_code = $1
871 ",
872 )
873 .bind(code)
874 .fetch_optional(&self.pool)
875 .await?
876 .ok_or_else(|| {
877 Error::Http(
878 StatusCode::NOT_FOUND,
879 "that invite code does not exist".to_string(),
880 )
881 })
882 })
883 }
884
885 // projects
886
887 /// Registers a new project for the given user.
888 pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
889 test_support!(self, {
890 Ok(sqlx::query_scalar(
891 "
892 INSERT INTO projects(host_user_id)
893 VALUES ($1)
894 RETURNING id
895 ",
896 )
897 .bind(host_user_id)
898 .fetch_one(&self.pool)
899 .await
900 .map(ProjectId)?)
901 })
902 }
903
904 /// Unregisters a project for the given project id.
905 pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
906 test_support!(self, {
907 sqlx::query(
908 "
909 UPDATE projects
910 SET unregistered = TRUE
911 WHERE id = $1
912 ",
913 )
914 .bind(project_id)
915 .execute(&self.pool)
916 .await?;
917 Ok(())
918 })
919 }
920
921 // contacts
922
923 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
924 test_support!(self, {
925 let query = "
926 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
927 FROM contacts
928 WHERE user_id_a = $1 OR user_id_b = $1;
929 ";
930
931 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
932 .bind(user_id)
933 .fetch(&self.pool);
934
935 let mut contacts = Vec::new();
936 while let Some(row) = rows.next().await {
937 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
938
939 if user_id_a == user_id {
940 if accepted {
941 contacts.push(Contact::Accepted {
942 user_id: user_id_b,
943 should_notify: should_notify && a_to_b,
944 });
945 } else if a_to_b {
946 contacts.push(Contact::Outgoing { user_id: user_id_b })
947 } else {
948 contacts.push(Contact::Incoming {
949 user_id: user_id_b,
950 should_notify,
951 });
952 }
953 } else if accepted {
954 contacts.push(Contact::Accepted {
955 user_id: user_id_a,
956 should_notify: should_notify && !a_to_b,
957 });
958 } else if a_to_b {
959 contacts.push(Contact::Incoming {
960 user_id: user_id_a,
961 should_notify,
962 });
963 } else {
964 contacts.push(Contact::Outgoing { user_id: user_id_a });
965 }
966 }
967
968 contacts.sort_unstable_by_key(|contact| contact.user_id());
969
970 Ok(contacts)
971 })
972 }
973
974 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
975 test_support!(self, {
976 let (id_a, id_b) = if user_id_1 < user_id_2 {
977 (user_id_1, user_id_2)
978 } else {
979 (user_id_2, user_id_1)
980 };
981
982 let query = "
983 SELECT 1 FROM contacts
984 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
985 LIMIT 1
986 ";
987 Ok(sqlx::query_scalar::<_, i32>(query)
988 .bind(id_a.0)
989 .bind(id_b.0)
990 .fetch_optional(&self.pool)
991 .await?
992 .is_some())
993 })
994 }
995
996 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
997 test_support!(self, {
998 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
999 (sender_id, receiver_id, true)
1000 } else {
1001 (receiver_id, sender_id, false)
1002 };
1003 let query = "
1004 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1005 VALUES ($1, $2, $3, FALSE, TRUE)
1006 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1007 SET
1008 accepted = TRUE,
1009 should_notify = FALSE
1010 WHERE
1011 NOT contacts.accepted AND
1012 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1013 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1014 ";
1015 let result = sqlx::query(query)
1016 .bind(id_a.0)
1017 .bind(id_b.0)
1018 .bind(a_to_b)
1019 .execute(&self.pool)
1020 .await?;
1021
1022 if result.rows_affected() == 1 {
1023 Ok(())
1024 } else {
1025 Err(anyhow!("contact already requested"))?
1026 }
1027 })
1028 }
1029
1030 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1031 test_support!(self, {
1032 let (id_a, id_b) = if responder_id < requester_id {
1033 (responder_id, requester_id)
1034 } else {
1035 (requester_id, responder_id)
1036 };
1037 let query = "
1038 DELETE FROM contacts
1039 WHERE user_id_a = $1 AND user_id_b = $2;
1040 ";
1041 let result = sqlx::query(query)
1042 .bind(id_a.0)
1043 .bind(id_b.0)
1044 .execute(&self.pool)
1045 .await?;
1046
1047 if result.rows_affected() == 1 {
1048 Ok(())
1049 } else {
1050 Err(anyhow!("no such contact"))?
1051 }
1052 })
1053 }
1054
1055 pub async fn dismiss_contact_notification(
1056 &self,
1057 user_id: UserId,
1058 contact_user_id: UserId,
1059 ) -> Result<()> {
1060 test_support!(self, {
1061 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1062 (user_id, contact_user_id, true)
1063 } else {
1064 (contact_user_id, user_id, false)
1065 };
1066
1067 let query = "
1068 UPDATE contacts
1069 SET should_notify = FALSE
1070 WHERE
1071 user_id_a = $1 AND user_id_b = $2 AND
1072 (
1073 (a_to_b = $3 AND accepted) OR
1074 (a_to_b != $3 AND NOT accepted)
1075 );
1076 ";
1077
1078 let result = sqlx::query(query)
1079 .bind(id_a.0)
1080 .bind(id_b.0)
1081 .bind(a_to_b)
1082 .execute(&self.pool)
1083 .await?;
1084
1085 if result.rows_affected() == 0 {
1086 Err(anyhow!("no such contact request"))?;
1087 }
1088
1089 Ok(())
1090 })
1091 }
1092
1093 pub async fn respond_to_contact_request(
1094 &self,
1095 responder_id: UserId,
1096 requester_id: UserId,
1097 accept: bool,
1098 ) -> Result<()> {
1099 test_support!(self, {
1100 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1101 (responder_id, requester_id, false)
1102 } else {
1103 (requester_id, responder_id, true)
1104 };
1105 let result = if accept {
1106 let query = "
1107 UPDATE contacts
1108 SET accepted = TRUE, should_notify = TRUE
1109 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1110 ";
1111 sqlx::query(query)
1112 .bind(id_a.0)
1113 .bind(id_b.0)
1114 .bind(a_to_b)
1115 .execute(&self.pool)
1116 .await?
1117 } else {
1118 let query = "
1119 DELETE FROM contacts
1120 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1121 ";
1122 sqlx::query(query)
1123 .bind(id_a.0)
1124 .bind(id_b.0)
1125 .bind(a_to_b)
1126 .execute(&self.pool)
1127 .await?
1128 };
1129 if result.rows_affected() == 1 {
1130 Ok(())
1131 } else {
1132 Err(anyhow!("no such contact request"))?
1133 }
1134 })
1135 }
1136
1137 // access tokens
1138
1139 pub async fn create_access_token_hash(
1140 &self,
1141 user_id: UserId,
1142 access_token_hash: &str,
1143 max_access_token_count: usize,
1144 ) -> Result<()> {
1145 test_support!(self, {
1146 let insert_query = "
1147 INSERT INTO access_tokens (user_id, hash)
1148 VALUES ($1, $2);
1149 ";
1150 let cleanup_query = "
1151 DELETE FROM access_tokens
1152 WHERE id IN (
1153 SELECT id from access_tokens
1154 WHERE user_id = $1
1155 ORDER BY id DESC
1156 LIMIT 10000
1157 OFFSET $3
1158 )
1159 ";
1160
1161 let mut tx = self.pool.begin().await?;
1162 sqlx::query(insert_query)
1163 .bind(user_id.0)
1164 .bind(access_token_hash)
1165 .execute(&mut tx)
1166 .await?;
1167 sqlx::query(cleanup_query)
1168 .bind(user_id.0)
1169 .bind(access_token_hash)
1170 .bind(max_access_token_count as i32)
1171 .execute(&mut tx)
1172 .await?;
1173 Ok(tx.commit().await?)
1174 })
1175 }
1176
1177 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1178 test_support!(self, {
1179 let query = "
1180 SELECT hash
1181 FROM access_tokens
1182 WHERE user_id = $1
1183 ORDER BY id DESC
1184 ";
1185 Ok(sqlx::query_scalar(query)
1186 .bind(user_id.0)
1187 .fetch_all(&self.pool)
1188 .await?)
1189 })
1190 }
1191}
1192
1193macro_rules! id_type {
1194 ($name:ident) => {
1195 #[derive(
1196 Clone,
1197 Copy,
1198 Debug,
1199 Default,
1200 PartialEq,
1201 Eq,
1202 PartialOrd,
1203 Ord,
1204 Hash,
1205 sqlx::Type,
1206 Serialize,
1207 Deserialize,
1208 )]
1209 #[sqlx(transparent)]
1210 #[serde(transparent)]
1211 pub struct $name(pub i32);
1212
1213 impl $name {
1214 #[allow(unused)]
1215 pub const MAX: Self = Self(i32::MAX);
1216
1217 #[allow(unused)]
1218 pub fn from_proto(value: u64) -> Self {
1219 Self(value as i32)
1220 }
1221
1222 #[allow(unused)]
1223 pub fn to_proto(self) -> u64 {
1224 self.0 as u64
1225 }
1226 }
1227
1228 impl std::fmt::Display for $name {
1229 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1230 self.0.fmt(f)
1231 }
1232 }
1233 };
1234}
1235
1236id_type!(UserId);
1237#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1238pub struct User {
1239 pub id: UserId,
1240 pub github_login: String,
1241 pub github_user_id: Option<i32>,
1242 pub email_address: Option<String>,
1243 pub admin: bool,
1244 pub invite_code: Option<String>,
1245 pub invite_count: i32,
1246 pub connected_once: bool,
1247}
1248
1249id_type!(ProjectId);
1250#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1251pub struct Project {
1252 pub id: ProjectId,
1253 pub host_user_id: UserId,
1254 pub unregistered: bool,
1255}
1256
1257#[derive(Clone, Debug, PartialEq, Eq)]
1258pub enum Contact {
1259 Accepted {
1260 user_id: UserId,
1261 should_notify: bool,
1262 },
1263 Outgoing {
1264 user_id: UserId,
1265 },
1266 Incoming {
1267 user_id: UserId,
1268 should_notify: bool,
1269 },
1270}
1271
1272impl Contact {
1273 pub fn user_id(&self) -> UserId {
1274 match self {
1275 Contact::Accepted { user_id, .. } => *user_id,
1276 Contact::Outgoing { user_id } => *user_id,
1277 Contact::Incoming { user_id, .. } => *user_id,
1278 }
1279 }
1280}
1281
1282#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1283pub struct IncomingContactRequest {
1284 pub requester_id: UserId,
1285 pub should_notify: bool,
1286}
1287
1288#[derive(Clone, Deserialize)]
1289pub struct Signup {
1290 pub email_address: String,
1291 pub platform_mac: bool,
1292 pub platform_windows: bool,
1293 pub platform_linux: bool,
1294 pub editor_features: Vec<String>,
1295 pub programming_languages: Vec<String>,
1296 pub device_id: Option<String>,
1297}
1298
1299#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1300pub struct WaitlistSummary {
1301 #[sqlx(default)]
1302 pub count: i64,
1303 #[sqlx(default)]
1304 pub linux_count: i64,
1305 #[sqlx(default)]
1306 pub mac_count: i64,
1307 #[sqlx(default)]
1308 pub windows_count: i64,
1309 #[sqlx(default)]
1310 pub unknown_count: i64,
1311}
1312
1313#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1314pub struct Invite {
1315 pub email_address: String,
1316 pub email_confirmation_code: String,
1317}
1318
1319#[derive(Debug, Serialize, Deserialize)]
1320pub struct NewUserParams {
1321 pub github_login: String,
1322 pub github_user_id: i32,
1323 pub invite_count: i32,
1324}
1325
1326#[derive(Debug)]
1327pub struct NewUserResult {
1328 pub user_id: UserId,
1329 pub metrics_id: String,
1330 pub inviting_user_id: Option<UserId>,
1331 pub signup_device_id: Option<String>,
1332}
1333
1334fn random_invite_code() -> String {
1335 nanoid::nanoid!(16)
1336}
1337
1338fn random_email_confirmation_code() -> String {
1339 nanoid::nanoid!(64)
1340}
1341
1342#[cfg(test)]
1343pub use test::*;
1344
1345#[cfg(test)]
1346mod test {
1347 use super::*;
1348 use gpui::executor::Background;
1349 use lazy_static::lazy_static;
1350 use parking_lot::Mutex;
1351 use rand::prelude::*;
1352 use sqlx::migrate::MigrateDatabase;
1353 use std::sync::Arc;
1354
1355 pub struct SqliteTestDb {
1356 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1357 pub conn: sqlx::sqlite::SqliteConnection,
1358 }
1359
1360 pub struct PostgresTestDb {
1361 pub db: Option<Arc<Db<sqlx::Postgres>>>,
1362 pub url: String,
1363 }
1364
1365 impl SqliteTestDb {
1366 pub fn new(background: Arc<Background>) -> Self {
1367 let mut rng = StdRng::from_entropy();
1368 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1369 let runtime = tokio::runtime::Builder::new_current_thread()
1370 .enable_io()
1371 .enable_time()
1372 .build()
1373 .unwrap();
1374
1375 let (mut db, conn) = runtime.block_on(async {
1376 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1377 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1378 db.migrate(migrations_path.as_ref(), false).await.unwrap();
1379 let conn = db.pool.acquire().await.unwrap().detach();
1380 (db, conn)
1381 });
1382
1383 db.background = Some(background);
1384 db.runtime = Some(runtime);
1385
1386 Self {
1387 db: Some(Arc::new(db)),
1388 conn,
1389 }
1390 }
1391
1392 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1393 self.db.as_ref().unwrap()
1394 }
1395 }
1396
1397 impl PostgresTestDb {
1398 pub fn new(background: Arc<Background>) -> Self {
1399 lazy_static! {
1400 static ref LOCK: Mutex<()> = Mutex::new(());
1401 }
1402
1403 let _guard = LOCK.lock();
1404 let mut rng = StdRng::from_entropy();
1405 let url = format!(
1406 "postgres://postgres@localhost/zed-test-{}",
1407 rng.gen::<u128>()
1408 );
1409 let runtime = tokio::runtime::Builder::new_current_thread()
1410 .enable_io()
1411 .enable_time()
1412 .build()
1413 .unwrap();
1414
1415 let mut db = runtime.block_on(async {
1416 sqlx::Postgres::create_database(&url)
1417 .await
1418 .expect("failed to create test db");
1419 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1420 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1421 db.migrate(Path::new(migrations_path), false).await.unwrap();
1422 db
1423 });
1424
1425 db.background = Some(background);
1426 db.runtime = Some(runtime);
1427
1428 Self {
1429 db: Some(Arc::new(db)),
1430 url,
1431 }
1432 }
1433
1434 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1435 self.db.as_ref().unwrap()
1436 }
1437 }
1438
1439 impl Drop for PostgresTestDb {
1440 fn drop(&mut self) {
1441 let db = self.db.take().unwrap();
1442 db.teardown(&self.url);
1443 }
1444 }
1445}