1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use sqlx::{
8 migrate::{Migrate as _, Migration, MigrationSource},
9 types::Uuid,
10 FromRow,
11};
12use std::{path::Path, time::Duration};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[cfg(test)]
16pub type DefaultDb = Db<sqlx::Sqlite>;
17
18#[cfg(not(test))]
19pub type DefaultDb = Db<sqlx::Postgres>;
20
21pub struct Db<D: sqlx::Database> {
22 pool: sqlx::Pool<D>,
23 #[cfg(test)]
24 background: Option<std::sync::Arc<gpui::executor::Background>>,
25 #[cfg(test)]
26 runtime: Option<tokio::runtime::Runtime>,
27}
28
29macro_rules! test_support {
30 ($self:ident, { $($token:tt)* }) => {{
31 let body = async {
32 $($token)*
33 };
34
35 if cfg!(test) {
36 #[cfg(not(test))]
37 unreachable!();
38
39 #[cfg(test)]
40 if let Some(background) = $self.background.as_ref() {
41 background.simulate_random_delay().await;
42 }
43
44 #[cfg(test)]
45 $self.runtime.as_ref().unwrap().block_on(body)
46 } else {
47 body.await
48 }
49 }};
50}
51
52pub trait RowsAffected {
53 fn rows_affected(&self) -> u64;
54}
55
56#[cfg(test)]
57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
58 fn rows_affected(&self) -> u64 {
59 self.rows_affected()
60 }
61}
62
63impl RowsAffected for sqlx::postgres::PgQueryResult {
64 fn rows_affected(&self) -> u64 {
65 self.rows_affected()
66 }
67}
68
69#[cfg(test)]
70impl Db<sqlx::Sqlite> {
71 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
72 use std::str::FromStr as _;
73 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
74 .unwrap()
75 .create_if_missing(true)
76 .shared_cache(true);
77 let pool = sqlx::sqlite::SqlitePoolOptions::new()
78 .min_connections(2)
79 .max_connections(max_connections)
80 .connect_with(options)
81 .await?;
82 Ok(Self {
83 pool,
84 background: None,
85 runtime: None,
86 })
87 }
88
89 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
90 test_support!(self, {
91 let query = "
92 SELECT users.*
93 FROM users
94 WHERE users.id IN (SELECT value from json_each($1))
95 ";
96 Ok(sqlx::query_as(query)
97 .bind(&serde_json::json!(ids))
98 .fetch_all(&self.pool)
99 .await?)
100 })
101 }
102
103 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
104 test_support!(self, {
105 let query = "
106 SELECT metrics_id
107 FROM users
108 WHERE id = $1
109 ";
110 Ok(sqlx::query_scalar(query)
111 .bind(id)
112 .fetch_one(&self.pool)
113 .await?)
114 })
115 }
116
117 pub async fn create_user(
118 &self,
119 email_address: &str,
120 admin: bool,
121 params: NewUserParams,
122 ) -> Result<NewUserResult> {
123 test_support!(self, {
124 let query = "
125 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
126 VALUES ($1, $2, $3, $4, $5)
127 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
128 RETURNING id, metrics_id
129 ";
130
131 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
132 .bind(email_address)
133 .bind(params.github_login)
134 .bind(params.github_user_id)
135 .bind(admin)
136 .bind(Uuid::new_v4().to_string())
137 .fetch_one(&self.pool)
138 .await?;
139 Ok(NewUserResult {
140 user_id,
141 metrics_id,
142 signup_device_id: None,
143 inviting_user_id: None,
144 })
145 })
146 }
147
148 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
149 unimplemented!()
150 }
151
152 pub async fn create_user_from_invite(
153 &self,
154 _invite: &Invite,
155 _user: NewUserParams,
156 ) -> Result<Option<NewUserResult>> {
157 unimplemented!()
158 }
159
160 pub async fn create_signup(&self, _signup: &Signup) -> Result<()> {
161 unimplemented!()
162 }
163
164 pub async fn create_invite_from_code(
165 &self,
166 _code: &str,
167 _email_address: &str,
168 _device_id: Option<&str>,
169 ) -> Result<Invite> {
170 unimplemented!()
171 }
172
173 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
174 unimplemented!()
175 }
176}
177
178impl Db<sqlx::Postgres> {
179 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
180 let pool = sqlx::postgres::PgPoolOptions::new()
181 .max_connections(max_connections)
182 .connect(url)
183 .await?;
184 Ok(Self {
185 pool,
186 #[cfg(test)]
187 background: None,
188 #[cfg(test)]
189 runtime: None,
190 })
191 }
192
193 #[cfg(test)]
194 pub fn teardown(&self, url: &str) {
195 self.runtime.as_ref().unwrap().block_on(async {
196 use util::ResultExt;
197 let query = "
198 SELECT pg_terminate_backend(pg_stat_activity.pid)
199 FROM pg_stat_activity
200 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
201 ";
202 sqlx::query(query).execute(&self.pool).await.log_err();
203 self.pool.close().await;
204 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
205 .await
206 .log_err();
207 })
208 }
209
210 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
211 test_support!(self, {
212 let like_string = Self::fuzzy_like_string(name_query);
213 let query = "
214 SELECT users.*
215 FROM users
216 WHERE github_login ILIKE $1
217 ORDER BY github_login <-> $2
218 LIMIT $3
219 ";
220 Ok(sqlx::query_as(query)
221 .bind(like_string)
222 .bind(name_query)
223 .bind(limit as i32)
224 .fetch_all(&self.pool)
225 .await?)
226 })
227 }
228
229 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
230 test_support!(self, {
231 let query = "
232 SELECT users.*
233 FROM users
234 WHERE users.id = ANY ($1)
235 ";
236 Ok(sqlx::query_as(query)
237 .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
238 .fetch_all(&self.pool)
239 .await?)
240 })
241 }
242
243 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
244 test_support!(self, {
245 let query = "
246 SELECT metrics_id::text
247 FROM users
248 WHERE id = $1
249 ";
250 Ok(sqlx::query_scalar(query)
251 .bind(id)
252 .fetch_one(&self.pool)
253 .await?)
254 })
255 }
256
257 pub async fn create_user(
258 &self,
259 email_address: &str,
260 admin: bool,
261 params: NewUserParams,
262 ) -> Result<NewUserResult> {
263 test_support!(self, {
264 let query = "
265 INSERT INTO users (email_address, github_login, github_user_id, admin)
266 VALUES ($1, $2, $3, $4)
267 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
268 RETURNING id, metrics_id::text
269 ";
270
271 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
272 .bind(email_address)
273 .bind(params.github_login)
274 .bind(params.github_user_id)
275 .bind(admin)
276 .fetch_one(&self.pool)
277 .await?;
278 Ok(NewUserResult {
279 user_id,
280 metrics_id,
281 signup_device_id: None,
282 inviting_user_id: None,
283 })
284 })
285 }
286
287 pub async fn create_user_from_invite(
288 &self,
289 invite: &Invite,
290 user: NewUserParams,
291 ) -> Result<Option<NewUserResult>> {
292 test_support!(self, {
293 let mut tx = self.pool.begin().await?;
294
295 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
296 i32,
297 Option<UserId>,
298 Option<UserId>,
299 Option<String>,
300 ) = sqlx::query_as(
301 "
302 SELECT id, user_id, inviting_user_id, device_id
303 FROM signups
304 WHERE
305 email_address = $1 AND
306 email_confirmation_code = $2
307 ",
308 )
309 .bind(&invite.email_address)
310 .bind(&invite.email_confirmation_code)
311 .fetch_optional(&mut tx)
312 .await?
313 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
314
315 if existing_user_id.is_some() {
316 return Ok(None);
317 }
318
319 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
320 "
321 INSERT INTO users
322 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
323 VALUES
324 ($1, $2, $3, FALSE, $4, $5)
325 ON CONFLICT (github_login) DO UPDATE SET
326 email_address = excluded.email_address,
327 github_user_id = excluded.github_user_id,
328 admin = excluded.admin
329 RETURNING id, metrics_id::text
330 ",
331 )
332 .bind(&invite.email_address)
333 .bind(&user.github_login)
334 .bind(&user.github_user_id)
335 .bind(&user.invite_count)
336 .bind(random_invite_code())
337 .fetch_one(&mut tx)
338 .await?;
339
340 sqlx::query(
341 "
342 UPDATE signups
343 SET user_id = $1
344 WHERE id = $2
345 ",
346 )
347 .bind(&user_id)
348 .bind(&signup_id)
349 .execute(&mut tx)
350 .await?;
351
352 if let Some(inviting_user_id) = inviting_user_id {
353 let (user_id_a, user_id_b, a_to_b) = if inviting_user_id < user_id {
354 (inviting_user_id, user_id, true)
355 } else {
356 (user_id, inviting_user_id, false)
357 };
358
359 sqlx::query(
360 "
361 INSERT INTO contacts
362 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
363 VALUES
364 ($1, $2, $3, TRUE, TRUE)
365 ON CONFLICT DO NOTHING
366 ",
367 )
368 .bind(user_id_a)
369 .bind(user_id_b)
370 .bind(a_to_b)
371 .execute(&mut tx)
372 .await?;
373 }
374
375 tx.commit().await?;
376 Ok(Some(NewUserResult {
377 user_id,
378 metrics_id,
379 inviting_user_id,
380 signup_device_id,
381 }))
382 })
383 }
384
385 pub async fn create_signup(&self, signup: &Signup) -> Result<()> {
386 test_support!(self, {
387 sqlx::query(
388 "
389 INSERT INTO signups
390 (
391 email_address,
392 email_confirmation_code,
393 email_confirmation_sent,
394 platform_linux,
395 platform_mac,
396 platform_windows,
397 platform_unknown,
398 editor_features,
399 programming_languages,
400 device_id,
401 added_to_mailing_list
402 )
403 VALUES
404 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8, $9)
405 ON CONFLICT (email_address) DO UPDATE SET
406 email_address = excluded.email_address
407 RETURNING id
408 ",
409 )
410 .bind(&signup.email_address)
411 .bind(&random_email_confirmation_code())
412 .bind(&signup.platform_linux)
413 .bind(&signup.platform_mac)
414 .bind(&signup.platform_windows)
415 .bind(&signup.editor_features)
416 .bind(&signup.programming_languages)
417 .bind(&signup.device_id)
418 .bind(&signup.added_to_mailing_list)
419 .execute(&self.pool)
420 .await?;
421 Ok(())
422 })
423 }
424
425 pub async fn create_invite_from_code(
426 &self,
427 code: &str,
428 email_address: &str,
429 device_id: Option<&str>,
430 ) -> Result<Invite> {
431 test_support!(self, {
432 let mut tx = self.pool.begin().await?;
433
434 let existing_user: Option<UserId> = sqlx::query_scalar(
435 "
436 SELECT id
437 FROM users
438 WHERE email_address = $1
439 ",
440 )
441 .bind(email_address)
442 .fetch_optional(&mut tx)
443 .await?;
444 if existing_user.is_some() {
445 Err(anyhow!("email address is already in use"))?;
446 }
447
448 let inviting_user_id_with_invites: Option<UserId> = sqlx::query_scalar(
449 "
450 UPDATE users
451 SET invite_count = invite_count - 1
452 WHERE invite_code = $1 AND invite_count > 0
453 RETURNING id
454 ",
455 )
456 .bind(code)
457 .fetch_optional(&mut tx)
458 .await?;
459
460 let Some(inviter_id) = inviting_user_id_with_invites else {
461 return Err(Error::Http(
462 StatusCode::UNAUTHORIZED,
463 "unable to find an invite code with invites remaining".to_string(),
464 ));
465 };
466
467 let email_confirmation_code: String = sqlx::query_scalar(
468 "
469 INSERT INTO signups
470 (
471 email_address,
472 email_confirmation_code,
473 email_confirmation_sent,
474 inviting_user_id,
475 platform_linux,
476 platform_mac,
477 platform_windows,
478 platform_unknown,
479 device_id
480 )
481 VALUES
482 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
483 ON CONFLICT (email_address)
484 DO UPDATE SET
485 inviting_user_id = excluded.inviting_user_id
486 RETURNING email_confirmation_code
487 ",
488 )
489 .bind(&email_address)
490 .bind(&random_email_confirmation_code())
491 .bind(&inviter_id)
492 .bind(&device_id)
493 .fetch_one(&mut tx)
494 .await?;
495
496 tx.commit().await?;
497
498 Ok(Invite {
499 email_address: email_address.into(),
500 email_confirmation_code,
501 })
502 })
503 }
504
505 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
506 test_support!(self, {
507 let emails = invites
508 .iter()
509 .map(|s| s.email_address.as_str())
510 .collect::<Vec<_>>();
511 sqlx::query(
512 "
513 UPDATE signups
514 SET email_confirmation_sent = TRUE
515 WHERE email_address = ANY ($1)
516 ",
517 )
518 .bind(&emails)
519 .execute(&self.pool)
520 .await?;
521 Ok(())
522 })
523 }
524}
525
526impl<D> Db<D>
527where
528 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
529 D::Connection: sqlx::migrate::Migrate,
530 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
531 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
532 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
533 D::QueryResult: RowsAffected,
534 String: sqlx::Type<D>,
535 i32: sqlx::Type<D>,
536 i64: sqlx::Type<D>,
537 bool: sqlx::Type<D>,
538 str: sqlx::Type<D>,
539 Uuid: sqlx::Type<D>,
540 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
541 OffsetDateTime: sqlx::Type<D>,
542 PrimitiveDateTime: sqlx::Type<D>,
543 usize: sqlx::ColumnIndex<D::Row>,
544 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
545 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
546 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
547 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
548 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
549 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
550 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
551 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
552 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
553 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
554 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
555 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
556{
557 pub async fn migrate(
558 &self,
559 migrations_path: &Path,
560 ignore_checksum_mismatch: bool,
561 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
562 let migrations = MigrationSource::resolve(migrations_path)
563 .await
564 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
565
566 let mut conn = self.pool.acquire().await?;
567
568 conn.ensure_migrations_table().await?;
569 let applied_migrations: HashMap<_, _> = conn
570 .list_applied_migrations()
571 .await?
572 .into_iter()
573 .map(|m| (m.version, m))
574 .collect();
575
576 let mut new_migrations = Vec::new();
577 for migration in migrations {
578 match applied_migrations.get(&migration.version) {
579 Some(applied_migration) => {
580 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
581 {
582 Err(anyhow!(
583 "checksum mismatch for applied migration {}",
584 migration.description
585 ))?;
586 }
587 }
588 None => {
589 let elapsed = conn.apply(&migration).await?;
590 new_migrations.push((migration, elapsed));
591 }
592 }
593 }
594
595 Ok(new_migrations)
596 }
597
598 pub fn fuzzy_like_string(string: &str) -> String {
599 let mut result = String::with_capacity(string.len() * 2 + 1);
600 for c in string.chars() {
601 if c.is_alphanumeric() {
602 result.push('%');
603 result.push(c);
604 }
605 }
606 result.push('%');
607 result
608 }
609
610 // users
611
612 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
613 test_support!(self, {
614 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
615 Ok(sqlx::query_as(query)
616 .bind(limit as i32)
617 .bind((page * limit) as i32)
618 .fetch_all(&self.pool)
619 .await?)
620 })
621 }
622
623 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
624 test_support!(self, {
625 let query = "
626 SELECT users.*
627 FROM users
628 WHERE id = $1
629 LIMIT 1
630 ";
631 Ok(sqlx::query_as(query)
632 .bind(&id)
633 .fetch_optional(&self.pool)
634 .await?)
635 })
636 }
637
638 pub async fn get_users_with_no_invites(
639 &self,
640 invited_by_another_user: bool,
641 ) -> Result<Vec<User>> {
642 test_support!(self, {
643 let query = format!(
644 "
645 SELECT users.*
646 FROM users
647 WHERE invite_count = 0
648 AND inviter_id IS{} NULL
649 ",
650 if invited_by_another_user { " NOT" } else { "" }
651 );
652
653 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
654 })
655 }
656
657 pub async fn get_user_by_github_account(
658 &self,
659 github_login: &str,
660 github_user_id: Option<i32>,
661 ) -> Result<Option<User>> {
662 test_support!(self, {
663 if let Some(github_user_id) = github_user_id {
664 let mut user = sqlx::query_as::<_, User>(
665 "
666 UPDATE users
667 SET github_login = $1
668 WHERE github_user_id = $2
669 RETURNING *
670 ",
671 )
672 .bind(github_login)
673 .bind(github_user_id)
674 .fetch_optional(&self.pool)
675 .await?;
676
677 if user.is_none() {
678 user = sqlx::query_as::<_, User>(
679 "
680 UPDATE users
681 SET github_user_id = $1
682 WHERE github_login = $2
683 RETURNING *
684 ",
685 )
686 .bind(github_user_id)
687 .bind(github_login)
688 .fetch_optional(&self.pool)
689 .await?;
690 }
691
692 Ok(user)
693 } else {
694 let user = sqlx::query_as(
695 "
696 SELECT * FROM users
697 WHERE github_login = $1
698 LIMIT 1
699 ",
700 )
701 .bind(github_login)
702 .fetch_optional(&self.pool)
703 .await?;
704 Ok(user)
705 }
706 })
707 }
708
709 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
710 test_support!(self, {
711 let query = "UPDATE users SET admin = $1 WHERE id = $2";
712 Ok(sqlx::query(query)
713 .bind(is_admin)
714 .bind(id.0)
715 .execute(&self.pool)
716 .await
717 .map(drop)?)
718 })
719 }
720
721 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
722 test_support!(self, {
723 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
724 Ok(sqlx::query(query)
725 .bind(connected_once)
726 .bind(id.0)
727 .execute(&self.pool)
728 .await
729 .map(drop)?)
730 })
731 }
732
733 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
734 test_support!(self, {
735 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
736 sqlx::query(query)
737 .bind(id.0)
738 .execute(&self.pool)
739 .await
740 .map(drop)?;
741 let query = "DELETE FROM users WHERE id = $1;";
742 Ok(sqlx::query(query)
743 .bind(id.0)
744 .execute(&self.pool)
745 .await
746 .map(drop)?)
747 })
748 }
749
750 // signups
751
752 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
753 test_support!(self, {
754 Ok(sqlx::query_as(
755 "
756 SELECT
757 COUNT(*) as count,
758 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
759 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
760 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
761 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
762 FROM (
763 SELECT *
764 FROM signups
765 WHERE
766 NOT email_confirmation_sent
767 ) AS unsent
768 ",
769 )
770 .fetch_one(&self.pool)
771 .await?)
772 })
773 }
774
775 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
776 test_support!(self, {
777 Ok(sqlx::query_as(
778 "
779 SELECT
780 email_address, email_confirmation_code
781 FROM signups
782 WHERE
783 NOT email_confirmation_sent AND
784 (platform_mac OR platform_unknown)
785 ORDER BY
786 created_at
787 LIMIT $1
788 ",
789 )
790 .bind(count as i32)
791 .fetch_all(&self.pool)
792 .await?)
793 })
794 }
795
796 // invite codes
797
798 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
799 test_support!(self, {
800 let mut tx = self.pool.begin().await?;
801 if count > 0 {
802 sqlx::query(
803 "
804 UPDATE users
805 SET invite_code = $1
806 WHERE id = $2 AND invite_code IS NULL
807 ",
808 )
809 .bind(random_invite_code())
810 .bind(id)
811 .execute(&mut tx)
812 .await?;
813 }
814
815 sqlx::query(
816 "
817 UPDATE users
818 SET invite_count = $1
819 WHERE id = $2
820 ",
821 )
822 .bind(count as i32)
823 .bind(id)
824 .execute(&mut tx)
825 .await?;
826 tx.commit().await?;
827 Ok(())
828 })
829 }
830
831 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
832 test_support!(self, {
833 let result: Option<(String, i32)> = sqlx::query_as(
834 "
835 SELECT invite_code, invite_count
836 FROM users
837 WHERE id = $1 AND invite_code IS NOT NULL
838 ",
839 )
840 .bind(id)
841 .fetch_optional(&self.pool)
842 .await?;
843 if let Some((code, count)) = result {
844 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
845 } else {
846 Ok(None)
847 }
848 })
849 }
850
851 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
852 test_support!(self, {
853 sqlx::query_as(
854 "
855 SELECT *
856 FROM users
857 WHERE invite_code = $1
858 ",
859 )
860 .bind(code)
861 .fetch_optional(&self.pool)
862 .await?
863 .ok_or_else(|| {
864 Error::Http(
865 StatusCode::NOT_FOUND,
866 "that invite code does not exist".to_string(),
867 )
868 })
869 })
870 }
871
872 // projects
873
874 /// Registers a new project for the given user.
875 pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
876 test_support!(self, {
877 Ok(sqlx::query_scalar(
878 "
879 INSERT INTO projects(host_user_id)
880 VALUES ($1)
881 RETURNING id
882 ",
883 )
884 .bind(host_user_id)
885 .fetch_one(&self.pool)
886 .await
887 .map(ProjectId)?)
888 })
889 }
890
891 /// Unregisters a project for the given project id.
892 pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
893 test_support!(self, {
894 sqlx::query(
895 "
896 UPDATE projects
897 SET unregistered = TRUE
898 WHERE id = $1
899 ",
900 )
901 .bind(project_id)
902 .execute(&self.pool)
903 .await?;
904 Ok(())
905 })
906 }
907
908 // contacts
909
910 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
911 test_support!(self, {
912 let query = "
913 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
914 FROM contacts
915 WHERE user_id_a = $1 OR user_id_b = $1;
916 ";
917
918 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
919 .bind(user_id)
920 .fetch(&self.pool);
921
922 let mut contacts = Vec::new();
923 while let Some(row) = rows.next().await {
924 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
925
926 if user_id_a == user_id {
927 if accepted {
928 contacts.push(Contact::Accepted {
929 user_id: user_id_b,
930 should_notify: should_notify && a_to_b,
931 });
932 } else if a_to_b {
933 contacts.push(Contact::Outgoing { user_id: user_id_b })
934 } else {
935 contacts.push(Contact::Incoming {
936 user_id: user_id_b,
937 should_notify,
938 });
939 }
940 } else if accepted {
941 contacts.push(Contact::Accepted {
942 user_id: user_id_a,
943 should_notify: should_notify && !a_to_b,
944 });
945 } else if a_to_b {
946 contacts.push(Contact::Incoming {
947 user_id: user_id_a,
948 should_notify,
949 });
950 } else {
951 contacts.push(Contact::Outgoing { user_id: user_id_a });
952 }
953 }
954
955 contacts.sort_unstable_by_key(|contact| contact.user_id());
956
957 Ok(contacts)
958 })
959 }
960
961 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
962 test_support!(self, {
963 let (id_a, id_b) = if user_id_1 < user_id_2 {
964 (user_id_1, user_id_2)
965 } else {
966 (user_id_2, user_id_1)
967 };
968
969 let query = "
970 SELECT 1 FROM contacts
971 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
972 LIMIT 1
973 ";
974 Ok(sqlx::query_scalar::<_, i32>(query)
975 .bind(id_a.0)
976 .bind(id_b.0)
977 .fetch_optional(&self.pool)
978 .await?
979 .is_some())
980 })
981 }
982
983 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
984 test_support!(self, {
985 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
986 (sender_id, receiver_id, true)
987 } else {
988 (receiver_id, sender_id, false)
989 };
990 let query = "
991 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
992 VALUES ($1, $2, $3, FALSE, TRUE)
993 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
994 SET
995 accepted = TRUE,
996 should_notify = FALSE
997 WHERE
998 NOT contacts.accepted AND
999 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1000 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1001 ";
1002 let result = sqlx::query(query)
1003 .bind(id_a.0)
1004 .bind(id_b.0)
1005 .bind(a_to_b)
1006 .execute(&self.pool)
1007 .await?;
1008
1009 if result.rows_affected() == 1 {
1010 Ok(())
1011 } else {
1012 Err(anyhow!("contact already requested"))?
1013 }
1014 })
1015 }
1016
1017 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1018 test_support!(self, {
1019 let (id_a, id_b) = if responder_id < requester_id {
1020 (responder_id, requester_id)
1021 } else {
1022 (requester_id, responder_id)
1023 };
1024 let query = "
1025 DELETE FROM contacts
1026 WHERE user_id_a = $1 AND user_id_b = $2;
1027 ";
1028 let result = sqlx::query(query)
1029 .bind(id_a.0)
1030 .bind(id_b.0)
1031 .execute(&self.pool)
1032 .await?;
1033
1034 if result.rows_affected() == 1 {
1035 Ok(())
1036 } else {
1037 Err(anyhow!("no such contact"))?
1038 }
1039 })
1040 }
1041
1042 pub async fn dismiss_contact_notification(
1043 &self,
1044 user_id: UserId,
1045 contact_user_id: UserId,
1046 ) -> Result<()> {
1047 test_support!(self, {
1048 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1049 (user_id, contact_user_id, true)
1050 } else {
1051 (contact_user_id, user_id, false)
1052 };
1053
1054 let query = "
1055 UPDATE contacts
1056 SET should_notify = FALSE
1057 WHERE
1058 user_id_a = $1 AND user_id_b = $2 AND
1059 (
1060 (a_to_b = $3 AND accepted) OR
1061 (a_to_b != $3 AND NOT accepted)
1062 );
1063 ";
1064
1065 let result = sqlx::query(query)
1066 .bind(id_a.0)
1067 .bind(id_b.0)
1068 .bind(a_to_b)
1069 .execute(&self.pool)
1070 .await?;
1071
1072 if result.rows_affected() == 0 {
1073 Err(anyhow!("no such contact request"))?;
1074 }
1075
1076 Ok(())
1077 })
1078 }
1079
1080 pub async fn respond_to_contact_request(
1081 &self,
1082 responder_id: UserId,
1083 requester_id: UserId,
1084 accept: bool,
1085 ) -> Result<()> {
1086 test_support!(self, {
1087 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1088 (responder_id, requester_id, false)
1089 } else {
1090 (requester_id, responder_id, true)
1091 };
1092 let result = if accept {
1093 let query = "
1094 UPDATE contacts
1095 SET accepted = TRUE, should_notify = TRUE
1096 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1097 ";
1098 sqlx::query(query)
1099 .bind(id_a.0)
1100 .bind(id_b.0)
1101 .bind(a_to_b)
1102 .execute(&self.pool)
1103 .await?
1104 } else {
1105 let query = "
1106 DELETE FROM contacts
1107 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1108 ";
1109 sqlx::query(query)
1110 .bind(id_a.0)
1111 .bind(id_b.0)
1112 .bind(a_to_b)
1113 .execute(&self.pool)
1114 .await?
1115 };
1116 if result.rows_affected() == 1 {
1117 Ok(())
1118 } else {
1119 Err(anyhow!("no such contact request"))?
1120 }
1121 })
1122 }
1123
1124 // access tokens
1125
1126 pub async fn create_access_token_hash(
1127 &self,
1128 user_id: UserId,
1129 access_token_hash: &str,
1130 max_access_token_count: usize,
1131 ) -> Result<()> {
1132 test_support!(self, {
1133 let insert_query = "
1134 INSERT INTO access_tokens (user_id, hash)
1135 VALUES ($1, $2);
1136 ";
1137 let cleanup_query = "
1138 DELETE FROM access_tokens
1139 WHERE id IN (
1140 SELECT id from access_tokens
1141 WHERE user_id = $1
1142 ORDER BY id DESC
1143 LIMIT 10000
1144 OFFSET $3
1145 )
1146 ";
1147
1148 let mut tx = self.pool.begin().await?;
1149 sqlx::query(insert_query)
1150 .bind(user_id.0)
1151 .bind(access_token_hash)
1152 .execute(&mut tx)
1153 .await?;
1154 sqlx::query(cleanup_query)
1155 .bind(user_id.0)
1156 .bind(access_token_hash)
1157 .bind(max_access_token_count as i32)
1158 .execute(&mut tx)
1159 .await?;
1160 Ok(tx.commit().await?)
1161 })
1162 }
1163
1164 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1165 test_support!(self, {
1166 let query = "
1167 SELECT hash
1168 FROM access_tokens
1169 WHERE user_id = $1
1170 ORDER BY id DESC
1171 ";
1172 Ok(sqlx::query_scalar(query)
1173 .bind(user_id.0)
1174 .fetch_all(&self.pool)
1175 .await?)
1176 })
1177 }
1178}
1179
1180macro_rules! id_type {
1181 ($name:ident) => {
1182 #[derive(
1183 Clone,
1184 Copy,
1185 Debug,
1186 Default,
1187 PartialEq,
1188 Eq,
1189 PartialOrd,
1190 Ord,
1191 Hash,
1192 sqlx::Type,
1193 Serialize,
1194 Deserialize,
1195 )]
1196 #[sqlx(transparent)]
1197 #[serde(transparent)]
1198 pub struct $name(pub i32);
1199
1200 impl $name {
1201 #[allow(unused)]
1202 pub const MAX: Self = Self(i32::MAX);
1203
1204 #[allow(unused)]
1205 pub fn from_proto(value: u64) -> Self {
1206 Self(value as i32)
1207 }
1208
1209 #[allow(unused)]
1210 pub fn to_proto(self) -> u64 {
1211 self.0 as u64
1212 }
1213 }
1214
1215 impl std::fmt::Display for $name {
1216 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1217 self.0.fmt(f)
1218 }
1219 }
1220 };
1221}
1222
1223id_type!(UserId);
1224#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1225pub struct User {
1226 pub id: UserId,
1227 pub github_login: String,
1228 pub github_user_id: Option<i32>,
1229 pub email_address: Option<String>,
1230 pub admin: bool,
1231 pub invite_code: Option<String>,
1232 pub invite_count: i32,
1233 pub connected_once: bool,
1234}
1235
1236id_type!(ProjectId);
1237#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1238pub struct Project {
1239 pub id: ProjectId,
1240 pub host_user_id: UserId,
1241 pub unregistered: bool,
1242}
1243
1244#[derive(Clone, Debug, PartialEq, Eq)]
1245pub enum Contact {
1246 Accepted {
1247 user_id: UserId,
1248 should_notify: bool,
1249 },
1250 Outgoing {
1251 user_id: UserId,
1252 },
1253 Incoming {
1254 user_id: UserId,
1255 should_notify: bool,
1256 },
1257}
1258
1259impl Contact {
1260 pub fn user_id(&self) -> UserId {
1261 match self {
1262 Contact::Accepted { user_id, .. } => *user_id,
1263 Contact::Outgoing { user_id } => *user_id,
1264 Contact::Incoming { user_id, .. } => *user_id,
1265 }
1266 }
1267}
1268
1269#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1270pub struct IncomingContactRequest {
1271 pub requester_id: UserId,
1272 pub should_notify: bool,
1273}
1274
1275#[derive(Clone, Deserialize, Default)]
1276pub struct Signup {
1277 pub email_address: String,
1278 pub platform_mac: bool,
1279 pub platform_windows: bool,
1280 pub platform_linux: bool,
1281 pub editor_features: Vec<String>,
1282 pub programming_languages: Vec<String>,
1283 pub device_id: Option<String>,
1284 pub added_to_mailing_list: bool,
1285}
1286
1287#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1288pub struct WaitlistSummary {
1289 #[sqlx(default)]
1290 pub count: i64,
1291 #[sqlx(default)]
1292 pub linux_count: i64,
1293 #[sqlx(default)]
1294 pub mac_count: i64,
1295 #[sqlx(default)]
1296 pub windows_count: i64,
1297 #[sqlx(default)]
1298 pub unknown_count: i64,
1299}
1300
1301#[derive(Clone, FromRow, PartialEq, Debug, Serialize, Deserialize)]
1302pub struct Invite {
1303 pub email_address: String,
1304 pub email_confirmation_code: String,
1305}
1306
1307#[derive(Debug, Serialize, Deserialize)]
1308pub struct NewUserParams {
1309 pub github_login: String,
1310 pub github_user_id: i32,
1311 pub invite_count: i32,
1312}
1313
1314#[derive(Debug)]
1315pub struct NewUserResult {
1316 pub user_id: UserId,
1317 pub metrics_id: String,
1318 pub inviting_user_id: Option<UserId>,
1319 pub signup_device_id: Option<String>,
1320}
1321
1322fn random_invite_code() -> String {
1323 nanoid::nanoid!(16)
1324}
1325
1326fn random_email_confirmation_code() -> String {
1327 nanoid::nanoid!(64)
1328}
1329
1330#[cfg(test)]
1331pub use test::*;
1332
1333#[cfg(test)]
1334mod test {
1335 use super::*;
1336 use gpui::executor::Background;
1337 use lazy_static::lazy_static;
1338 use parking_lot::Mutex;
1339 use rand::prelude::*;
1340 use sqlx::migrate::MigrateDatabase;
1341 use std::sync::Arc;
1342
1343 pub struct SqliteTestDb {
1344 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1345 pub conn: sqlx::sqlite::SqliteConnection,
1346 }
1347
1348 pub struct PostgresTestDb {
1349 pub db: Option<Arc<Db<sqlx::Postgres>>>,
1350 pub url: String,
1351 }
1352
1353 impl SqliteTestDb {
1354 pub fn new(background: Arc<Background>) -> Self {
1355 let mut rng = StdRng::from_entropy();
1356 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1357 let runtime = tokio::runtime::Builder::new_current_thread()
1358 .enable_io()
1359 .enable_time()
1360 .build()
1361 .unwrap();
1362
1363 let (mut db, conn) = runtime.block_on(async {
1364 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1365 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1366 db.migrate(migrations_path.as_ref(), false).await.unwrap();
1367 let conn = db.pool.acquire().await.unwrap().detach();
1368 (db, conn)
1369 });
1370
1371 db.background = Some(background);
1372 db.runtime = Some(runtime);
1373
1374 Self {
1375 db: Some(Arc::new(db)),
1376 conn,
1377 }
1378 }
1379
1380 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1381 self.db.as_ref().unwrap()
1382 }
1383 }
1384
1385 impl PostgresTestDb {
1386 pub fn new(background: Arc<Background>) -> Self {
1387 lazy_static! {
1388 static ref LOCK: Mutex<()> = Mutex::new(());
1389 }
1390
1391 let _guard = LOCK.lock();
1392 let mut rng = StdRng::from_entropy();
1393 let url = format!(
1394 "postgres://postgres@localhost/zed-test-{}",
1395 rng.gen::<u128>()
1396 );
1397 let runtime = tokio::runtime::Builder::new_current_thread()
1398 .enable_io()
1399 .enable_time()
1400 .build()
1401 .unwrap();
1402
1403 let mut db = runtime.block_on(async {
1404 sqlx::Postgres::create_database(&url)
1405 .await
1406 .expect("failed to create test db");
1407 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1408 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1409 db.migrate(Path::new(migrations_path), false).await.unwrap();
1410 db
1411 });
1412
1413 db.background = Some(background);
1414 db.runtime = Some(runtime);
1415
1416 Self {
1417 db: Some(Arc::new(db)),
1418 url,
1419 }
1420 }
1421
1422 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1423 self.db.as_ref().unwrap()
1424 }
1425 }
1426
1427 impl Drop for PostgresTestDb {
1428 fn drop(&mut self) {
1429 let db = self.db.take().unwrap();
1430 db.teardown(&self.url);
1431 }
1432 }
1433}