1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::HashMap;
5use futures::StreamExt;
6use serde::{Deserialize, Serialize};
7use sqlx::{
8 migrate::{Migrate as _, Migration, MigrationSource},
9 types::Uuid,
10 FromRow,
11};
12use std::{path::Path, time::Duration};
13use time::{OffsetDateTime, PrimitiveDateTime};
14
15#[cfg(test)]
16pub type DefaultDb = Db<sqlx::Sqlite>;
17
18#[cfg(not(test))]
19pub type DefaultDb = Db<sqlx::Postgres>;
20
21pub struct Db<D: sqlx::Database> {
22 pool: sqlx::Pool<D>,
23 #[cfg(test)]
24 background: Option<std::sync::Arc<gpui::executor::Background>>,
25 #[cfg(test)]
26 runtime: Option<tokio::runtime::Runtime>,
27}
28
29macro_rules! test_support {
30 ($self:ident, { $($token:tt)* }) => {{
31 let body = async {
32 $($token)*
33 };
34
35 if cfg!(test) {
36 #[cfg(test)]
37 if let Some(background) = $self.background.as_ref() {
38 background.simulate_random_delay().await;
39 }
40 #[cfg(test)]
41 $self.runtime.as_ref().unwrap().block_on(body)
42 } else {
43 body.await
44 }
45 }};
46}
47
48pub trait RowsAffected {
49 fn rows_affected(&self) -> u64;
50}
51
52impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
53 fn rows_affected(&self) -> u64 {
54 self.rows_affected()
55 }
56}
57
58impl RowsAffected for sqlx::postgres::PgQueryResult {
59 fn rows_affected(&self) -> u64 {
60 self.rows_affected()
61 }
62}
63
64#[cfg(test)]
65impl Db<sqlx::Sqlite> {
66 const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
67
68 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
69 use std::str::FromStr as _;
70 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
71 .unwrap()
72 .create_if_missing(true)
73 .shared_cache(true);
74 let pool = sqlx::sqlite::SqlitePoolOptions::new()
75 .min_connections(2)
76 .max_connections(max_connections)
77 .connect_with(options)
78 .await?;
79 Ok(Self {
80 pool,
81 background: None,
82 runtime: None,
83 })
84 }
85
86 #[cfg(test)]
87 pub fn teardown(&self, _url: &str) {}
88
89 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
90 test_support!(self, {
91 let query = "
92 SELECT metrics_id
93 FROM users
94 WHERE id = $1
95 ";
96 Ok(sqlx::query_scalar(query)
97 .bind(id)
98 .fetch_one(&self.pool)
99 .await?)
100 })
101 }
102
103 pub async fn create_user(
104 &self,
105 email_address: &str,
106 admin: bool,
107 params: NewUserParams,
108 ) -> Result<NewUserResult> {
109 test_support!(self, {
110 let query = "
111 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
112 VALUES ($1, $2, $3, $4, $5)
113 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
114 RETURNING id, metrics_id
115 ";
116
117 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
118 .bind(email_address)
119 .bind(params.github_login)
120 .bind(params.github_user_id)
121 .bind(admin)
122 .bind(Uuid::new_v4().to_string())
123 .fetch_one(&self.pool)
124 .await?;
125 Ok(NewUserResult {
126 user_id,
127 metrics_id,
128 signup_device_id: None,
129 inviting_user_id: None,
130 })
131 })
132 }
133
134 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
135 unimplemented!()
136 }
137
138 pub async fn create_user_from_invite(
139 &self,
140 _invite: &Invite,
141 _user: NewUserParams,
142 ) -> Result<Option<NewUserResult>> {
143 unimplemented!()
144 }
145
146 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
147 unimplemented!()
148 }
149
150 pub async fn create_invite_from_code(
151 &self,
152 _code: &str,
153 _email_address: &str,
154 _device_id: Option<&str>,
155 ) -> Result<Invite> {
156 unimplemented!()
157 }
158}
159
160impl Db<sqlx::Postgres> {
161 const MIGRATIONS_PATH: &'static str = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
162
163 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
164 let pool = sqlx::postgres::PgPoolOptions::new()
165 .max_connections(max_connections)
166 .connect(url)
167 .await?;
168 Ok(Self {
169 pool,
170 #[cfg(test)]
171 background: None,
172 #[cfg(test)]
173 runtime: None,
174 })
175 }
176
177 #[cfg(test)]
178 pub fn teardown(&self, url: &str) {
179 self.runtime.as_ref().unwrap().block_on(async {
180 use util::ResultExt;
181 let query = "
182 SELECT pg_terminate_backend(pg_stat_activity.pid)
183 FROM pg_stat_activity
184 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
185 ";
186 sqlx::query(query).execute(&self.pool).await.log_err();
187 self.pool.close().await;
188 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
189 .await
190 .log_err();
191 })
192 }
193
194 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
195 test_support!(self, {
196 let like_string = Self::fuzzy_like_string(name_query);
197 let query = "
198 SELECT users.*
199 FROM users
200 WHERE github_login ILIKE $1
201 ORDER BY github_login <-> $2
202 LIMIT $3
203 ";
204 Ok(sqlx::query_as(query)
205 .bind(like_string)
206 .bind(name_query)
207 .bind(limit as i32)
208 .fetch_all(&self.pool)
209 .await?)
210 })
211 }
212
213 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
214 test_support!(self, {
215 let query = "
216 SELECT metrics_id::text
217 FROM users
218 WHERE id = $1
219 ";
220 Ok(sqlx::query_scalar(query)
221 .bind(id)
222 .fetch_one(&self.pool)
223 .await?)
224 })
225 }
226
227 pub async fn create_user(
228 &self,
229 email_address: &str,
230 admin: bool,
231 params: NewUserParams,
232 ) -> Result<NewUserResult> {
233 test_support!(self, {
234 let query = "
235 INSERT INTO users (email_address, github_login, github_user_id, admin)
236 VALUES ($1, $2, $3, $4)
237 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
238 RETURNING id, metrics_id::text
239 ";
240
241 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
242 .bind(email_address)
243 .bind(params.github_login)
244 .bind(params.github_user_id)
245 .bind(admin)
246 .fetch_one(&self.pool)
247 .await?;
248 Ok(NewUserResult {
249 user_id,
250 metrics_id,
251 signup_device_id: None,
252 inviting_user_id: None,
253 })
254 })
255 }
256
257 pub async fn create_user_from_invite(
258 &self,
259 invite: &Invite,
260 user: NewUserParams,
261 ) -> Result<Option<NewUserResult>> {
262 test_support!(self, {
263 let mut tx = self.pool.begin().await?;
264
265 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
266 i32,
267 Option<UserId>,
268 Option<UserId>,
269 Option<String>,
270 ) = sqlx::query_as(
271 "
272 SELECT id, user_id, inviting_user_id, device_id
273 FROM signups
274 WHERE
275 email_address = $1 AND
276 email_confirmation_code = $2
277 ",
278 )
279 .bind(&invite.email_address)
280 .bind(&invite.email_confirmation_code)
281 .fetch_optional(&mut tx)
282 .await?
283 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
284
285 if existing_user_id.is_some() {
286 return Ok(None);
287 }
288
289 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
290 "
291 INSERT INTO users
292 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
293 VALUES
294 ($1, $2, $3, FALSE, $4, $5)
295 ON CONFLICT (github_login) DO UPDATE SET
296 email_address = excluded.email_address,
297 github_user_id = excluded.github_user_id,
298 admin = excluded.admin
299 RETURNING id, metrics_id::text
300 ",
301 )
302 .bind(&invite.email_address)
303 .bind(&user.github_login)
304 .bind(&user.github_user_id)
305 .bind(&user.invite_count)
306 .bind(random_invite_code())
307 .fetch_one(&mut tx)
308 .await?;
309
310 sqlx::query(
311 "
312 UPDATE signups
313 SET user_id = $1
314 WHERE id = $2
315 ",
316 )
317 .bind(&user_id)
318 .bind(&signup_id)
319 .execute(&mut tx)
320 .await?;
321
322 if let Some(inviting_user_id) = inviting_user_id {
323 let id: Option<UserId> = sqlx::query_scalar(
324 "
325 UPDATE users
326 SET invite_count = invite_count - 1
327 WHERE id = $1 AND invite_count > 0
328 RETURNING id
329 ",
330 )
331 .bind(&inviting_user_id)
332 .fetch_optional(&mut tx)
333 .await?;
334
335 if id.is_none() {
336 Err(Error::Http(
337 StatusCode::UNAUTHORIZED,
338 "no invites remaining".to_string(),
339 ))?;
340 }
341
342 sqlx::query(
343 "
344 INSERT INTO contacts
345 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
346 VALUES
347 ($1, $2, TRUE, TRUE, TRUE)
348 ON CONFLICT DO NOTHING
349 ",
350 )
351 .bind(inviting_user_id)
352 .bind(user_id)
353 .execute(&mut tx)
354 .await?;
355 }
356
357 tx.commit().await?;
358 Ok(Some(NewUserResult {
359 user_id,
360 metrics_id,
361 inviting_user_id,
362 signup_device_id,
363 }))
364 })
365 }
366
367 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
368 test_support!(self, {
369 sqlx::query(
370 "
371 INSERT INTO signups
372 (
373 email_address,
374 email_confirmation_code,
375 email_confirmation_sent,
376 platform_linux,
377 platform_mac,
378 platform_windows,
379 platform_unknown,
380 editor_features,
381 programming_languages,
382 device_id
383 )
384 VALUES
385 ($1, $2, FALSE, $3, $4, $5, FALSE, $6)
386 RETURNING id
387 ",
388 )
389 .bind(&signup.email_address)
390 .bind(&random_email_confirmation_code())
391 .bind(&signup.platform_linux)
392 .bind(&signup.platform_mac)
393 .bind(&signup.platform_windows)
394 .bind(&signup.editor_features)
395 .bind(&signup.programming_languages)
396 .bind(&signup.device_id)
397 .execute(&self.pool)
398 .await?;
399 Ok(())
400 })
401 }
402
403 pub async fn create_invite_from_code(
404 &self,
405 code: &str,
406 email_address: &str,
407 device_id: Option<&str>,
408 ) -> Result<Invite> {
409 test_support!(self, {
410 let mut tx = self.pool.begin().await?;
411
412 let existing_user: Option<UserId> = sqlx::query_scalar(
413 "
414 SELECT id
415 FROM users
416 WHERE email_address = $1
417 ",
418 )
419 .bind(email_address)
420 .fetch_optional(&mut tx)
421 .await?;
422 if existing_user.is_some() {
423 Err(anyhow!("email address is already in use"))?;
424 }
425
426 let row: Option<(UserId, i32)> = sqlx::query_as(
427 "
428 SELECT id, invite_count
429 FROM users
430 WHERE invite_code = $1
431 ",
432 )
433 .bind(code)
434 .fetch_optional(&mut tx)
435 .await?;
436
437 let (inviter_id, invite_count) = match row {
438 Some(row) => row,
439 None => Err(Error::Http(
440 StatusCode::NOT_FOUND,
441 "invite code not found".to_string(),
442 ))?,
443 };
444
445 if invite_count == 0 {
446 Err(Error::Http(
447 StatusCode::UNAUTHORIZED,
448 "no invites remaining".to_string(),
449 ))?;
450 }
451
452 let email_confirmation_code: String = sqlx::query_scalar(
453 "
454 INSERT INTO signups
455 (
456 email_address,
457 email_confirmation_code,
458 email_confirmation_sent,
459 inviting_user_id,
460 platform_linux,
461 platform_mac,
462 platform_windows,
463 platform_unknown,
464 device_id
465 )
466 VALUES
467 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
468 ON CONFLICT (email_address)
469 DO UPDATE SET
470 inviting_user_id = excluded.inviting_user_id
471 RETURNING email_confirmation_code
472 ",
473 )
474 .bind(&email_address)
475 .bind(&random_email_confirmation_code())
476 .bind(&inviter_id)
477 .bind(&device_id)
478 .fetch_one(&mut tx)
479 .await?;
480
481 tx.commit().await?;
482
483 Ok(Invite {
484 email_address: email_address.into(),
485 email_confirmation_code,
486 })
487 })
488 }
489}
490
491impl<D> Db<D>
492where
493 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
494 D::Connection: sqlx::migrate::Migrate,
495 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
496 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
497 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
498 D::QueryResult: RowsAffected,
499 String: sqlx::Type<D>,
500 i32: sqlx::Type<D>,
501 i64: sqlx::Type<D>,
502 bool: sqlx::Type<D>,
503 str: sqlx::Type<D>,
504 Uuid: sqlx::Type<D>,
505 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
506 OffsetDateTime: sqlx::Type<D>,
507 PrimitiveDateTime: sqlx::Type<D>,
508 usize: sqlx::ColumnIndex<D::Row>,
509 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
510 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
511 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
512 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
513 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
514 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
515 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
516 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
517 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
518 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
519 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
520 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
521{
522 pub async fn migrate(
523 &self,
524 migrations_path: &Path,
525 ignore_checksum_mismatch: bool,
526 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
527 let migrations = MigrationSource::resolve(migrations_path)
528 .await
529 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
530
531 let mut conn = self.pool.acquire().await?;
532
533 conn.ensure_migrations_table().await?;
534 let applied_migrations: HashMap<_, _> = conn
535 .list_applied_migrations()
536 .await?
537 .into_iter()
538 .map(|m| (m.version, m))
539 .collect();
540
541 let mut new_migrations = Vec::new();
542 for migration in migrations {
543 match applied_migrations.get(&migration.version) {
544 Some(applied_migration) => {
545 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
546 {
547 Err(anyhow!(
548 "checksum mismatch for applied migration {}",
549 migration.description
550 ))?;
551 }
552 }
553 None => {
554 let elapsed = conn.apply(&migration).await?;
555 new_migrations.push((migration, elapsed));
556 }
557 }
558 }
559
560 Ok(new_migrations)
561 }
562
563 pub fn fuzzy_like_string(string: &str) -> String {
564 let mut result = String::with_capacity(string.len() * 2 + 1);
565 for c in string.chars() {
566 if c.is_alphanumeric() {
567 result.push('%');
568 result.push(c);
569 }
570 }
571 result.push('%');
572 result
573 }
574
575 // users
576
577 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
578 test_support!(self, {
579 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
580 Ok(sqlx::query_as(query)
581 .bind(limit as i32)
582 .bind((page * limit) as i32)
583 .fetch_all(&self.pool)
584 .await?)
585 })
586 }
587
588 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
589 test_support!(self, {
590 let query = "
591 SELECT users.*
592 FROM users
593 WHERE id = $1
594 LIMIT 1
595 ";
596 Ok(sqlx::query_as(query)
597 .bind(&id)
598 .fetch_optional(&self.pool)
599 .await?)
600 })
601 }
602
603 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
604 test_support!(self, {
605 let query = "
606 SELECT users.*
607 FROM users
608 WHERE users.id IN (SELECT value from json_each($1))
609 ";
610 Ok(sqlx::query_as(query)
611 .bind(&serde_json::json!(ids))
612 .fetch_all(&self.pool)
613 .await?)
614 })
615 }
616
617 pub async fn get_users_with_no_invites(
618 &self,
619 invited_by_another_user: bool,
620 ) -> Result<Vec<User>> {
621 test_support!(self, {
622 let query = format!(
623 "
624 SELECT users.*
625 FROM users
626 WHERE invite_count = 0
627 AND inviter_id IS{} NULL
628 ",
629 if invited_by_another_user { " NOT" } else { "" }
630 );
631
632 Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?)
633 })
634 }
635
636 pub async fn get_user_by_github_account(
637 &self,
638 github_login: &str,
639 github_user_id: Option<i32>,
640 ) -> Result<Option<User>> {
641 test_support!(self, {
642 if let Some(github_user_id) = github_user_id {
643 let mut user = sqlx::query_as::<_, User>(
644 "
645 UPDATE users
646 SET github_login = $1
647 WHERE github_user_id = $2
648 RETURNING *
649 ",
650 )
651 .bind(github_login)
652 .bind(github_user_id)
653 .fetch_optional(&self.pool)
654 .await?;
655
656 if user.is_none() {
657 user = sqlx::query_as::<_, User>(
658 "
659 UPDATE users
660 SET github_user_id = $1
661 WHERE github_login = $2
662 RETURNING *
663 ",
664 )
665 .bind(github_user_id)
666 .bind(github_login)
667 .fetch_optional(&self.pool)
668 .await?;
669 }
670
671 Ok(user)
672 } else {
673 let user = sqlx::query_as(
674 "
675 SELECT * FROM users
676 WHERE github_login = $1
677 LIMIT 1
678 ",
679 )
680 .bind(github_login)
681 .fetch_optional(&self.pool)
682 .await?;
683 Ok(user)
684 }
685 })
686 }
687
688 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
689 test_support!(self, {
690 let query = "UPDATE users SET admin = $1 WHERE id = $2";
691 Ok(sqlx::query(query)
692 .bind(is_admin)
693 .bind(id.0)
694 .execute(&self.pool)
695 .await
696 .map(drop)?)
697 })
698 }
699
700 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
701 test_support!(self, {
702 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
703 Ok(sqlx::query(query)
704 .bind(connected_once)
705 .bind(id.0)
706 .execute(&self.pool)
707 .await
708 .map(drop)?)
709 })
710 }
711
712 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
713 test_support!(self, {
714 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
715 sqlx::query(query)
716 .bind(id.0)
717 .execute(&self.pool)
718 .await
719 .map(drop)?;
720 let query = "DELETE FROM users WHERE id = $1;";
721 Ok(sqlx::query(query)
722 .bind(id.0)
723 .execute(&self.pool)
724 .await
725 .map(drop)?)
726 })
727 }
728
729 // signups
730
731 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
732 test_support!(self, {
733 Ok(sqlx::query_as(
734 "
735 SELECT
736 COUNT(*) as count,
737 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
738 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
739 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
740 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
741 FROM (
742 SELECT *
743 FROM signups
744 WHERE
745 NOT email_confirmation_sent
746 ) AS unsent
747 ",
748 )
749 .fetch_one(&self.pool)
750 .await?)
751 })
752 }
753
754 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
755 test_support!(self, {
756 Ok(sqlx::query_as(
757 "
758 SELECT
759 email_address, email_confirmation_code
760 FROM signups
761 WHERE
762 NOT email_confirmation_sent AND
763 (platform_mac OR platform_unknown)
764 LIMIT $1
765 ",
766 )
767 .bind(count as i32)
768 .fetch_all(&self.pool)
769 .await?)
770 })
771 }
772
773 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
774 test_support!(self, {
775 let emails = invites
776 .iter()
777 .map(|s| s.email_address.as_str())
778 .collect::<Vec<_>>();
779 sqlx::query(
780 "
781 UPDATE signups
782 SET email_confirmation_sent = TRUE
783 WHERE email_address IN (SELECT value from json_each($1))
784 ",
785 )
786 .bind(&serde_json::json!(emails))
787 .execute(&self.pool)
788 .await?;
789 Ok(())
790 })
791 }
792
793 // invite codes
794
795 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
796 test_support!(self, {
797 let mut tx = self.pool.begin().await?;
798 if count > 0 {
799 sqlx::query(
800 "
801 UPDATE users
802 SET invite_code = $1
803 WHERE id = $2 AND invite_code IS NULL
804 ",
805 )
806 .bind(random_invite_code())
807 .bind(id)
808 .execute(&mut tx)
809 .await?;
810 }
811
812 sqlx::query(
813 "
814 UPDATE users
815 SET invite_count = $1
816 WHERE id = $2
817 ",
818 )
819 .bind(count as i32)
820 .bind(id)
821 .execute(&mut tx)
822 .await?;
823 tx.commit().await?;
824 Ok(())
825 })
826 }
827
828 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
829 test_support!(self, {
830 let result: Option<(String, i32)> = sqlx::query_as(
831 "
832 SELECT invite_code, invite_count
833 FROM users
834 WHERE id = $1 AND invite_code IS NOT NULL
835 ",
836 )
837 .bind(id)
838 .fetch_optional(&self.pool)
839 .await?;
840 if let Some((code, count)) = result {
841 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
842 } else {
843 Ok(None)
844 }
845 })
846 }
847
848 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
849 test_support!(self, {
850 sqlx::query_as(
851 "
852 SELECT *
853 FROM users
854 WHERE invite_code = $1
855 ",
856 )
857 .bind(code)
858 .fetch_optional(&self.pool)
859 .await?
860 .ok_or_else(|| {
861 Error::Http(
862 StatusCode::NOT_FOUND,
863 "that invite code does not exist".to_string(),
864 )
865 })
866 })
867 }
868
869 // projects
870
871 /// Registers a new project for the given user.
872 pub async fn register_project(&self, host_user_id: UserId) -> Result<ProjectId> {
873 test_support!(self, {
874 Ok(sqlx::query_scalar(
875 "
876 INSERT INTO projects(host_user_id)
877 VALUES ($1)
878 RETURNING id
879 ",
880 )
881 .bind(host_user_id)
882 .fetch_one(&self.pool)
883 .await
884 .map(ProjectId)?)
885 })
886 }
887
888 /// Unregisters a project for the given project id.
889 pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> {
890 test_support!(self, {
891 sqlx::query(
892 "
893 UPDATE projects
894 SET unregistered = TRUE
895 WHERE id = $1
896 ",
897 )
898 .bind(project_id)
899 .execute(&self.pool)
900 .await?;
901 Ok(())
902 })
903 }
904
905 // contacts
906
907 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
908 test_support!(self, {
909 let query = "
910 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
911 FROM contacts
912 WHERE user_id_a = $1 OR user_id_b = $1;
913 ";
914
915 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
916 .bind(user_id)
917 .fetch(&self.pool);
918
919 let mut contacts = Vec::new();
920 while let Some(row) = rows.next().await {
921 let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
922
923 if user_id_a == user_id {
924 if accepted {
925 contacts.push(Contact::Accepted {
926 user_id: user_id_b,
927 should_notify: should_notify && a_to_b,
928 });
929 } else if a_to_b {
930 contacts.push(Contact::Outgoing { user_id: user_id_b })
931 } else {
932 contacts.push(Contact::Incoming {
933 user_id: user_id_b,
934 should_notify,
935 });
936 }
937 } else if accepted {
938 contacts.push(Contact::Accepted {
939 user_id: user_id_a,
940 should_notify: should_notify && !a_to_b,
941 });
942 } else if a_to_b {
943 contacts.push(Contact::Incoming {
944 user_id: user_id_a,
945 should_notify,
946 });
947 } else {
948 contacts.push(Contact::Outgoing { user_id: user_id_a });
949 }
950 }
951
952 contacts.sort_unstable_by_key(|contact| contact.user_id());
953
954 Ok(contacts)
955 })
956 }
957
958 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
959 test_support!(self, {
960 let (id_a, id_b) = if user_id_1 < user_id_2 {
961 (user_id_1, user_id_2)
962 } else {
963 (user_id_2, user_id_1)
964 };
965
966 let query = "
967 SELECT 1 FROM contacts
968 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
969 LIMIT 1
970 ";
971 Ok(sqlx::query_scalar::<_, i32>(query)
972 .bind(id_a.0)
973 .bind(id_b.0)
974 .fetch_optional(&self.pool)
975 .await?
976 .is_some())
977 })
978 }
979
980 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
981 test_support!(self, {
982 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
983 (sender_id, receiver_id, true)
984 } else {
985 (receiver_id, sender_id, false)
986 };
987 let query = "
988 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
989 VALUES ($1, $2, $3, FALSE, TRUE)
990 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
991 SET
992 accepted = TRUE,
993 should_notify = FALSE
994 WHERE
995 NOT contacts.accepted AND
996 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
997 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
998 ";
999 let result = sqlx::query(query)
1000 .bind(id_a.0)
1001 .bind(id_b.0)
1002 .bind(a_to_b)
1003 .execute(&self.pool)
1004 .await?;
1005
1006 if result.rows_affected() == 1 {
1007 Ok(())
1008 } else {
1009 Err(anyhow!("contact already requested"))?
1010 }
1011 })
1012 }
1013
1014 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1015 test_support!(self, {
1016 let (id_a, id_b) = if responder_id < requester_id {
1017 (responder_id, requester_id)
1018 } else {
1019 (requester_id, responder_id)
1020 };
1021 let query = "
1022 DELETE FROM contacts
1023 WHERE user_id_a = $1 AND user_id_b = $2;
1024 ";
1025 let result = sqlx::query(query)
1026 .bind(id_a.0)
1027 .bind(id_b.0)
1028 .execute(&self.pool)
1029 .await?;
1030
1031 if result.rows_affected() == 1 {
1032 Ok(())
1033 } else {
1034 Err(anyhow!("no such contact"))?
1035 }
1036 })
1037 }
1038
1039 pub async fn dismiss_contact_notification(
1040 &self,
1041 user_id: UserId,
1042 contact_user_id: UserId,
1043 ) -> Result<()> {
1044 test_support!(self, {
1045 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1046 (user_id, contact_user_id, true)
1047 } else {
1048 (contact_user_id, user_id, false)
1049 };
1050
1051 let query = "
1052 UPDATE contacts
1053 SET should_notify = FALSE
1054 WHERE
1055 user_id_a = $1 AND user_id_b = $2 AND
1056 (
1057 (a_to_b = $3 AND accepted) OR
1058 (a_to_b != $3 AND NOT accepted)
1059 );
1060 ";
1061
1062 let result = sqlx::query(query)
1063 .bind(id_a.0)
1064 .bind(id_b.0)
1065 .bind(a_to_b)
1066 .execute(&self.pool)
1067 .await?;
1068
1069 if result.rows_affected() == 0 {
1070 Err(anyhow!("no such contact request"))?;
1071 }
1072
1073 Ok(())
1074 })
1075 }
1076
1077 pub async fn respond_to_contact_request(
1078 &self,
1079 responder_id: UserId,
1080 requester_id: UserId,
1081 accept: bool,
1082 ) -> Result<()> {
1083 test_support!(self, {
1084 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1085 (responder_id, requester_id, false)
1086 } else {
1087 (requester_id, responder_id, true)
1088 };
1089 let result = if accept {
1090 let query = "
1091 UPDATE contacts
1092 SET accepted = TRUE, should_notify = TRUE
1093 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1094 ";
1095 sqlx::query(query)
1096 .bind(id_a.0)
1097 .bind(id_b.0)
1098 .bind(a_to_b)
1099 .execute(&self.pool)
1100 .await?
1101 } else {
1102 let query = "
1103 DELETE FROM contacts
1104 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1105 ";
1106 sqlx::query(query)
1107 .bind(id_a.0)
1108 .bind(id_b.0)
1109 .bind(a_to_b)
1110 .execute(&self.pool)
1111 .await?
1112 };
1113 if result.rows_affected() == 1 {
1114 Ok(())
1115 } else {
1116 Err(anyhow!("no such contact request"))?
1117 }
1118 })
1119 }
1120
1121 // access tokens
1122
1123 pub async fn create_access_token_hash(
1124 &self,
1125 user_id: UserId,
1126 access_token_hash: &str,
1127 max_access_token_count: usize,
1128 ) -> Result<()> {
1129 test_support!(self, {
1130 let insert_query = "
1131 INSERT INTO access_tokens (user_id, hash)
1132 VALUES ($1, $2);
1133 ";
1134 let cleanup_query = "
1135 DELETE FROM access_tokens
1136 WHERE id IN (
1137 SELECT id from access_tokens
1138 WHERE user_id = $1
1139 ORDER BY id DESC
1140 LIMIT 10000
1141 OFFSET $3
1142 )
1143 ";
1144
1145 let mut tx = self.pool.begin().await?;
1146 sqlx::query(insert_query)
1147 .bind(user_id.0)
1148 .bind(access_token_hash)
1149 .execute(&mut tx)
1150 .await?;
1151 sqlx::query(cleanup_query)
1152 .bind(user_id.0)
1153 .bind(access_token_hash)
1154 .bind(max_access_token_count as i32)
1155 .execute(&mut tx)
1156 .await?;
1157 Ok(tx.commit().await?)
1158 })
1159 }
1160
1161 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1162 test_support!(self, {
1163 let query = "
1164 SELECT hash
1165 FROM access_tokens
1166 WHERE user_id = $1
1167 ORDER BY id DESC
1168 ";
1169 Ok(sqlx::query_scalar(query)
1170 .bind(user_id.0)
1171 .fetch_all(&self.pool)
1172 .await?)
1173 })
1174 }
1175}
1176
1177macro_rules! id_type {
1178 ($name:ident) => {
1179 #[derive(
1180 Clone,
1181 Copy,
1182 Debug,
1183 Default,
1184 PartialEq,
1185 Eq,
1186 PartialOrd,
1187 Ord,
1188 Hash,
1189 sqlx::Type,
1190 Serialize,
1191 Deserialize,
1192 )]
1193 #[sqlx(transparent)]
1194 #[serde(transparent)]
1195 pub struct $name(pub i32);
1196
1197 impl $name {
1198 #[allow(unused)]
1199 pub const MAX: Self = Self(i32::MAX);
1200
1201 #[allow(unused)]
1202 pub fn from_proto(value: u64) -> Self {
1203 Self(value as i32)
1204 }
1205
1206 #[allow(unused)]
1207 pub fn to_proto(self) -> u64 {
1208 self.0 as u64
1209 }
1210 }
1211
1212 impl std::fmt::Display for $name {
1213 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1214 self.0.fmt(f)
1215 }
1216 }
1217 };
1218}
1219
1220id_type!(UserId);
1221#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1222pub struct User {
1223 pub id: UserId,
1224 pub github_login: String,
1225 pub github_user_id: Option<i32>,
1226 pub email_address: Option<String>,
1227 pub admin: bool,
1228 pub invite_code: Option<String>,
1229 pub invite_count: i32,
1230 pub connected_once: bool,
1231}
1232
1233id_type!(ProjectId);
1234#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
1235pub struct Project {
1236 pub id: ProjectId,
1237 pub host_user_id: UserId,
1238 pub unregistered: bool,
1239}
1240
1241#[derive(Clone, Debug, PartialEq, Eq)]
1242pub enum Contact {
1243 Accepted {
1244 user_id: UserId,
1245 should_notify: bool,
1246 },
1247 Outgoing {
1248 user_id: UserId,
1249 },
1250 Incoming {
1251 user_id: UserId,
1252 should_notify: bool,
1253 },
1254}
1255
1256impl Contact {
1257 pub fn user_id(&self) -> UserId {
1258 match self {
1259 Contact::Accepted { user_id, .. } => *user_id,
1260 Contact::Outgoing { user_id } => *user_id,
1261 Contact::Incoming { user_id, .. } => *user_id,
1262 }
1263 }
1264}
1265
1266#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
1267pub struct IncomingContactRequest {
1268 pub requester_id: UserId,
1269 pub should_notify: bool,
1270}
1271
1272#[derive(Clone, Deserialize)]
1273pub struct Signup {
1274 pub email_address: String,
1275 pub platform_mac: bool,
1276 pub platform_windows: bool,
1277 pub platform_linux: bool,
1278 pub editor_features: Vec<String>,
1279 pub programming_languages: Vec<String>,
1280 pub device_id: Option<String>,
1281}
1282
1283#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
1284pub struct WaitlistSummary {
1285 #[sqlx(default)]
1286 pub count: i64,
1287 #[sqlx(default)]
1288 pub linux_count: i64,
1289 #[sqlx(default)]
1290 pub mac_count: i64,
1291 #[sqlx(default)]
1292 pub windows_count: i64,
1293 #[sqlx(default)]
1294 pub unknown_count: i64,
1295}
1296
1297#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
1298pub struct Invite {
1299 pub email_address: String,
1300 pub email_confirmation_code: String,
1301}
1302
1303#[derive(Debug, Serialize, Deserialize)]
1304pub struct NewUserParams {
1305 pub github_login: String,
1306 pub github_user_id: i32,
1307 pub invite_count: i32,
1308}
1309
1310#[derive(Debug)]
1311pub struct NewUserResult {
1312 pub user_id: UserId,
1313 pub metrics_id: String,
1314 pub inviting_user_id: Option<UserId>,
1315 pub signup_device_id: Option<String>,
1316}
1317
1318fn random_invite_code() -> String {
1319 nanoid::nanoid!(16)
1320}
1321
1322fn random_email_confirmation_code() -> String {
1323 nanoid::nanoid!(64)
1324}
1325
1326#[cfg(test)]
1327pub use test::*;
1328
1329#[cfg(test)]
1330mod test {
1331 use super::*;
1332 use gpui::executor::Background;
1333 use rand::prelude::*;
1334 use std::sync::Arc;
1335
1336 pub struct TestDb {
1337 pub db: Option<Arc<DefaultDb>>,
1338 pub conn: sqlx::sqlite::SqliteConnection,
1339 pub url: String,
1340 }
1341
1342 impl TestDb {
1343 pub fn new(background: Arc<Background>) -> Self {
1344 let mut rng = StdRng::from_entropy();
1345 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
1346 let runtime = tokio::runtime::Builder::new_current_thread()
1347 .enable_io()
1348 .enable_time()
1349 .build()
1350 .unwrap();
1351
1352 let (mut db, conn) = runtime.block_on(async {
1353 let db = DefaultDb::new(&url, 5).await.unwrap();
1354 db.migrate(Path::new(DefaultDb::MIGRATIONS_PATH), false)
1355 .await
1356 .unwrap();
1357 let conn = db.pool.acquire().await.unwrap().detach();
1358 (db, conn)
1359 });
1360
1361 db.background = Some(background);
1362 db.runtime = Some(runtime);
1363
1364 Self {
1365 db: Some(Arc::new(db)),
1366 conn,
1367 url,
1368 }
1369 }
1370
1371 pub fn db(&self) -> &Arc<DefaultDb> {
1372 self.db.as_ref().unwrap()
1373 }
1374 }
1375
1376 impl Drop for TestDb {
1377 fn drop(&mut self) {
1378 let db = self.db.take().unwrap();
1379 db.teardown(&self.url);
1380 }
1381 }
1382}