1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use sqlx::{
8 migrate::{Migrate as _, Migration, MigrationSource},
9 types::Uuid,
10 FromRow,
11};
12use std::{path::Path, time::Duration};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[cfg(test)]
16pub type DefaultDb = Db<sqlx::Sqlite>;
17
18#[cfg(not(test))]
19pub type DefaultDb = Db<sqlx::Postgres>;
20
21pub struct Db<D: sqlx::Database> {
22 pool: sqlx::Pool<D>,
23 #[cfg(test)]
24 background: Option<std::sync::Arc<gpui::executor::Background>>,
25 #[cfg(test)]
26 runtime: Option<tokio::runtime::Runtime>,
27}
28
29macro_rules! test_support {
30 ($self:ident, { $($token:tt)* }) => {{
31 let body = async {
32 $($token)*
33 };
34
35 if cfg!(test) {
36 #[cfg(not(test))]
37 unreachable!();
38
39 #[cfg(test)]
40 if let Some(background) = $self.background.as_ref() {
41 background.simulate_random_delay().await;
42 }
43
44 #[cfg(test)]
45 $self.runtime.as_ref().unwrap().block_on(body)
46 } else {
47 body.await
48 }
49 }};
50}
51
52pub trait RowsAffected {
53 fn rows_affected(&self) -> u64;
54}
55
56#[cfg(test)]
57impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
58 fn rows_affected(&self) -> u64 {
59 self.rows_affected()
60 }
61}
62
63impl RowsAffected for sqlx::postgres::PgQueryResult {
64 fn rows_affected(&self) -> u64 {
65 self.rows_affected()
66 }
67}
68
69#[cfg(test)]
70impl Db<sqlx::Sqlite> {
71 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
72 use std::str::FromStr as _;
73 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
74 .unwrap()
75 .create_if_missing(true)
76 .shared_cache(true);
77 let pool = sqlx::sqlite::SqlitePoolOptions::new()
78 .min_connections(2)
79 .max_connections(max_connections)
80 .connect_with(options)
81 .await?;
82 Ok(Self {
83 pool,
84 background: None,
85 runtime: None,
86 })
87 }
88
89 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
90 test_support!(self, {
91 let query = "
92 SELECT users.*
93 FROM users
94 WHERE users.id IN (SELECT value from json_each($1))
95 ";
96 Ok(sqlx::query_as(query)
97 .bind(&serde_json::json!(ids))
98 .fetch_all(&self.pool)
99 .await?)
100 })
101 }
102
103 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
104 test_support!(self, {
105 let query = "
106 SELECT metrics_id
107 FROM users
108 WHERE id = $1
109 ";
110 Ok(sqlx::query_scalar(query)
111 .bind(id)
112 .fetch_one(&self.pool)
113 .await?)
114 })
115 }
116
117 pub async fn create_user(
118 &self,
119 email_address: &str,
120 admin: bool,
121 params: NewUserParams,
122 ) -> Result<NewUserResult> {
123 test_support!(self, {
124 let query = "
125 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
126 VALUES ($1, $2, $3, $4, $5)
127 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
128 RETURNING id, metrics_id
129 ";
130
131 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
132 .bind(email_address)
133 .bind(params.github_login)
134 .bind(params.github_user_id)
135 .bind(admin)
136 .bind(Uuid::new_v4().to_string())
137 .fetch_one(&self.pool)
138 .await?;
139 Ok(NewUserResult {
140 user_id,
141 metrics_id,
142 signup_device_id: None,
143 inviting_user_id: None,
144 })
145 })
146 }
147
148 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
149 unimplemented!()
150 }
151
152 pub async fn create_user_from_invite(
153 &self,
154 _invite: &Invite,
155 _user: NewUserParams,
156 ) -> Result<Option<NewUserResult>> {
157 unimplemented!()
158 }
159
160 pub async fn create_signup(&self, _signup: &Signup) -> Result<()> {
161 unimplemented!()
162 }
163
164 pub async fn create_invite_from_code(
165 &self,
166 _code: &str,
167 _email_address: &str,
168 _device_id: Option<&str>,
169 ) -> Result<Invite> {
170 unimplemented!()
171 }
172
173 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
174 unimplemented!()
175 }
176}
177
178impl Db<sqlx::Postgres> {
179 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
180 let pool = sqlx::postgres::PgPoolOptions::new()
181 .max_connections(max_connections)
182 .connect(url)
183 .await?;
184 Ok(Self {
185 pool,
186 #[cfg(test)]
187 background: None,
188 #[cfg(test)]
189 runtime: None,
190 })
191 }
192
193 #[cfg(test)]
194 pub fn teardown(&self, url: &str) {
195 self.runtime.as_ref().unwrap().block_on(async {
196 use util::ResultExt;
197 let query = "
198 SELECT pg_terminate_backend(pg_stat_activity.pid)
199 FROM pg_stat_activity
200 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
201 ";
202 sqlx::query(query).execute(&self.pool).await.log_err();
203 self.pool.close().await;
204 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
205 .await
206 .log_err();
207 })
208 }
209
210 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
211 test_support!(self, {
212 let like_string = Self::fuzzy_like_string(name_query);
213 let query = "
214 SELECT users.*
215 FROM users
216 WHERE github_login ILIKE $1
217 ORDER BY github_login <-> $2
218 LIMIT $3
219 ";
220 Ok(sqlx::query_as(query)
221 .bind(like_string)
222 .bind(name_query)
223 .bind(limit as i32)
224 .fetch_all(&self.pool)
225 .await?)
226 })
227 }
228
229 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
230 test_support!(self, {
231 let query = "
232 SELECT users.*
233 FROM users
234 WHERE users.id = ANY ($1)
235 ";
236 Ok(sqlx::query_as(query)
237 .bind(&ids.into_iter().map(|id| id.0).collect::<Vec<_>>())
238 .fetch_all(&self.pool)
239 .await?)
240 })
241 }
242
243 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
244 test_support!(self, {
245 let query = "
246 SELECT metrics_id::text
247 FROM users
248 WHERE id = $1
249 ";
250 Ok(sqlx::query_scalar(query)
251 .bind(id)
252 .fetch_one(&self.pool)
253 .await?)
254 })
255 }
256
257 pub async fn create_user(
258 &self,
259 email_address: &str,
260 admin: bool,
261 params: NewUserParams,
262 ) -> Result<NewUserResult> {
263 test_support!(self, {
264 let query = "
265 INSERT INTO users (email_address, github_login, github_user_id, admin)
266 VALUES ($1, $2, $3, $4)
267 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
268 RETURNING id, metrics_id::text
269 ";
270
271 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
272 .bind(email_address)
273 .bind(params.github_login)
274 .bind(params.github_user_id)
275 .bind(admin)
276 .fetch_one(&self.pool)
277 .await?;
278 Ok(NewUserResult {
279 user_id,
280 metrics_id,
281 signup_device_id: None,
282 inviting_user_id: None,
283 })
284 })
285 }
286
287 pub async fn create_user_from_invite(
288 &self,
289 invite: &Invite,
290 user: NewUserParams,
291 ) -> Result<Option<NewUserResult>> {
292 test_support!(self, {
293 let mut tx = self.pool.begin().await?;
294
295 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
296 i32,
297 Option<UserId>,
298 Option<UserId>,
299 Option<String>,
300 ) = sqlx::query_as(
301 "
302 SELECT id, user_id, inviting_user_id, device_id
303 FROM signups
304 WHERE
305 email_address = $1 AND
306 email_confirmation_code = $2
307 ",
308 )
309 .bind(&invite.email_address)
310 .bind(&invite.email_confirmation_code)
311 .fetch_optional(&mut tx)
312 .await?
313 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
314
315 if existing_user_id.is_some() {
316 return Ok(None);
317 }
318
319 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
320 "
321 INSERT INTO users
322 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
323 VALUES
324 ($1, $2, $3, FALSE, $4, $5)
325 ON CONFLICT (github_login) DO UPDATE SET
326 email_address = excluded.email_address,
327 github_user_id = excluded.github_user_id,
328 admin = excluded.admin
329 RETURNING id, metrics_id::text
330 ",
331 )
332 .bind(&invite.email_address)
333 .bind(&user.github_login)
334 .bind(&user.github_user_id)
335 .bind(&user.invite_count)
336 .bind(random_invite_code())
337 .fetch_one(&mut tx)
338 .await?;
339
340 sqlx::query(
341 "
342 UPDATE signups
343 SET user_id = $1
344 WHERE id = $2
345 ",
346 )
347 .bind(&user_id)
348 .bind(&signup_id)
349 .execute(&mut tx)
350 .await?;
351
352 if let Some(inviting_user_id) = inviting_user_id {
353 sqlx::query(
354 "
355 INSERT INTO contacts
356 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
357 VALUES
358 ($1, $2, TRUE, TRUE, TRUE)
359 ON CONFLICT DO NOTHING
360 ",
361 )
362 .bind(inviting_user_id)
363 .bind(user_id)
364 .execute(&mut tx)
365 .await?;
366 }
367
368 tx.commit().await?;
369 Ok(Some(NewUserResult {
370 user_id,
371 metrics_id,
372 inviting_user_id,
373 signup_device_id,
374 }))
375 })
376 }
377
378 pub async fn create_signup(&self, signup: &Signup) -> Result<()> {
379 test_support!(self, {
380 sqlx::query(
381 "
382 INSERT INTO signups
383 (
384 email_address,
385 email_confirmation_code,
386 email_confirmation_sent,
387 platform_linux,
388 platform_mac,
389 platform_windows,
390 platform_unknown,
391 editor_features,
392 programming_languages,
393 device_id,
394 added_to_mailing_list
395 )
396 VALUES
397 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8, $9)
398 ON CONFLICT (email_address) DO UPDATE SET
399 email_address = excluded.email_address
400 RETURNING id
401 ",
402 )
403 .bind(&signup.email_address)
404 .bind(&random_email_confirmation_code())
405 .bind(&signup.platform_linux)
406 .bind(&signup.platform_mac)
407 .bind(&signup.platform_windows)
408 .bind(&signup.editor_features)
409 .bind(&signup.programming_languages)
410 .bind(&signup.device_id)
411 .bind(&signup.added_to_mailing_list)
412 .execute(&self.pool)
413 .await?;
414 Ok(())
415 })
416 }
417
418 pub async fn create_invite_from_code(
419 &self,
420 code: &str,
421 email_address: &str,
422 device_id: Option<&str>,
423 ) -> Result<Invite> {
424 test_support!(self, {
425 let mut tx = self.pool.begin().await?;
426
427 let existing_user: Option<UserId> = sqlx::query_scalar(
428 "
429 SELECT id
430 FROM users
431 WHERE email_address = $1
432 ",
433 )
434 .bind(email_address)
435 .fetch_optional(&mut tx)
436 .await?;
437 if existing_user.is_some() {
438 Err(anyhow!("email address is already in use"))?;
439 }
440
441 let inviting_user_id_with_invites: Option<UserId> = sqlx::query_scalar(
442 "
443 UPDATE users
444 SET invite_count = invite_count - 1
445 WHERE invite_code = $1 AND invite_count > 0
446 RETURNING id
447 ",
448 )
449 .bind(code)
450 .fetch_optional(&mut tx)
451 .await?;
452
453 let Some(inviter_id) = inviting_user_id_with_invites else {
454 return Err(Error::Http(
455 StatusCode::UNAUTHORIZED,
456 "unable to find an invite code with invites remaining".to_string(),
457 ));
458 };
459
460 let email_confirmation_code: String = sqlx::query_scalar(
461 "
462 INSERT INTO signups
463 (
464 email_address,
465 email_confirmation_code,
466 email_confirmation_sent,
467 inviting_user_id,
468 platform_linux,
469 platform_mac,
470 platform_windows,
471 platform_unknown,
472 device_id
473 )
474 VALUES
475 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
476 ON CONFLICT (email_address)
477 DO UPDATE SET
478 inviting_user_id = excluded.inviting_user_id
479 RETURNING email_confirmation_code
480 ",
481 )
482 .bind(&email_address)
483 .bind(&random_email_confirmation_code())
484 .bind(&inviter_id)
485 .bind(&device_id)
486 .fetch_one(&mut tx)
487 .await?;
488
489 tx.commit().await?;
490
491 Ok(Invite {
492 email_address: email_address.into(),
493 email_confirmation_code,
494 })
495 })
496 }
497
498 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
499 test_support!(self, {
500 let emails = invites
501 .iter()
502 .map(|s| s.email_address.as_str())
503 .collect::<Vec<_>>();
504 sqlx::query(
505 "
506 UPDATE signups
507 SET email_confirmation_sent = TRUE
508 WHERE email_address = ANY ($1)
509 ",
510 )
511 .bind(&emails)
512 .execute(&self.pool)
513 .await?;
514 Ok(())
515 })
516 }
517}
518
519impl<D> Db<D>
520where
521 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
522 D::Connection: sqlx::migrate::Migrate,
523 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
524 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
525 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
526 D::QueryResult: RowsAffected,
527 String: sqlx::Type<D>,
528 i32: sqlx::Type<D>,
529 i64: sqlx::Type<D>,
530 bool: sqlx::Type<D>,
531 str: sqlx::Type<D>,
532 Uuid: sqlx::Type<D>,
533 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
534 OffsetDateTime: sqlx::Type<D>,
535 PrimitiveDateTime: sqlx::Type<D>,
536 usize: sqlx::ColumnIndex<D::Row>,
537 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
538 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
539 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
540 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
541 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
542 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
543 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
544 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
545 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
546 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
547 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
548 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
549{
550 pub async fn migrate(
551 &self,
552 migrations_path: &Path,
553 ignore_checksum_mismatch: bool,
554 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
555 let migrations = MigrationSource::resolve(migrations_path)
556 .await
557 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
558
559 let mut conn = self.pool.acquire().await?;
560
561 conn.ensure_migrations_table().await?;
562 let applied_migrations: HashMap<_, _> = conn
563 .list_applied_migrations()
564 .await?
565 .into_iter()
566 .map(|m| (m.version, m))
567 .collect();
568
569 let mut new_migrations = Vec::new();
570 for migration in migrations {
571 match applied_migrations.get(&migration.version) {
572 Some(applied_migration) => {
573 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
574 {
575 Err(anyhow!(
576 "checksum mismatch for applied migration {}",
577 migration.description
578 ))?;
579 }
580 }
581 None => {
582 let elapsed = conn.apply(&migration).await?;
583 new_migrations.push((migration, elapsed));
584 }
585 }
586 }
587
588 Ok(new_migrations)
589 }
590
591 pub fn fuzzy_like_string(string: &str) -> String {
592 let mut result = String::with_capacity(string.len() * 2 + 1);
593 for c in string.chars() {
594 if c.is_alphanumeric() {
595 result.push('%');
596 result.push(c);
597 }
598 }
599 result.push('%');
600 result
601 }
602
603 // users
604
605 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
606 test_support!(self, {
607 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
608 Ok(sqlx::query_as(query)
609 .bind(limit as i32)
610 .bind((page * limit) as i32)
611 .fetch_all(&self.pool)
612 .await?)
613 })
614 }
615
616 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
617 test_support!(self, {
618 let query = "
619 SELECT users.*
620 FROM users
621 WHERE id = $1
622 LIMIT 1
623 ";
624 Ok(sqlx::query_as(query)
625 .bind(&id)
626 .fetch_optional(&self.pool)
627 .await?)
628 })
629 }
630
631 pub async fn get_users_with_no_invites(
632 &self,
633 invited_by_another_user: bool,
634 ) -> Result<Vec<User>> {
635 test_support!(self, {
636 let query = format!(
637 "
638 SELECT users.*
639 FROM users
640 WHERE invite_count = 0
641 AND inviter_id IS{} NULL
642 ",
643 if invited_by_another_user { " NOT" } else { "" }
644 );
645
646 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
647 })
648 }
649
650 pub async fn get_user_by_github_account(
651 &self,
652 github_login: &str,
653 github_user_id: Option<i32>,
654 ) -> Result<Option<User>> {
655 test_support!(self, {
656 if let Some(github_user_id) = github_user_id {
657 let mut user = sqlx::query_as::<_, User>(
658 "
659 UPDATE users
660 SET github_login = $1
661 WHERE github_user_id = $2
662 RETURNING *
663 ",
664 )
665 .bind(github_login)
666 .bind(github_user_id)
667 .fetch_optional(&self.pool)
668 .await?;
669
670 if user.is_none() {
671 user = sqlx::query_as::<_, User>(
672 "
673 UPDATE users
674 SET github_user_id = $1
675 WHERE github_login = $2
676 RETURNING *
677 ",
678 )
679 .bind(github_user_id)
680 .bind(github_login)
681 .fetch_optional(&self.pool)
682 .await?;
683 }
684
685 Ok(user)
686 } else {
687 let user = sqlx::query_as(
688 "
689 SELECT * FROM users
690 WHERE github_login = $1
691 LIMIT 1
692 ",
693 )
694 .bind(github_login)
695 .fetch_optional(&self.pool)
696 .await?;
697 Ok(user)
698 }
699 })
700 }
701
702 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
703 test_support!(self, {
704 let query = "UPDATE users SET admin = $1 WHERE id = $2";
705 Ok(sqlx::query(query)
706 .bind(is_admin)
707 .bind(id.0)
708 .execute(&self.pool)
709 .await
710 .map(drop)?)
711 })
712 }
713
714 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
715 test_support!(self, {
716 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
717 Ok(sqlx::query(query)
718 .bind(connected_once)
719 .bind(id.0)
720 .execute(&self.pool)
721 .await
722 .map(drop)?)
723 })
724 }
725
726 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
727 test_support!(self, {
728 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
729 sqlx::query(query)
730 .bind(id.0)
731 .execute(&self.pool)
732 .await
733 .map(drop)?;
734 let query = "DELETE FROM users WHERE id = $1;";
735 Ok(sqlx::query(query)
736 .bind(id.0)
737 .execute(&self.pool)
738 .await
739 .map(drop)?)
740 })
741 }
742
743 // signups
744
745 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
746 test_support!(self, {
747 Ok(sqlx::query_as(
748 "
749 SELECT
750 COUNT(*) as count,
751 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
752 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
753 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
754 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
755 FROM (
756 SELECT *
757 FROM signups
758 WHERE
759 NOT email_confirmation_sent
760 ) AS unsent
761 ",
762 )
763 .fetch_one(&self.pool)
764 .await?)
765 })
766 }
767
768 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
769 test_support!(self, {
770 Ok(sqlx::query_as(
771 "
772 SELECT
773 email_address, email_confirmation_code
774 FROM signups
775 WHERE
776 NOT email_confirmation_sent AND
777 (platform_mac OR platform_unknown)
778 ORDER BY
779 created_at
780 LIMIT $1
781 ",
782 )
783 .bind(count as i32)
784 .fetch_all(&self.pool)
785 .await?)
786 })
787 }
788
789 // invite codes
790
791 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
792 test_support!(self, {
793 let mut tx = self.pool.begin().await?;
794 if count > 0 {
795 sqlx::query(
796 "
797 UPDATE users
798 SET invite_code = $1
799 WHERE id = $2 AND invite_code IS NULL
800 ",
801 )
802 .bind(random_invite_code())
803 .bind(id)
804 .execute(&mut tx)
805 .await?;
806 }
807
808 sqlx::query(
809 "
810 UPDATE users
811 SET invite_count = $1
812 WHERE id = $2
813 ",
814 )
815 .bind(count as i32)
816 .bind(id)
817 .execute(&mut tx)
818 .await?;
819 tx.commit().await?;
820 Ok(())
821 })
822 }
823
824 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
825 test_support!(self, {
826 let result: Option<(String, i32)> = sqlx::query_as(
827 "
828 SELECT invite_code, invite_count
829 FROM users
830 WHERE id = $1 AND invite_code IS NOT NULL
831 ",
832 )
833 .bind(id)
834 .fetch_optional(&self.pool)
835 .await?;
836 if let Some((code, count)) = result {
837 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
838 } else {
839 Ok(None)
840 }
841 })
842 }
843
844 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
845 test_support!(self, {
846 sqlx::query_as(
847 "
848 SELECT *
849 FROM users
850 WHERE invite_code = $1
851 ",
852 )
853 .bind(code)
854 .fetch_optional(&self.pool)
855 .await?
856 .ok_or_else(|| {
857 Error::Http(
858 StatusCode::NOT_FOUND,
859 "that invite code does not exist".to_string(),
860 )
861 })
862 })
863 }
864
865 // projects
866
867 /// Registers a new project for the given user.
868 pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
869 test_support!(self, {
870 Ok(sqlx::query_scalar(
871 "
872 INSERT INTO projects(host_user_id)
873 VALUES ($1)
874 RETURNING id
875 ",
876 )
877 .bind(host_user_id)
878 .fetch_one(&self.pool)
879 .await
880 .map(ProjectId)?)
881 })
882 }
883
884 /// Unregisters a project for the given project id.
885 pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
886 test_support!(self, {
887 sqlx::query(
888 "
889 UPDATE projects
890 SET unregistered = TRUE
891 WHERE id = $1
892 ",
893 )
894 .bind(project_id)
895 .execute(&self.pool)
896 .await?;
897 Ok(())
898 })
899 }
900
901 // contacts
902
903 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
904 test_support!(self, {
905 let query = "
906 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
907 FROM contacts
908 WHERE user_id_a = $1 OR user_id_b = $1;
909 ";
910
911 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
912 .bind(user_id)
913 .fetch(&self.pool);
914
915 let mut contacts = Vec::new();
916 while let Some(row) = rows.next().await {
917 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
918
919 if user_id_a == user_id {
920 if accepted {
921 contacts.push(Contact::Accepted {
922 user_id: user_id_b,
923 should_notify: should_notify && a_to_b,
924 });
925 } else if a_to_b {
926 contacts.push(Contact::Outgoing { user_id: user_id_b })
927 } else {
928 contacts.push(Contact::Incoming {
929 user_id: user_id_b,
930 should_notify,
931 });
932 }
933 } else if accepted {
934 contacts.push(Contact::Accepted {
935 user_id: user_id_a,
936 should_notify: should_notify && !a_to_b,
937 });
938 } else if a_to_b {
939 contacts.push(Contact::Incoming {
940 user_id: user_id_a,
941 should_notify,
942 });
943 } else {
944 contacts.push(Contact::Outgoing { user_id: user_id_a });
945 }
946 }
947
948 contacts.sort_unstable_by_key(|contact| contact.user_id());
949
950 Ok(contacts)
951 })
952 }
953
954 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
955 test_support!(self, {
956 let (id_a, id_b) = if user_id_1 < user_id_2 {
957 (user_id_1, user_id_2)
958 } else {
959 (user_id_2, user_id_1)
960 };
961
962 let query = "
963 SELECT 1 FROM contacts
964 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
965 LIMIT 1
966 ";
967 Ok(sqlx::query_scalar::<_, i32>(query)
968 .bind(id_a.0)
969 .bind(id_b.0)
970 .fetch_optional(&self.pool)
971 .await?
972 .is_some())
973 })
974 }
975
976 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
977 test_support!(self, {
978 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
979 (sender_id, receiver_id, true)
980 } else {
981 (receiver_id, sender_id, false)
982 };
983 let query = "
984 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
985 VALUES ($1, $2, $3, FALSE, TRUE)
986 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
987 SET
988 accepted = TRUE,
989 should_notify = FALSE
990 WHERE
991 NOT contacts.accepted AND
992 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
993 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
994 ";
995 let result = sqlx::query(query)
996 .bind(id_a.0)
997 .bind(id_b.0)
998 .bind(a_to_b)
999 .execute(&self.pool)
1000 .await?;
1001
1002 if result.rows_affected() == 1 {
1003 Ok(())
1004 } else {
1005 Err(anyhow!("contact already requested"))?
1006 }
1007 })
1008 }
1009
1010 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1011 test_support!(self, {
1012 let (id_a, id_b) = if responder_id < requester_id {
1013 (responder_id, requester_id)
1014 } else {
1015 (requester_id, responder_id)
1016 };
1017 let query = "
1018 DELETE FROM contacts
1019 WHERE user_id_a = $1 AND user_id_b = $2;
1020 ";
1021 let result = sqlx::query(query)
1022 .bind(id_a.0)
1023 .bind(id_b.0)
1024 .execute(&self.pool)
1025 .await?;
1026
1027 if result.rows_affected() == 1 {
1028 Ok(())
1029 } else {
1030 Err(anyhow!("no such contact"))?
1031 }
1032 })
1033 }
1034
1035 pub async fn dismiss_contact_notification(
1036 &self,
1037 user_id: UserId,
1038 contact_user_id: UserId,
1039 ) -> Result<()> {
1040 test_support!(self, {
1041 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1042 (user_id, contact_user_id, true)
1043 } else {
1044 (contact_user_id, user_id, false)
1045 };
1046
1047 let query = "
1048 UPDATE contacts
1049 SET should_notify = FALSE
1050 WHERE
1051 user_id_a = $1 AND user_id_b = $2 AND
1052 (
1053 (a_to_b = $3 AND accepted) OR
1054 (a_to_b != $3 AND NOT accepted)
1055 );
1056 ";
1057
1058 let result = sqlx::query(query)
1059 .bind(id_a.0)
1060 .bind(id_b.0)
1061 .bind(a_to_b)
1062 .execute(&self.pool)
1063 .await?;
1064
1065 if result.rows_affected() == 0 {
1066 Err(anyhow!("no such contact request"))?;
1067 }
1068
1069 Ok(())
1070 })
1071 }
1072
1073 pub async fn respond_to_contact_request(
1074 &self,
1075 responder_id: UserId,
1076 requester_id: UserId,
1077 accept: bool,
1078 ) -> Result<()> {
1079 test_support!(self, {
1080 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1081 (responder_id, requester_id, false)
1082 } else {
1083 (requester_id, responder_id, true)
1084 };
1085 let result = if accept {
1086 let query = "
1087 UPDATE contacts
1088 SET accepted = TRUE, should_notify = TRUE
1089 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1090 ";
1091 sqlx::query(query)
1092 .bind(id_a.0)
1093 .bind(id_b.0)
1094 .bind(a_to_b)
1095 .execute(&self.pool)
1096 .await?
1097 } else {
1098 let query = "
1099 DELETE FROM contacts
1100 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1101 ";
1102 sqlx::query(query)
1103 .bind(id_a.0)
1104 .bind(id_b.0)
1105 .bind(a_to_b)
1106 .execute(&self.pool)
1107 .await?
1108 };
1109 if result.rows_affected() == 1 {
1110 Ok(())
1111 } else {
1112 Err(anyhow!("no such contact request"))?
1113 }
1114 })
1115 }
1116
1117 // access tokens
1118
1119 pub async fn create_access_token_hash(
1120 &self,
1121 user_id: UserId,
1122 access_token_hash: &str,
1123 max_access_token_count: usize,
1124 ) -> Result<()> {
1125 test_support!(self, {
1126 let insert_query = "
1127 INSERT INTO access_tokens (user_id, hash)
1128 VALUES ($1, $2);
1129 ";
1130 let cleanup_query = "
1131 DELETE FROM access_tokens
1132 WHERE id IN (
1133 SELECT id from access_tokens
1134 WHERE user_id = $1
1135 ORDER BY id DESC
1136 LIMIT 10000
1137 OFFSET $3
1138 )
1139 ";
1140
1141 let mut tx = self.pool.begin().await?;
1142 sqlx::query(insert_query)
1143 .bind(user_id.0)
1144 .bind(access_token_hash)
1145 .execute(&mut tx)
1146 .await?;
1147 sqlx::query(cleanup_query)
1148 .bind(user_id.0)
1149 .bind(access_token_hash)
1150 .bind(max_access_token_count as i32)
1151 .execute(&mut tx)
1152 .await?;
1153 Ok(tx.commit().await?)
1154 })
1155 }
1156
1157 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1158 test_support!(self, {
1159 let query = "
1160 SELECT hash
1161 FROM access_tokens
1162 WHERE user_id = $1
1163 ORDER BY id DESC
1164 ";
1165 Ok(sqlx::query_scalar(query)
1166 .bind(user_id.0)
1167 .fetch_all(&self.pool)
1168 .await?)
1169 })
1170 }
1171}
1172
1173macro_rules! id_type {
1174 ($name:ident) => {
1175 #[derive(
1176 Clone,
1177 Copy,
1178 Debug,
1179 Default,
1180 PartialEq,
1181 Eq,
1182 PartialOrd,
1183 Ord,
1184 Hash,
1185 sqlx::Type,
1186 Serialize,
1187 Deserialize,
1188 )]
1189 #[sqlx(transparent)]
1190 #[serde(transparent)]
1191 pub struct $name(pub i32);
1192
1193 impl $name {
1194 #[allow(unused)]
1195 pub const MAX: Self = Self(i32::MAX);
1196
1197 #[allow(unused)]
1198 pub fn from_proto(value: u64) -> Self {
1199 Self(value as i32)
1200 }
1201
1202 #[allow(unused)]
1203 pub fn to_proto(self) -> u64 {
1204 self.0 as u64
1205 }
1206 }
1207
1208 impl std::fmt::Display for $name {
1209 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1210 self.0.fmt(f)
1211 }
1212 }
1213 };
1214}
1215
1216id_type!(UserId);
1217#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1218pub struct User {
1219 pub id: UserId,
1220 pub github_login: String,
1221 pub github_user_id: Option<i32>,
1222 pub email_address: Option<String>,
1223 pub admin: bool,
1224 pub invite_code: Option<String>,
1225 pub invite_count: i32,
1226 pub connected_once: bool,
1227}
1228
1229id_type!(ProjectId);
1230#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1231pub struct Project {
1232 pub id: ProjectId,
1233 pub host_user_id: UserId,
1234 pub unregistered: bool,
1235}
1236
1237#[derive(Clone, Debug, PartialEq, Eq)]
1238pub enum Contact {
1239 Accepted {
1240 user_id: UserId,
1241 should_notify: bool,
1242 },
1243 Outgoing {
1244 user_id: UserId,
1245 },
1246 Incoming {
1247 user_id: UserId,
1248 should_notify: bool,
1249 },
1250}
1251
1252impl Contact {
1253 pub fn user_id(&self) -> UserId {
1254 match self {
1255 Contact::Accepted { user_id, .. } => *user_id,
1256 Contact::Outgoing { user_id } => *user_id,
1257 Contact::Incoming { user_id, .. } => *user_id,
1258 }
1259 }
1260}
1261
1262#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1263pub struct IncomingContactRequest {
1264 pub requester_id: UserId,
1265 pub should_notify: bool,
1266}
1267
1268#[derive(Clone, Deserialize, Default)]
1269pub struct Signup {
1270 pub email_address: String,
1271 pub platform_mac: bool,
1272 pub platform_windows: bool,
1273 pub platform_linux: bool,
1274 pub editor_features: Vec<String>,
1275 pub programming_languages: Vec<String>,
1276 pub device_id: Option<String>,
1277 pub added_to_mailing_list: bool,
1278}
1279
1280#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1281pub struct WaitlistSummary {
1282 #[sqlx(default)]
1283 pub count: i64,
1284 #[sqlx(default)]
1285 pub linux_count: i64,
1286 #[sqlx(default)]
1287 pub mac_count: i64,
1288 #[sqlx(default)]
1289 pub windows_count: i64,
1290 #[sqlx(default)]
1291 pub unknown_count: i64,
1292}
1293
1294#[derive(Clone, FromRow, PartialEq, Debug, Serialize, Deserialize)]
1295pub struct Invite {
1296 pub email_address: String,
1297 pub email_confirmation_code: String,
1298}
1299
1300#[derive(Debug, Serialize, Deserialize)]
1301pub struct NewUserParams {
1302 pub github_login: String,
1303 pub github_user_id: i32,
1304 pub invite_count: i32,
1305}
1306
1307#[derive(Debug)]
1308pub struct NewUserResult {
1309 pub user_id: UserId,
1310 pub metrics_id: String,
1311 pub inviting_user_id: Option<UserId>,
1312 pub signup_device_id: Option<String>,
1313}
1314
1315fn random_invite_code() -> String {
1316 nanoid::nanoid!(16)
1317}
1318
1319fn random_email_confirmation_code() -> String {
1320 nanoid::nanoid!(64)
1321}
1322
1323#[cfg(test)]
1324pub use test::*;
1325
1326#[cfg(test)]
1327mod test {
1328 use super::*;
1329 use gpui::executor::Background;
1330 use lazy_static::lazy_static;
1331 use parking_lot::Mutex;
1332 use rand::prelude::*;
1333 use sqlx::migrate::MigrateDatabase;
1334 use std::sync::Arc;
1335
1336 pub struct SqliteTestDb {
1337 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
1338 pub conn: sqlx::sqlite::SqliteConnection,
1339 }
1340
1341 pub struct PostgresTestDb {
1342 pub db: Option<Arc<Db<sqlx::Postgres>>>,
1343 pub url: String,
1344 }
1345
1346 impl SqliteTestDb {
1347 pub fn new(background: Arc<Background>) -> Self {
1348 let mut rng = StdRng::from_entropy();
1349 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1350 let runtime = tokio::runtime::Builder::new_current_thread()
1351 .enable_io()
1352 .enable_time()
1353 .build()
1354 .unwrap();
1355
1356 let (mut db, conn) = runtime.block_on(async {
1357 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
1358 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
1359 db.migrate(migrations_path.as_ref(), false).await.unwrap();
1360 let conn = db.pool.acquire().await.unwrap().detach();
1361 (db, conn)
1362 });
1363
1364 db.background = Some(background);
1365 db.runtime = Some(runtime);
1366
1367 Self {
1368 db: Some(Arc::new(db)),
1369 conn,
1370 }
1371 }
1372
1373 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
1374 self.db.as_ref().unwrap()
1375 }
1376 }
1377
1378 impl PostgresTestDb {
1379 pub fn new(background: Arc<Background>) -> Self {
1380 lazy_static! {
1381 static ref LOCK: Mutex<()> = Mutex::new(());
1382 }
1383
1384 let _guard = LOCK.lock();
1385 let mut rng = StdRng::from_entropy();
1386 let url = format!(
1387 "postgres://postgres@localhost/zed-test-{}",
1388 rng.gen::<u128>()
1389 );
1390 let runtime = tokio::runtime::Builder::new_current_thread()
1391 .enable_io()
1392 .enable_time()
1393 .build()
1394 .unwrap();
1395
1396 let mut db = runtime.block_on(async {
1397 sqlx::Postgres::create_database(&url)
1398 .await
1399 .expect("failed to create test db");
1400 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
1401 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1402 db.migrate(Path::new(migrations_path), false).await.unwrap();
1403 db
1404 });
1405
1406 db.background = Some(background);
1407 db.runtime = Some(runtime);
1408
1409 Self {
1410 db: Some(Arc::new(db)),
1411 url,
1412 }
1413 }
1414
1415 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
1416 self.db.as_ref().unwrap()
1417 }
1418 }
1419
1420 impl Drop for PostgresTestDb {
1421 fn drop(&mut self) {
1422 let db = self.db.take().unwrap();
1423 db.teardown(&self.url);
1424 }
1425 }
1426}