1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::{BTreeMap, HashMap, HashSet};
5use futures::{future::BoxFuture, FutureExt, StreamExt};
6use rpc::{proto, ConnectionId};
7use serde::{Deserialize, Serialize};
8use sqlx::{
9 migrate::{Migrate as _, Migration, MigrationSource},
10 types::Uuid,
11 FromRow,
12};
13use std::{
14 future::Future,
15 path::{Path, PathBuf},
16 time::Duration,
17};
18use time::{OffsetDateTime, PrimitiveDateTime};
19
20#[cfg(test)]
21pub type DefaultDb = Db<sqlx::Sqlite>;
22
23#[cfg(not(test))]
24pub type DefaultDb = Db<sqlx::Postgres>;
25
26pub struct Db<D: sqlx::Database> {
27 pool: sqlx::Pool<D>,
28 #[cfg(test)]
29 background: Option<std::sync::Arc<gpui::executor::Background>>,
30 #[cfg(test)]
31 runtime: Option<tokio::runtime::Runtime>,
32}
33
34pub trait BeginTransaction: Send + Sync {
35 type Database: sqlx::Database;
36
37 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
38}
39
40// In Postgres, serializable transactions are opt-in
41impl BeginTransaction for Db<sqlx::Postgres> {
42 type Database = sqlx::Postgres;
43
44 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
45 async move {
46 let mut tx = self.pool.begin().await?;
47 sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
48 .await?;
49 Ok(tx)
50 }
51 .boxed()
52 }
53}
54
55// In Sqlite, transactions are inherently serializable.
56impl BeginTransaction for Db<sqlx::Sqlite> {
57 type Database = sqlx::Sqlite;
58
59 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
60 async move { Ok(self.pool.begin().await?) }.boxed()
61 }
62}
63
64pub trait RowsAffected {
65 fn rows_affected(&self) -> u64;
66}
67
68#[cfg(test)]
69impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
70 fn rows_affected(&self) -> u64 {
71 self.rows_affected()
72 }
73}
74
75impl RowsAffected for sqlx::postgres::PgQueryResult {
76 fn rows_affected(&self) -> u64 {
77 self.rows_affected()
78 }
79}
80
81#[cfg(test)]
82impl Db<sqlx::Sqlite> {
83 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
84 use std::str::FromStr as _;
85 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
86 .unwrap()
87 .create_if_missing(true)
88 .shared_cache(true);
89 let pool = sqlx::sqlite::SqlitePoolOptions::new()
90 .min_connections(2)
91 .max_connections(max_connections)
92 .connect_with(options)
93 .await?;
94 Ok(Self {
95 pool,
96 background: None,
97 runtime: None,
98 })
99 }
100
101 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
102 self.transact(|tx| async {
103 let mut tx = tx;
104 let query = "
105 SELECT users.*
106 FROM users
107 WHERE users.id IN (SELECT value from json_each($1))
108 ";
109 Ok(sqlx::query_as(query)
110 .bind(&serde_json::json!(ids))
111 .fetch_all(&mut tx)
112 .await?)
113 })
114 .await
115 }
116
117 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
118 self.transact(|mut tx| async move {
119 let query = "
120 SELECT metrics_id
121 FROM users
122 WHERE id = $1
123 ";
124 Ok(sqlx::query_scalar(query)
125 .bind(id)
126 .fetch_one(&mut tx)
127 .await?)
128 })
129 .await
130 }
131
132 pub async fn create_user(
133 &self,
134 email_address: &str,
135 admin: bool,
136 params: NewUserParams,
137 ) -> Result<NewUserResult> {
138 self.transact(|mut tx| async {
139 let query = "
140 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
141 VALUES ($1, $2, $3, $4, $5)
142 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
143 RETURNING id, metrics_id
144 ";
145
146 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
147 .bind(email_address)
148 .bind(¶ms.github_login)
149 .bind(¶ms.github_user_id)
150 .bind(admin)
151 .bind(Uuid::new_v4().to_string())
152 .fetch_one(&mut tx)
153 .await?;
154 tx.commit().await?;
155 Ok(NewUserResult {
156 user_id,
157 metrics_id,
158 signup_device_id: None,
159 inviting_user_id: None,
160 })
161 })
162 .await
163 }
164
165 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
166 unimplemented!()
167 }
168
169 pub async fn create_user_from_invite(
170 &self,
171 _invite: &Invite,
172 _user: NewUserParams,
173 ) -> Result<Option<NewUserResult>> {
174 unimplemented!()
175 }
176
177 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
178 unimplemented!()
179 }
180
181 pub async fn create_invite_from_code(
182 &self,
183 _code: &str,
184 _email_address: &str,
185 _device_id: Option<&str>,
186 ) -> Result<Invite> {
187 unimplemented!()
188 }
189
190 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
191 unimplemented!()
192 }
193}
194
195impl Db<sqlx::Postgres> {
196 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
197 let pool = sqlx::postgres::PgPoolOptions::new()
198 .max_connections(max_connections)
199 .connect(url)
200 .await?;
201 Ok(Self {
202 pool,
203 #[cfg(test)]
204 background: None,
205 #[cfg(test)]
206 runtime: None,
207 })
208 }
209
210 #[cfg(test)]
211 pub fn teardown(&self, url: &str) {
212 self.runtime.as_ref().unwrap().block_on(async {
213 use util::ResultExt;
214 let query = "
215 SELECT pg_terminate_backend(pg_stat_activity.pid)
216 FROM pg_stat_activity
217 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
218 ";
219 sqlx::query(query).execute(&self.pool).await.log_err();
220 self.pool.close().await;
221 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
222 .await
223 .log_err();
224 })
225 }
226
227 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
228 self.transact(|tx| async {
229 let mut tx = tx;
230 let like_string = Self::fuzzy_like_string(name_query);
231 let query = "
232 SELECT users.*
233 FROM users
234 WHERE github_login ILIKE $1
235 ORDER BY github_login <-> $2
236 LIMIT $3
237 ";
238 Ok(sqlx::query_as(query)
239 .bind(like_string)
240 .bind(name_query)
241 .bind(limit as i32)
242 .fetch_all(&mut tx)
243 .await?)
244 })
245 .await
246 }
247
248 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
249 let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
250 self.transact(|tx| async {
251 let mut tx = tx;
252 let query = "
253 SELECT users.*
254 FROM users
255 WHERE users.id = ANY ($1)
256 ";
257 Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
258 })
259 .await
260 }
261
262 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
263 self.transact(|mut tx| async move {
264 let query = "
265 SELECT metrics_id::text
266 FROM users
267 WHERE id = $1
268 ";
269 Ok(sqlx::query_scalar(query)
270 .bind(id)
271 .fetch_one(&mut tx)
272 .await?)
273 })
274 .await
275 }
276
277 pub async fn create_user(
278 &self,
279 email_address: &str,
280 admin: bool,
281 params: NewUserParams,
282 ) -> Result<NewUserResult> {
283 self.transact(|mut tx| async {
284 let query = "
285 INSERT INTO users (email_address, github_login, github_user_id, admin)
286 VALUES ($1, $2, $3, $4)
287 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
288 RETURNING id, metrics_id::text
289 ";
290
291 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
292 .bind(email_address)
293 .bind(¶ms.github_login)
294 .bind(params.github_user_id)
295 .bind(admin)
296 .fetch_one(&mut tx)
297 .await?;
298 tx.commit().await?;
299
300 Ok(NewUserResult {
301 user_id,
302 metrics_id,
303 signup_device_id: None,
304 inviting_user_id: None,
305 })
306 })
307 .await
308 }
309
310 pub async fn create_user_from_invite(
311 &self,
312 invite: &Invite,
313 user: NewUserParams,
314 ) -> Result<Option<NewUserResult>> {
315 self.transact(|mut tx| async {
316 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
317 i32,
318 Option<UserId>,
319 Option<UserId>,
320 Option<String>,
321 ) = sqlx::query_as(
322 "
323 SELECT id, user_id, inviting_user_id, device_id
324 FROM signups
325 WHERE
326 email_address = $1 AND
327 email_confirmation_code = $2
328 ",
329 )
330 .bind(&invite.email_address)
331 .bind(&invite.email_confirmation_code)
332 .fetch_optional(&mut tx)
333 .await?
334 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
335
336 if existing_user_id.is_some() {
337 return Ok(None);
338 }
339
340 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
341 "
342 INSERT INTO users
343 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
344 VALUES
345 ($1, $2, $3, FALSE, $4, $5)
346 ON CONFLICT (github_login) DO UPDATE SET
347 email_address = excluded.email_address,
348 github_user_id = excluded.github_user_id,
349 admin = excluded.admin
350 RETURNING id, metrics_id::text
351 ",
352 )
353 .bind(&invite.email_address)
354 .bind(&user.github_login)
355 .bind(&user.github_user_id)
356 .bind(&user.invite_count)
357 .bind(random_invite_code())
358 .fetch_one(&mut tx)
359 .await?;
360
361 sqlx::query(
362 "
363 UPDATE signups
364 SET user_id = $1
365 WHERE id = $2
366 ",
367 )
368 .bind(&user_id)
369 .bind(&signup_id)
370 .execute(&mut tx)
371 .await?;
372
373 if let Some(inviting_user_id) = inviting_user_id {
374 let id: Option<UserId> = sqlx::query_scalar(
375 "
376 UPDATE users
377 SET invite_count = invite_count - 1
378 WHERE id = $1 AND invite_count > 0
379 RETURNING id
380 ",
381 )
382 .bind(&inviting_user_id)
383 .fetch_optional(&mut tx)
384 .await?;
385
386 if id.is_none() {
387 Err(Error::Http(
388 StatusCode::UNAUTHORIZED,
389 "no invites remaining".to_string(),
390 ))?;
391 }
392
393 sqlx::query(
394 "
395 INSERT INTO contacts
396 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
397 VALUES
398 ($1, $2, TRUE, TRUE, TRUE)
399 ON CONFLICT DO NOTHING
400 ",
401 )
402 .bind(inviting_user_id)
403 .bind(user_id)
404 .execute(&mut tx)
405 .await?;
406 }
407
408 tx.commit().await?;
409 Ok(Some(NewUserResult {
410 user_id,
411 metrics_id,
412 inviting_user_id,
413 signup_device_id,
414 }))
415 })
416 .await
417 }
418
419 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
420 self.transact(|mut tx| async {
421 sqlx::query(
422 "
423 INSERT INTO signups
424 (
425 email_address,
426 email_confirmation_code,
427 email_confirmation_sent,
428 platform_linux,
429 platform_mac,
430 platform_windows,
431 platform_unknown,
432 editor_features,
433 programming_languages,
434 device_id
435 )
436 VALUES
437 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
438 RETURNING id
439 ",
440 )
441 .bind(&signup.email_address)
442 .bind(&random_email_confirmation_code())
443 .bind(&signup.platform_linux)
444 .bind(&signup.platform_mac)
445 .bind(&signup.platform_windows)
446 .bind(&signup.editor_features)
447 .bind(&signup.programming_languages)
448 .bind(&signup.device_id)
449 .execute(&mut tx)
450 .await?;
451 tx.commit().await?;
452 Ok(())
453 })
454 .await
455 }
456
457 pub async fn create_invite_from_code(
458 &self,
459 code: &str,
460 email_address: &str,
461 device_id: Option<&str>,
462 ) -> Result<Invite> {
463 self.transact(|mut tx| async {
464 let existing_user: Option<UserId> = sqlx::query_scalar(
465 "
466 SELECT id
467 FROM users
468 WHERE email_address = $1
469 ",
470 )
471 .bind(email_address)
472 .fetch_optional(&mut tx)
473 .await?;
474 if existing_user.is_some() {
475 Err(anyhow!("email address is already in use"))?;
476 }
477
478 let row: Option<(UserId, i32)> = sqlx::query_as(
479 "
480 SELECT id, invite_count
481 FROM users
482 WHERE invite_code = $1
483 ",
484 )
485 .bind(code)
486 .fetch_optional(&mut tx)
487 .await?;
488
489 let (inviter_id, invite_count) = match row {
490 Some(row) => row,
491 None => Err(Error::Http(
492 StatusCode::NOT_FOUND,
493 "invite code not found".to_string(),
494 ))?,
495 };
496
497 if invite_count == 0 {
498 Err(Error::Http(
499 StatusCode::UNAUTHORIZED,
500 "no invites remaining".to_string(),
501 ))?;
502 }
503
504 let email_confirmation_code: String = sqlx::query_scalar(
505 "
506 INSERT INTO signups
507 (
508 email_address,
509 email_confirmation_code,
510 email_confirmation_sent,
511 inviting_user_id,
512 platform_linux,
513 platform_mac,
514 platform_windows,
515 platform_unknown,
516 device_id
517 )
518 VALUES
519 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
520 ON CONFLICT (email_address)
521 DO UPDATE SET
522 inviting_user_id = excluded.inviting_user_id
523 RETURNING email_confirmation_code
524 ",
525 )
526 .bind(&email_address)
527 .bind(&random_email_confirmation_code())
528 .bind(&inviter_id)
529 .bind(&device_id)
530 .fetch_one(&mut tx)
531 .await?;
532
533 tx.commit().await?;
534
535 Ok(Invite {
536 email_address: email_address.into(),
537 email_confirmation_code,
538 })
539 })
540 .await
541 }
542
543 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
544 self.transact(|mut tx| async {
545 let emails = invites
546 .iter()
547 .map(|s| s.email_address.as_str())
548 .collect::<Vec<_>>();
549 sqlx::query(
550 "
551 UPDATE signups
552 SET email_confirmation_sent = TRUE
553 WHERE email_address = ANY ($1)
554 ",
555 )
556 .bind(&emails)
557 .execute(&mut tx)
558 .await?;
559 tx.commit().await?;
560 Ok(())
561 })
562 .await
563 }
564}
565
566impl<D> Db<D>
567where
568 Self: BeginTransaction<Database = D>,
569 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
570 D::Connection: sqlx::migrate::Migrate,
571 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
572 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
573 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
574 D::QueryResult: RowsAffected,
575 String: sqlx::Type<D>,
576 i32: sqlx::Type<D>,
577 i64: sqlx::Type<D>,
578 bool: sqlx::Type<D>,
579 str: sqlx::Type<D>,
580 Uuid: sqlx::Type<D>,
581 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
582 OffsetDateTime: sqlx::Type<D>,
583 PrimitiveDateTime: sqlx::Type<D>,
584 usize: sqlx::ColumnIndex<D::Row>,
585 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
586 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
587 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
588 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
589 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
590 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
591 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
592 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
593 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
594 for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
595 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
596 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
597 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
598{
599 pub async fn migrate(
600 &self,
601 migrations_path: &Path,
602 ignore_checksum_mismatch: bool,
603 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
604 let migrations = MigrationSource::resolve(migrations_path)
605 .await
606 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
607
608 let mut conn = self.pool.acquire().await?;
609
610 conn.ensure_migrations_table().await?;
611 let applied_migrations: HashMap<_, _> = conn
612 .list_applied_migrations()
613 .await?
614 .into_iter()
615 .map(|m| (m.version, m))
616 .collect();
617
618 let mut new_migrations = Vec::new();
619 for migration in migrations {
620 match applied_migrations.get(&migration.version) {
621 Some(applied_migration) => {
622 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
623 {
624 Err(anyhow!(
625 "checksum mismatch for applied migration {}",
626 migration.description
627 ))?;
628 }
629 }
630 None => {
631 let elapsed = conn.apply(&migration).await?;
632 new_migrations.push((migration, elapsed));
633 }
634 }
635 }
636
637 Ok(new_migrations)
638 }
639
640 pub fn fuzzy_like_string(string: &str) -> String {
641 let mut result = String::with_capacity(string.len() * 2 + 1);
642 for c in string.chars() {
643 if c.is_alphanumeric() {
644 result.push('%');
645 result.push(c);
646 }
647 }
648 result.push('%');
649 result
650 }
651
652 // users
653
654 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
655 self.transact(|tx| async {
656 let mut tx = tx;
657 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
658 Ok(sqlx::query_as(query)
659 .bind(limit as i32)
660 .bind((page * limit) as i32)
661 .fetch_all(&mut tx)
662 .await?)
663 })
664 .await
665 }
666
667 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
668 self.transact(|tx| async {
669 let mut tx = tx;
670 let query = "
671 SELECT users.*
672 FROM users
673 WHERE id = $1
674 LIMIT 1
675 ";
676 Ok(sqlx::query_as(query)
677 .bind(&id)
678 .fetch_optional(&mut tx)
679 .await?)
680 })
681 .await
682 }
683
684 pub async fn get_users_with_no_invites(
685 &self,
686 invited_by_another_user: bool,
687 ) -> Result<Vec<User>> {
688 self.transact(|tx| async {
689 let mut tx = tx;
690 let query = format!(
691 "
692 SELECT users.*
693 FROM users
694 WHERE invite_count = 0
695 AND inviter_id IS{} NULL
696 ",
697 if invited_by_another_user { " NOT" } else { "" }
698 );
699
700 Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
701 })
702 .await
703 }
704
705 pub async fn get_user_by_github_account(
706 &self,
707 github_login: &str,
708 github_user_id: Option<i32>,
709 ) -> Result<Option<User>> {
710 self.transact(|tx| async {
711 let mut tx = tx;
712 if let Some(github_user_id) = github_user_id {
713 let mut user = sqlx::query_as::<_, User>(
714 "
715 UPDATE users
716 SET github_login = $1
717 WHERE github_user_id = $2
718 RETURNING *
719 ",
720 )
721 .bind(github_login)
722 .bind(github_user_id)
723 .fetch_optional(&mut tx)
724 .await?;
725
726 if user.is_none() {
727 user = sqlx::query_as::<_, User>(
728 "
729 UPDATE users
730 SET github_user_id = $1
731 WHERE github_login = $2
732 RETURNING *
733 ",
734 )
735 .bind(github_user_id)
736 .bind(github_login)
737 .fetch_optional(&mut tx)
738 .await?;
739 }
740
741 Ok(user)
742 } else {
743 let user = sqlx::query_as(
744 "
745 SELECT * FROM users
746 WHERE github_login = $1
747 LIMIT 1
748 ",
749 )
750 .bind(github_login)
751 .fetch_optional(&mut tx)
752 .await?;
753 Ok(user)
754 }
755 })
756 .await
757 }
758
759 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
760 self.transact(|mut tx| async {
761 let query = "UPDATE users SET admin = $1 WHERE id = $2";
762 sqlx::query(query)
763 .bind(is_admin)
764 .bind(id.0)
765 .execute(&mut tx)
766 .await?;
767 tx.commit().await?;
768 Ok(())
769 })
770 .await
771 }
772
773 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
774 self.transact(|mut tx| async move {
775 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
776 sqlx::query(query)
777 .bind(connected_once)
778 .bind(id.0)
779 .execute(&mut tx)
780 .await?;
781 tx.commit().await?;
782 Ok(())
783 })
784 .await
785 }
786
787 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
788 self.transact(|mut tx| async move {
789 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
790 sqlx::query(query)
791 .bind(id.0)
792 .execute(&mut tx)
793 .await
794 .map(drop)?;
795 let query = "DELETE FROM users WHERE id = $1;";
796 sqlx::query(query).bind(id.0).execute(&mut tx).await?;
797 tx.commit().await?;
798 Ok(())
799 })
800 .await
801 }
802
803 // signups
804
805 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
806 self.transact(|mut tx| async move {
807 Ok(sqlx::query_as(
808 "
809 SELECT
810 COUNT(*) as count,
811 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
812 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
813 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
814 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
815 FROM (
816 SELECT *
817 FROM signups
818 WHERE
819 NOT email_confirmation_sent
820 ) AS unsent
821 ",
822 )
823 .fetch_one(&mut tx)
824 .await?)
825 })
826 .await
827 }
828
829 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
830 self.transact(|mut tx| async move {
831 Ok(sqlx::query_as(
832 "
833 SELECT
834 email_address, email_confirmation_code
835 FROM signups
836 WHERE
837 NOT email_confirmation_sent AND
838 (platform_mac OR platform_unknown)
839 LIMIT $1
840 ",
841 )
842 .bind(count as i32)
843 .fetch_all(&mut tx)
844 .await?)
845 })
846 .await
847 }
848
849 // invite codes
850
851 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
852 self.transact(|mut tx| async move {
853 if count > 0 {
854 sqlx::query(
855 "
856 UPDATE users
857 SET invite_code = $1
858 WHERE id = $2 AND invite_code IS NULL
859 ",
860 )
861 .bind(random_invite_code())
862 .bind(id)
863 .execute(&mut tx)
864 .await?;
865 }
866
867 sqlx::query(
868 "
869 UPDATE users
870 SET invite_count = $1
871 WHERE id = $2
872 ",
873 )
874 .bind(count as i32)
875 .bind(id)
876 .execute(&mut tx)
877 .await?;
878 tx.commit().await?;
879 Ok(())
880 })
881 .await
882 }
883
884 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
885 self.transact(|mut tx| async move {
886 let result: Option<(String, i32)> = sqlx::query_as(
887 "
888 SELECT invite_code, invite_count
889 FROM users
890 WHERE id = $1 AND invite_code IS NOT NULL
891 ",
892 )
893 .bind(id)
894 .fetch_optional(&mut tx)
895 .await?;
896 if let Some((code, count)) = result {
897 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
898 } else {
899 Ok(None)
900 }
901 })
902 .await
903 }
904
905 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
906 self.transact(|tx| async {
907 let mut tx = tx;
908 sqlx::query_as(
909 "
910 SELECT *
911 FROM users
912 WHERE invite_code = $1
913 ",
914 )
915 .bind(code)
916 .fetch_optional(&mut tx)
917 .await?
918 .ok_or_else(|| {
919 Error::Http(
920 StatusCode::NOT_FOUND,
921 "that invite code does not exist".to_string(),
922 )
923 })
924 })
925 .await
926 }
927
928 pub async fn create_room(
929 &self,
930 user_id: UserId,
931 connection_id: ConnectionId,
932 ) -> Result<proto::Room> {
933 self.transact(|mut tx| async move {
934 let live_kit_room = nanoid::nanoid!(30);
935 let room_id = sqlx::query_scalar(
936 "
937 INSERT INTO rooms (live_kit_room, version)
938 VALUES ($1, $2)
939 RETURNING id
940 ",
941 )
942 .bind(&live_kit_room)
943 .bind(0)
944 .fetch_one(&mut tx)
945 .await
946 .map(RoomId)?;
947
948 sqlx::query(
949 "
950 INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
951 VALUES ($1, $2, $3, $4, $5)
952 ",
953 )
954 .bind(room_id)
955 .bind(user_id)
956 .bind(connection_id.0 as i32)
957 .bind(user_id)
958 .bind(connection_id.0 as i32)
959 .execute(&mut tx)
960 .await?;
961
962 self.commit_room_transaction(room_id, tx).await
963 }).await
964 }
965
966 pub async fn call(
967 &self,
968 room_id: RoomId,
969 calling_user_id: UserId,
970 calling_connection_id: ConnectionId,
971 called_user_id: UserId,
972 initial_project_id: Option<ProjectId>,
973 ) -> Result<(proto::Room, proto::IncomingCall)> {
974 self.transact(|mut tx| async move {
975 sqlx::query(
976 "
977 INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id)
978 VALUES ($1, $2, $3, $4, $5)
979 ",
980 )
981 .bind(room_id)
982 .bind(called_user_id)
983 .bind(calling_user_id)
984 .bind(calling_connection_id.0 as i32)
985 .bind(initial_project_id)
986 .execute(&mut tx)
987 .await?;
988
989 let room = self.commit_room_transaction(room_id, tx).await?;
990 let incoming_call = Self::build_incoming_call(&room, called_user_id)
991 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
992 Ok((room, incoming_call))
993 }).await
994 }
995
996 pub async fn incoming_call_for_user(
997 &self,
998 user_id: UserId,
999 ) -> Result<Option<proto::IncomingCall>> {
1000 self.transact(|mut tx| async move {
1001 let room_id = sqlx::query_scalar::<_, RoomId>(
1002 "
1003 SELECT room_id
1004 FROM room_participants
1005 WHERE user_id = $1 AND answering_connection_id IS NULL
1006 ",
1007 )
1008 .bind(user_id)
1009 .fetch_optional(&mut tx)
1010 .await?;
1011
1012 if let Some(room_id) = room_id {
1013 let room = self.get_room(room_id, &mut tx).await?;
1014 Ok(Self::build_incoming_call(&room, user_id))
1015 } else {
1016 Ok(None)
1017 }
1018 })
1019 .await
1020 }
1021
1022 fn build_incoming_call(
1023 room: &proto::Room,
1024 called_user_id: UserId,
1025 ) -> Option<proto::IncomingCall> {
1026 let pending_participant = room
1027 .pending_participants
1028 .iter()
1029 .find(|participant| participant.user_id == called_user_id.to_proto())?;
1030
1031 Some(proto::IncomingCall {
1032 room_id: room.id,
1033 calling_user_id: pending_participant.calling_user_id,
1034 participant_user_ids: room
1035 .participants
1036 .iter()
1037 .map(|participant| participant.user_id)
1038 .collect(),
1039 initial_project: room.participants.iter().find_map(|participant| {
1040 let initial_project_id = pending_participant.initial_project_id?;
1041 participant
1042 .projects
1043 .iter()
1044 .find(|project| project.id == initial_project_id)
1045 .cloned()
1046 }),
1047 })
1048 }
1049
1050 pub async fn call_failed(
1051 &self,
1052 room_id: RoomId,
1053 called_user_id: UserId,
1054 ) -> Result<proto::Room> {
1055 self.transact(|mut tx| async move {
1056 sqlx::query(
1057 "
1058 DELETE FROM room_participants
1059 WHERE room_id = $1 AND user_id = $2
1060 ",
1061 )
1062 .bind(room_id)
1063 .bind(called_user_id)
1064 .execute(&mut tx)
1065 .await?;
1066
1067 self.commit_room_transaction(room_id, tx).await
1068 })
1069 .await
1070 }
1071
1072 pub async fn decline_call(
1073 &self,
1074 expected_room_id: Option<RoomId>,
1075 user_id: UserId,
1076 ) -> Result<proto::Room> {
1077 self.transact(|mut tx| async move {
1078 let room_id = sqlx::query_scalar(
1079 "
1080 DELETE FROM room_participants
1081 WHERE user_id = $1 AND answering_connection_id IS NULL
1082 RETURNING room_id
1083 ",
1084 )
1085 .bind(user_id)
1086 .fetch_one(&mut tx)
1087 .await?;
1088 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1089 return Err(anyhow!("declining call on unexpected room"))?;
1090 }
1091
1092 self.commit_room_transaction(room_id, tx).await
1093 })
1094 .await
1095 }
1096
1097 pub async fn cancel_call(
1098 &self,
1099 expected_room_id: Option<RoomId>,
1100 calling_connection_id: ConnectionId,
1101 called_user_id: UserId,
1102 ) -> Result<proto::Room> {
1103 self.transact(|mut tx| async move {
1104 let room_id = sqlx::query_scalar(
1105 "
1106 DELETE FROM room_participants
1107 WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1108 RETURNING room_id
1109 ",
1110 )
1111 .bind(called_user_id)
1112 .bind(calling_connection_id.0 as i32)
1113 .fetch_one(&mut tx)
1114 .await?;
1115 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1116 return Err(anyhow!("canceling call on unexpected room"))?;
1117 }
1118
1119 self.commit_room_transaction(room_id, tx).await
1120 }).await
1121 }
1122
1123 pub async fn join_room(
1124 &self,
1125 room_id: RoomId,
1126 user_id: UserId,
1127 connection_id: ConnectionId,
1128 ) -> Result<proto::Room> {
1129 self.transact(|mut tx| async move {
1130 sqlx::query(
1131 "
1132 UPDATE room_participants
1133 SET answering_connection_id = $1
1134 WHERE room_id = $2 AND user_id = $3
1135 RETURNING 1
1136 ",
1137 )
1138 .bind(connection_id.0 as i32)
1139 .bind(room_id)
1140 .bind(user_id)
1141 .fetch_one(&mut tx)
1142 .await?;
1143 self.commit_room_transaction(room_id, tx).await
1144 })
1145 .await
1146 }
1147
1148 pub async fn leave_room_for_connection(
1149 &self,
1150 connection_id: ConnectionId,
1151 ) -> Result<Option<LeftRoom>> {
1152 self.transact(|mut tx| async move {
1153 // Leave room.
1154 let room_id = sqlx::query_scalar::<_, RoomId>(
1155 "
1156 DELETE FROM room_participants
1157 WHERE answering_connection_id = $1
1158 RETURNING room_id
1159 ",
1160 )
1161 .bind(connection_id.0 as i32)
1162 .fetch_optional(&mut tx)
1163 .await?;
1164
1165 if let Some(room_id) = room_id {
1166 // Cancel pending calls initiated by the leaving user.
1167 let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1168 "
1169 DELETE FROM room_participants
1170 WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1171 RETURNING user_id
1172 ",
1173 )
1174 .bind(connection_id.0 as i32)
1175 .fetch_all(&mut tx)
1176 .await?;
1177
1178 let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1179 "
1180 SELECT project_collaborators.*
1181 FROM projects, project_collaborators
1182 WHERE
1183 projects.room_id = $1 AND
1184 projects.id = project_collaborators.project_id AND
1185 project_collaborators.connection_id = $2
1186 ",
1187 )
1188 .bind(room_id)
1189 .bind(connection_id.0 as i32)
1190 .fetch(&mut tx);
1191
1192 let mut left_projects = HashMap::default();
1193 while let Some(collaborator) = project_collaborators.next().await {
1194 let collaborator = collaborator?;
1195 let left_project =
1196 left_projects
1197 .entry(collaborator.project_id)
1198 .or_insert(LeftProject {
1199 id: collaborator.project_id,
1200 host_user_id: Default::default(),
1201 connection_ids: Default::default(),
1202 });
1203
1204 let collaborator_connection_id =
1205 ConnectionId(collaborator.connection_id as u32);
1206 if collaborator_connection_id != connection_id || collaborator.is_host {
1207 left_project.connection_ids.push(collaborator_connection_id);
1208 }
1209
1210 if collaborator.is_host {
1211 left_project.host_user_id = collaborator.user_id;
1212 }
1213 }
1214 drop(project_collaborators);
1215
1216 sqlx::query(
1217 "
1218 DELETE FROM projects
1219 WHERE room_id = $1 AND host_connection_id = $2
1220 ",
1221 )
1222 .bind(room_id)
1223 .bind(connection_id.0 as i32)
1224 .execute(&mut tx)
1225 .await?;
1226
1227 let room = self.commit_room_transaction(room_id, tx).await?;
1228 Ok(Some(LeftRoom {
1229 room,
1230 left_projects,
1231 canceled_calls_to_user_ids,
1232 }))
1233 } else {
1234 Ok(None)
1235 }
1236 })
1237 .await
1238 }
1239
1240 pub async fn update_room_participant_location(
1241 &self,
1242 room_id: RoomId,
1243 connection_id: ConnectionId,
1244 location: proto::ParticipantLocation,
1245 ) -> Result<proto::Room> {
1246 self.transact(|tx| async {
1247 let mut tx = tx;
1248 let location_kind;
1249 let location_project_id;
1250 match location
1251 .variant
1252 .as_ref()
1253 .ok_or_else(|| anyhow!("invalid location"))?
1254 {
1255 proto::participant_location::Variant::SharedProject(project) => {
1256 location_kind = 0;
1257 location_project_id = Some(ProjectId::from_proto(project.id));
1258 }
1259 proto::participant_location::Variant::UnsharedProject(_) => {
1260 location_kind = 1;
1261 location_project_id = None;
1262 }
1263 proto::participant_location::Variant::External(_) => {
1264 location_kind = 2;
1265 location_project_id = None;
1266 }
1267 }
1268
1269 sqlx::query(
1270 "
1271 UPDATE room_participants
1272 SET location_kind = $1 AND location_project_id = $2
1273 WHERE room_id = $3 AND answering_connection_id = $4
1274 ",
1275 )
1276 .bind(location_kind)
1277 .bind(location_project_id)
1278 .bind(room_id)
1279 .bind(connection_id.0 as i32)
1280 .execute(&mut tx)
1281 .await?;
1282
1283 self.commit_room_transaction(room_id, tx).await
1284 })
1285 .await
1286 }
1287
1288 async fn commit_room_transaction(
1289 &self,
1290 room_id: RoomId,
1291 mut tx: sqlx::Transaction<'_, D>,
1292 ) -> Result<proto::Room> {
1293 sqlx::query(
1294 "
1295 UPDATE rooms
1296 SET version = version + 1
1297 WHERE id = $1
1298 ",
1299 )
1300 .bind(room_id)
1301 .execute(&mut tx)
1302 .await?;
1303 let room = self.get_room(room_id, &mut tx).await?;
1304 tx.commit().await?;
1305
1306 Ok(room)
1307 }
1308
1309 async fn get_room(
1310 &self,
1311 room_id: RoomId,
1312 tx: &mut sqlx::Transaction<'_, D>,
1313 ) -> Result<proto::Room> {
1314 let room: Room = sqlx::query_as(
1315 "
1316 SELECT *
1317 FROM rooms
1318 WHERE id = $1
1319 ",
1320 )
1321 .bind(room_id)
1322 .fetch_one(&mut *tx)
1323 .await?;
1324
1325 let mut db_participants =
1326 sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1327 "
1328 SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1329 FROM room_participants
1330 WHERE room_id = $1
1331 ",
1332 )
1333 .bind(room_id)
1334 .fetch(&mut *tx);
1335
1336 let mut participants = Vec::new();
1337 let mut pending_participants = Vec::new();
1338 while let Some(participant) = db_participants.next().await {
1339 let (
1340 user_id,
1341 answering_connection_id,
1342 _location_kind,
1343 _location_project_id,
1344 calling_user_id,
1345 initial_project_id,
1346 ) = participant?;
1347 if let Some(answering_connection_id) = answering_connection_id {
1348 participants.push(proto::Participant {
1349 user_id: user_id.to_proto(),
1350 peer_id: answering_connection_id as u32,
1351 projects: Default::default(),
1352 location: Some(proto::ParticipantLocation {
1353 variant: Some(proto::participant_location::Variant::External(
1354 Default::default(),
1355 )),
1356 }),
1357 });
1358 } else {
1359 pending_participants.push(proto::PendingParticipant {
1360 user_id: user_id.to_proto(),
1361 calling_user_id: calling_user_id.to_proto(),
1362 initial_project_id: initial_project_id.map(|id| id.to_proto()),
1363 });
1364 }
1365 }
1366 drop(db_participants);
1367
1368 for participant in &mut participants {
1369 let mut entries = sqlx::query_as::<_, (ProjectId, String)>(
1370 "
1371 SELECT projects.id, worktrees.root_name
1372 FROM projects
1373 LEFT JOIN worktrees ON projects.id = worktrees.project_id
1374 WHERE room_id = $1 AND host_connection_id = $2
1375 ",
1376 )
1377 .bind(room_id)
1378 .bind(participant.peer_id as i32)
1379 .fetch(&mut *tx);
1380
1381 let mut projects = HashMap::default();
1382 while let Some(entry) = entries.next().await {
1383 let (project_id, worktree_root_name) = entry?;
1384 let participant_project =
1385 projects
1386 .entry(project_id)
1387 .or_insert(proto::ParticipantProject {
1388 id: project_id.to_proto(),
1389 worktree_root_names: Default::default(),
1390 });
1391 participant_project
1392 .worktree_root_names
1393 .push(worktree_root_name);
1394 }
1395
1396 participant.projects = projects.into_values().collect();
1397 }
1398 Ok(proto::Room {
1399 id: room.id.to_proto(),
1400 version: room.version as u64,
1401 live_kit_room: room.live_kit_room,
1402 participants,
1403 pending_participants,
1404 })
1405 }
1406
1407 // projects
1408
1409 pub async fn share_project(
1410 &self,
1411 expected_room_id: RoomId,
1412 connection_id: ConnectionId,
1413 worktrees: &[proto::WorktreeMetadata],
1414 ) -> Result<(ProjectId, proto::Room)> {
1415 self.transact(|mut tx| async move {
1416 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1417 "
1418 SELECT room_id, user_id
1419 FROM room_participants
1420 WHERE answering_connection_id = $1
1421 ",
1422 )
1423 .bind(connection_id.0 as i32)
1424 .fetch_one(&mut tx)
1425 .await?;
1426 if room_id != expected_room_id {
1427 return Err(anyhow!("shared project on unexpected room"))?;
1428 }
1429
1430 let project_id: ProjectId = sqlx::query_scalar(
1431 "
1432 INSERT INTO projects (room_id, host_user_id, host_connection_id)
1433 VALUES ($1, $2, $3)
1434 RETURNING id
1435 ",
1436 )
1437 .bind(room_id)
1438 .bind(user_id)
1439 .bind(connection_id.0 as i32)
1440 .fetch_one(&mut tx)
1441 .await?;
1442
1443 for worktree in worktrees {
1444 sqlx::query(
1445 "
1446 INSERT INTO worktrees (id, project_id, root_name)
1447 VALUES ($1, $2, $3)
1448 ",
1449 )
1450 .bind(worktree.id as i32)
1451 .bind(project_id)
1452 .bind(&worktree.root_name)
1453 .execute(&mut tx)
1454 .await?;
1455 }
1456
1457 sqlx::query(
1458 "
1459 INSERT INTO project_collaborators (
1460 project_id,
1461 connection_id,
1462 user_id,
1463 replica_id,
1464 is_host
1465 )
1466 VALUES ($1, $2, $3, $4, $5)
1467 ",
1468 )
1469 .bind(project_id)
1470 .bind(connection_id.0 as i32)
1471 .bind(user_id)
1472 .bind(0)
1473 .bind(true)
1474 .execute(&mut tx)
1475 .await?;
1476
1477 let room = self.commit_room_transaction(room_id, tx).await?;
1478 Ok((project_id, room))
1479 })
1480 .await
1481 }
1482
1483 pub async fn update_project(
1484 &self,
1485 project_id: ProjectId,
1486 connection_id: ConnectionId,
1487 worktrees: &[proto::WorktreeMetadata],
1488 ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1489 self.transact(|mut tx| async move {
1490 let room_id: RoomId = sqlx::query_scalar(
1491 "
1492 SELECT room_id
1493 FROM projects
1494 WHERE id = $1 AND host_connection_id = $2
1495 ",
1496 )
1497 .bind(project_id)
1498 .bind(connection_id.0 as i32)
1499 .fetch_one(&mut tx)
1500 .await?;
1501
1502 for worktree in worktrees {
1503 sqlx::query(
1504 "
1505 INSERT INTO worktrees (project_id, id, root_name)
1506 VALUES ($1, $2, $3)
1507 ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1508 ",
1509 )
1510 .bind(project_id)
1511 .bind(worktree.id as i32)
1512 .bind(&worktree.root_name)
1513 .execute(&mut tx)
1514 .await?;
1515 }
1516
1517 let mut params = "?,".repeat(worktrees.len());
1518 if !worktrees.is_empty() {
1519 params.pop();
1520 }
1521 let query = format!(
1522 "
1523 DELETE FROM worktrees
1524 WHERE id NOT IN ({params})
1525 ",
1526 );
1527
1528 let mut query = sqlx::query(&query);
1529 for worktree in worktrees {
1530 query = query.bind(worktree.id as i32);
1531 }
1532 query.execute(&mut tx).await?;
1533
1534 let mut guest_connection_ids = Vec::new();
1535 {
1536 let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1537 "
1538 SELECT connection_id
1539 FROM project_collaborators
1540 WHERE project_id = $1 AND is_host = FALSE
1541 ",
1542 )
1543 .fetch(&mut tx);
1544 while let Some(connection_id) = db_guest_connection_ids.next().await {
1545 guest_connection_ids.push(ConnectionId(connection_id? as u32));
1546 }
1547 }
1548
1549 let room = self.commit_room_transaction(room_id, tx).await?;
1550 Ok((room, guest_connection_ids))
1551 })
1552 .await
1553 }
1554
1555 pub async fn join_project(
1556 &self,
1557 project_id: ProjectId,
1558 connection_id: ConnectionId,
1559 ) -> Result<(Project, i32)> {
1560 self.transact(|mut tx| async move {
1561 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1562 "
1563 SELECT room_id, user_id
1564 FROM room_participants
1565 WHERE answering_connection_id = $1
1566 ",
1567 )
1568 .bind(connection_id.0 as i32)
1569 .fetch_one(&mut tx)
1570 .await?;
1571
1572 // Ensure project id was shared on this room.
1573 sqlx::query(
1574 "
1575 SELECT 1
1576 FROM projects
1577 WHERE project_id = $1 AND room_id = $2
1578 ",
1579 )
1580 .bind(project_id)
1581 .bind(room_id)
1582 .fetch_one(&mut tx)
1583 .await?;
1584
1585 let replica_ids = sqlx::query_scalar::<_, i32>(
1586 "
1587 SELECT replica_id
1588 FROM project_collaborators
1589 WHERE project_id = $1
1590 ",
1591 )
1592 .bind(project_id)
1593 .fetch_all(&mut tx)
1594 .await?;
1595 let replica_ids = HashSet::from_iter(replica_ids);
1596 let mut replica_id = 1;
1597 while replica_ids.contains(&replica_id) {
1598 replica_id += 1;
1599 }
1600
1601 sqlx::query(
1602 "
1603 INSERT INTO project_collaborators (
1604 project_id,
1605 connection_id,
1606 user_id,
1607 replica_id,
1608 is_host
1609 )
1610 VALUES ($1, $2, $3, $4, $5)
1611 ",
1612 )
1613 .bind(project_id)
1614 .bind(connection_id.0 as i32)
1615 .bind(user_id)
1616 .bind(replica_id)
1617 .bind(false)
1618 .execute(&mut tx)
1619 .await?;
1620
1621 tx.commit().await?;
1622 todo!()
1623 })
1624 .await
1625 // sqlx::query(
1626 // "
1627 // SELECT replica_id
1628 // FROM project_collaborators
1629 // WHERE project_id = $
1630 // ",
1631 // )
1632 // .bind(project_id)
1633 // .bind(connection_id.0 as i32)
1634 // .bind(user_id)
1635 // .bind(0)
1636 // .bind(true)
1637 // .execute(&mut tx)
1638 // .await?;
1639 // sqlx::query(
1640 // "
1641 // INSERT INTO project_collaborators (
1642 // project_id,
1643 // connection_id,
1644 // user_id,
1645 // replica_id,
1646 // is_host
1647 // )
1648 // VALUES ($1, $2, $3, $4, $5)
1649 // ",
1650 // )
1651 // .bind(project_id)
1652 // .bind(connection_id.0 as i32)
1653 // .bind(user_id)
1654 // .bind(0)
1655 // .bind(true)
1656 // .execute(&mut tx)
1657 // .await?;
1658 }
1659
1660 pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> {
1661 todo!()
1662 // test_support!(self, {
1663 // sqlx::query(
1664 // "
1665 // UPDATE projects
1666 // SET unregistered = TRUE
1667 // WHERE id = $1
1668 // ",
1669 // )
1670 // .bind(project_id)
1671 // .execute(&self.pool)
1672 // .await?;
1673 // Ok(())
1674 // })
1675 }
1676
1677 // contacts
1678
1679 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
1680 self.transact(|mut tx| async move {
1681 let query = "
1682 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
1683 FROM contacts
1684 LEFT JOIN room_participants ON room_participants.user_id = $1
1685 WHERE user_id_a = $1 OR user_id_b = $1;
1686 ";
1687
1688 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
1689 .bind(user_id)
1690 .fetch(&mut tx);
1691
1692 let mut contacts = Vec::new();
1693 while let Some(row) = rows.next().await {
1694 let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
1695 if user_id_a == user_id {
1696 if accepted {
1697 contacts.push(Contact::Accepted {
1698 user_id: user_id_b,
1699 should_notify: should_notify && a_to_b,
1700 busy
1701 });
1702 } else if a_to_b {
1703 contacts.push(Contact::Outgoing { user_id: user_id_b })
1704 } else {
1705 contacts.push(Contact::Incoming {
1706 user_id: user_id_b,
1707 should_notify,
1708 });
1709 }
1710 } else if accepted {
1711 contacts.push(Contact::Accepted {
1712 user_id: user_id_a,
1713 should_notify: should_notify && !a_to_b,
1714 busy
1715 });
1716 } else if a_to_b {
1717 contacts.push(Contact::Incoming {
1718 user_id: user_id_a,
1719 should_notify,
1720 });
1721 } else {
1722 contacts.push(Contact::Outgoing { user_id: user_id_a });
1723 }
1724 }
1725
1726 contacts.sort_unstable_by_key(|contact| contact.user_id());
1727
1728 Ok(contacts)
1729 })
1730 .await
1731 }
1732
1733 pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
1734 self.transact(|mut tx| async move {
1735 Ok(sqlx::query_scalar::<_, i32>(
1736 "
1737 SELECT 1
1738 FROM room_participants
1739 WHERE room_participants.user_id = $1
1740 ",
1741 )
1742 .bind(user_id)
1743 .fetch_optional(&mut tx)
1744 .await?
1745 .is_some())
1746 })
1747 .await
1748 }
1749
1750 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
1751 self.transact(|mut tx| async move {
1752 let (id_a, id_b) = if user_id_1 < user_id_2 {
1753 (user_id_1, user_id_2)
1754 } else {
1755 (user_id_2, user_id_1)
1756 };
1757
1758 let query = "
1759 SELECT 1 FROM contacts
1760 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
1761 LIMIT 1
1762 ";
1763 Ok(sqlx::query_scalar::<_, i32>(query)
1764 .bind(id_a.0)
1765 .bind(id_b.0)
1766 .fetch_optional(&mut tx)
1767 .await?
1768 .is_some())
1769 })
1770 .await
1771 }
1772
1773 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
1774 self.transact(|mut tx| async move {
1775 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
1776 (sender_id, receiver_id, true)
1777 } else {
1778 (receiver_id, sender_id, false)
1779 };
1780 let query = "
1781 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
1782 VALUES ($1, $2, $3, FALSE, TRUE)
1783 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
1784 SET
1785 accepted = TRUE,
1786 should_notify = FALSE
1787 WHERE
1788 NOT contacts.accepted AND
1789 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
1790 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
1791 ";
1792 let result = sqlx::query(query)
1793 .bind(id_a.0)
1794 .bind(id_b.0)
1795 .bind(a_to_b)
1796 .execute(&mut tx)
1797 .await?;
1798
1799 if result.rows_affected() == 1 {
1800 tx.commit().await?;
1801 Ok(())
1802 } else {
1803 Err(anyhow!("contact already requested"))?
1804 }
1805 }).await
1806 }
1807
1808 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
1809 self.transact(|mut tx| async move {
1810 let (id_a, id_b) = if responder_id < requester_id {
1811 (responder_id, requester_id)
1812 } else {
1813 (requester_id, responder_id)
1814 };
1815 let query = "
1816 DELETE FROM contacts
1817 WHERE user_id_a = $1 AND user_id_b = $2;
1818 ";
1819 let result = sqlx::query(query)
1820 .bind(id_a.0)
1821 .bind(id_b.0)
1822 .execute(&mut tx)
1823 .await?;
1824
1825 if result.rows_affected() == 1 {
1826 tx.commit().await?;
1827 Ok(())
1828 } else {
1829 Err(anyhow!("no such contact"))?
1830 }
1831 })
1832 .await
1833 }
1834
1835 pub async fn dismiss_contact_notification(
1836 &self,
1837 user_id: UserId,
1838 contact_user_id: UserId,
1839 ) -> Result<()> {
1840 self.transact(|mut tx| async move {
1841 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
1842 (user_id, contact_user_id, true)
1843 } else {
1844 (contact_user_id, user_id, false)
1845 };
1846
1847 let query = "
1848 UPDATE contacts
1849 SET should_notify = FALSE
1850 WHERE
1851 user_id_a = $1 AND user_id_b = $2 AND
1852 (
1853 (a_to_b = $3 AND accepted) OR
1854 (a_to_b != $3 AND NOT accepted)
1855 );
1856 ";
1857
1858 let result = sqlx::query(query)
1859 .bind(id_a.0)
1860 .bind(id_b.0)
1861 .bind(a_to_b)
1862 .execute(&mut tx)
1863 .await?;
1864
1865 if result.rows_affected() == 0 {
1866 Err(anyhow!("no such contact request"))?
1867 } else {
1868 tx.commit().await?;
1869 Ok(())
1870 }
1871 })
1872 .await
1873 }
1874
1875 pub async fn respond_to_contact_request(
1876 &self,
1877 responder_id: UserId,
1878 requester_id: UserId,
1879 accept: bool,
1880 ) -> Result<()> {
1881 self.transact(|mut tx| async move {
1882 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
1883 (responder_id, requester_id, false)
1884 } else {
1885 (requester_id, responder_id, true)
1886 };
1887 let result = if accept {
1888 let query = "
1889 UPDATE contacts
1890 SET accepted = TRUE, should_notify = TRUE
1891 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
1892 ";
1893 sqlx::query(query)
1894 .bind(id_a.0)
1895 .bind(id_b.0)
1896 .bind(a_to_b)
1897 .execute(&mut tx)
1898 .await?
1899 } else {
1900 let query = "
1901 DELETE FROM contacts
1902 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
1903 ";
1904 sqlx::query(query)
1905 .bind(id_a.0)
1906 .bind(id_b.0)
1907 .bind(a_to_b)
1908 .execute(&mut tx)
1909 .await?
1910 };
1911 if result.rows_affected() == 1 {
1912 tx.commit().await?;
1913 Ok(())
1914 } else {
1915 Err(anyhow!("no such contact request"))?
1916 }
1917 })
1918 .await
1919 }
1920
1921 // access tokens
1922
1923 pub async fn create_access_token_hash(
1924 &self,
1925 user_id: UserId,
1926 access_token_hash: &str,
1927 max_access_token_count: usize,
1928 ) -> Result<()> {
1929 self.transact(|tx| async {
1930 let mut tx = tx;
1931 let insert_query = "
1932 INSERT INTO access_tokens (user_id, hash)
1933 VALUES ($1, $2);
1934 ";
1935 let cleanup_query = "
1936 DELETE FROM access_tokens
1937 WHERE id IN (
1938 SELECT id from access_tokens
1939 WHERE user_id = $1
1940 ORDER BY id DESC
1941 LIMIT 10000
1942 OFFSET $3
1943 )
1944 ";
1945
1946 sqlx::query(insert_query)
1947 .bind(user_id.0)
1948 .bind(access_token_hash)
1949 .execute(&mut tx)
1950 .await?;
1951 sqlx::query(cleanup_query)
1952 .bind(user_id.0)
1953 .bind(access_token_hash)
1954 .bind(max_access_token_count as i32)
1955 .execute(&mut tx)
1956 .await?;
1957 Ok(tx.commit().await?)
1958 })
1959 .await
1960 }
1961
1962 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1963 self.transact(|mut tx| async move {
1964 let query = "
1965 SELECT hash
1966 FROM access_tokens
1967 WHERE user_id = $1
1968 ORDER BY id DESC
1969 ";
1970 Ok(sqlx::query_scalar(query)
1971 .bind(user_id.0)
1972 .fetch_all(&mut tx)
1973 .await?)
1974 })
1975 .await
1976 }
1977
1978 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
1979 where
1980 F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
1981 Fut: Send + Future<Output = Result<T>>,
1982 {
1983 let body = async {
1984 loop {
1985 let tx = self.begin_transaction().await?;
1986 match f(tx).await {
1987 Ok(result) => return Ok(result),
1988 Err(error) => match error {
1989 Error::Database(error)
1990 if error
1991 .as_database_error()
1992 .and_then(|error| error.code())
1993 .as_deref()
1994 == Some("hey") =>
1995 {
1996 // Retry (don't break the loop)
1997 }
1998 error @ _ => return Err(error),
1999 },
2000 }
2001 }
2002 };
2003
2004 #[cfg(test)]
2005 {
2006 if let Some(background) = self.background.as_ref() {
2007 background.simulate_random_delay().await;
2008 }
2009
2010 let result = self.runtime.as_ref().unwrap().block_on(body);
2011
2012 if let Some(background) = self.background.as_ref() {
2013 background.simulate_random_delay().await;
2014 }
2015
2016 result
2017 }
2018
2019 #[cfg(not(test))]
2020 {
2021 body.await
2022 }
2023 }
2024}
2025
2026macro_rules! id_type {
2027 ($name:ident) => {
2028 #[derive(
2029 Clone,
2030 Copy,
2031 Debug,
2032 Default,
2033 PartialEq,
2034 Eq,
2035 PartialOrd,
2036 Ord,
2037 Hash,
2038 sqlx::Type,
2039 Serialize,
2040 Deserialize,
2041 )]
2042 #[sqlx(transparent)]
2043 #[serde(transparent)]
2044 pub struct $name(pub i32);
2045
2046 impl $name {
2047 #[allow(unused)]
2048 pub const MAX: Self = Self(i32::MAX);
2049
2050 #[allow(unused)]
2051 pub fn from_proto(value: u64) -> Self {
2052 Self(value as i32)
2053 }
2054
2055 #[allow(unused)]
2056 pub fn to_proto(self) -> u64 {
2057 self.0 as u64
2058 }
2059 }
2060
2061 impl std::fmt::Display for $name {
2062 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2063 self.0.fmt(f)
2064 }
2065 }
2066 };
2067}
2068
2069id_type!(UserId);
2070#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2071pub struct User {
2072 pub id: UserId,
2073 pub github_login: String,
2074 pub github_user_id: Option<i32>,
2075 pub email_address: Option<String>,
2076 pub admin: bool,
2077 pub invite_code: Option<String>,
2078 pub invite_count: i32,
2079 pub connected_once: bool,
2080}
2081
2082id_type!(RoomId);
2083#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2084pub struct Room {
2085 pub id: RoomId,
2086 pub version: i32,
2087 pub live_kit_room: String,
2088}
2089
2090id_type!(ProjectId);
2091pub struct Project {
2092 pub id: ProjectId,
2093 pub collaborators: Vec<ProjectCollaborator>,
2094 pub worktrees: BTreeMap<u64, Worktree>,
2095 pub language_servers: Vec<proto::LanguageServer>,
2096}
2097
2098#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2099pub struct ProjectCollaborator {
2100 pub project_id: ProjectId,
2101 pub connection_id: i32,
2102 pub user_id: UserId,
2103 pub replica_id: i32,
2104 pub is_host: bool,
2105}
2106
2107#[derive(Default)]
2108pub struct Worktree {
2109 pub abs_path: PathBuf,
2110 pub root_name: String,
2111 pub visible: bool,
2112 pub entries: BTreeMap<u64, proto::Entry>,
2113 pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
2114 pub scan_id: u64,
2115 pub is_complete: bool,
2116}
2117
2118pub struct LeftProject {
2119 pub id: ProjectId,
2120 pub host_user_id: UserId,
2121 pub connection_ids: Vec<ConnectionId>,
2122}
2123
2124pub struct LeftRoom {
2125 pub room: proto::Room,
2126 pub left_projects: HashMap<ProjectId, LeftProject>,
2127 pub canceled_calls_to_user_ids: Vec<UserId>,
2128}
2129
2130#[derive(Clone, Debug, PartialEq, Eq)]
2131pub enum Contact {
2132 Accepted {
2133 user_id: UserId,
2134 should_notify: bool,
2135 busy: bool,
2136 },
2137 Outgoing {
2138 user_id: UserId,
2139 },
2140 Incoming {
2141 user_id: UserId,
2142 should_notify: bool,
2143 },
2144}
2145
2146impl Contact {
2147 pub fn user_id(&self) -> UserId {
2148 match self {
2149 Contact::Accepted { user_id, .. } => *user_id,
2150 Contact::Outgoing { user_id } => *user_id,
2151 Contact::Incoming { user_id, .. } => *user_id,
2152 }
2153 }
2154}
2155
2156#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2157pub struct IncomingContactRequest {
2158 pub requester_id: UserId,
2159 pub should_notify: bool,
2160}
2161
2162#[derive(Clone, Deserialize)]
2163pub struct Signup {
2164 pub email_address: String,
2165 pub platform_mac: bool,
2166 pub platform_windows: bool,
2167 pub platform_linux: bool,
2168 pub editor_features: Vec<String>,
2169 pub programming_languages: Vec<String>,
2170 pub device_id: Option<String>,
2171}
2172
2173#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2174pub struct WaitlistSummary {
2175 #[sqlx(default)]
2176 pub count: i64,
2177 #[sqlx(default)]
2178 pub linux_count: i64,
2179 #[sqlx(default)]
2180 pub mac_count: i64,
2181 #[sqlx(default)]
2182 pub windows_count: i64,
2183 #[sqlx(default)]
2184 pub unknown_count: i64,
2185}
2186
2187#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2188pub struct Invite {
2189 pub email_address: String,
2190 pub email_confirmation_code: String,
2191}
2192
2193#[derive(Debug, Serialize, Deserialize)]
2194pub struct NewUserParams {
2195 pub github_login: String,
2196 pub github_user_id: i32,
2197 pub invite_count: i32,
2198}
2199
2200#[derive(Debug)]
2201pub struct NewUserResult {
2202 pub user_id: UserId,
2203 pub metrics_id: String,
2204 pub inviting_user_id: Option<UserId>,
2205 pub signup_device_id: Option<String>,
2206}
2207
2208fn random_invite_code() -> String {
2209 nanoid::nanoid!(16)
2210}
2211
2212fn random_email_confirmation_code() -> String {
2213 nanoid::nanoid!(64)
2214}
2215
2216#[cfg(test)]
2217pub use test::*;
2218
2219#[cfg(test)]
2220mod test {
2221 use super::*;
2222 use gpui::executor::Background;
2223 use lazy_static::lazy_static;
2224 use parking_lot::Mutex;
2225 use rand::prelude::*;
2226 use sqlx::migrate::MigrateDatabase;
2227 use std::sync::Arc;
2228
2229 pub struct SqliteTestDb {
2230 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2231 pub conn: sqlx::sqlite::SqliteConnection,
2232 }
2233
2234 pub struct PostgresTestDb {
2235 pub db: Option<Arc<Db<sqlx::Postgres>>>,
2236 pub url: String,
2237 }
2238
2239 impl SqliteTestDb {
2240 pub fn new(background: Arc<Background>) -> Self {
2241 let mut rng = StdRng::from_entropy();
2242 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2243 let runtime = tokio::runtime::Builder::new_current_thread()
2244 .enable_io()
2245 .enable_time()
2246 .build()
2247 .unwrap();
2248
2249 let (mut db, conn) = runtime.block_on(async {
2250 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2251 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2252 db.migrate(migrations_path.as_ref(), false).await.unwrap();
2253 let conn = db.pool.acquire().await.unwrap().detach();
2254 (db, conn)
2255 });
2256
2257 db.background = Some(background);
2258 db.runtime = Some(runtime);
2259
2260 Self {
2261 db: Some(Arc::new(db)),
2262 conn,
2263 }
2264 }
2265
2266 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2267 self.db.as_ref().unwrap()
2268 }
2269 }
2270
2271 impl PostgresTestDb {
2272 pub fn new(background: Arc<Background>) -> Self {
2273 lazy_static! {
2274 static ref LOCK: Mutex<()> = Mutex::new(());
2275 }
2276
2277 let _guard = LOCK.lock();
2278 let mut rng = StdRng::from_entropy();
2279 let url = format!(
2280 "postgres://postgres@localhost/zed-test-{}",
2281 rng.gen::<u128>()
2282 );
2283 let runtime = tokio::runtime::Builder::new_current_thread()
2284 .enable_io()
2285 .enable_time()
2286 .build()
2287 .unwrap();
2288
2289 let mut db = runtime.block_on(async {
2290 sqlx::Postgres::create_database(&url)
2291 .await
2292 .expect("failed to create test db");
2293 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2294 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2295 db.migrate(Path::new(migrations_path), false).await.unwrap();
2296 db
2297 });
2298
2299 db.background = Some(background);
2300 db.runtime = Some(runtime);
2301
2302 Self {
2303 db: Some(Arc::new(db)),
2304 url,
2305 }
2306 }
2307
2308 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2309 self.db.as_ref().unwrap()
2310 }
2311 }
2312
2313 impl Drop for PostgresTestDb {
2314 fn drop(&mut self) {
2315 let db = self.db.take().unwrap();
2316 db.teardown(&self.url);
2317 }
2318 }
2319}