1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::{BTreeMap, HashMap, HashSet};
5use futures::{future::BoxFuture, FutureExt, StreamExt};
6use rpc::{proto, ConnectionId};
7use serde::{Deserialize, Serialize};
8use sqlx::{
9 migrate::{Migrate as _, Migration, MigrationSource},
10 types::Uuid,
11 FromRow,
12};
13use std::{future::Future, path::Path, time::Duration};
14use time::{OffsetDateTime, PrimitiveDateTime};
15
16#[cfg(test)]
17pub type DefaultDb = Db<sqlx::Sqlite>;
18
19#[cfg(not(test))]
20pub type DefaultDb = Db<sqlx::Postgres>;
21
22pub struct Db<D: sqlx::Database> {
23 pool: sqlx::Pool<D>,
24 #[cfg(test)]
25 background: Option<std::sync::Arc<gpui::executor::Background>>,
26 #[cfg(test)]
27 runtime: Option<tokio::runtime::Runtime>,
28}
29
30pub trait BeginTransaction: Send + Sync {
31 type Database: sqlx::Database;
32
33 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
34}
35
36// In Postgres, serializable transactions are opt-in
37impl BeginTransaction for Db<sqlx::Postgres> {
38 type Database = sqlx::Postgres;
39
40 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
41 async move {
42 let mut tx = self.pool.begin().await?;
43 sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
44 .await?;
45 Ok(tx)
46 }
47 .boxed()
48 }
49}
50
51// In Sqlite, transactions are inherently serializable.
52#[cfg(test)]
53impl BeginTransaction for Db<sqlx::Sqlite> {
54 type Database = sqlx::Sqlite;
55
56 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
57 async move { Ok(self.pool.begin().await?) }.boxed()
58 }
59}
60
61pub trait RowsAffected {
62 fn rows_affected(&self) -> u64;
63}
64
65#[cfg(test)]
66impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
67 fn rows_affected(&self) -> u64 {
68 self.rows_affected()
69 }
70}
71
72impl RowsAffected for sqlx::postgres::PgQueryResult {
73 fn rows_affected(&self) -> u64 {
74 self.rows_affected()
75 }
76}
77
78#[cfg(test)]
79impl Db<sqlx::Sqlite> {
80 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
81 use std::str::FromStr as _;
82 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
83 .unwrap()
84 .create_if_missing(true)
85 .shared_cache(true);
86 let pool = sqlx::sqlite::SqlitePoolOptions::new()
87 .min_connections(2)
88 .max_connections(max_connections)
89 .connect_with(options)
90 .await?;
91 Ok(Self {
92 pool,
93 background: None,
94 runtime: None,
95 })
96 }
97
98 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
99 self.transact(|tx| async {
100 let mut tx = tx;
101 let query = "
102 SELECT users.*
103 FROM users
104 WHERE users.id IN (SELECT value from json_each($1))
105 ";
106 Ok(sqlx::query_as(query)
107 .bind(&serde_json::json!(ids))
108 .fetch_all(&mut tx)
109 .await?)
110 })
111 .await
112 }
113
114 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
115 self.transact(|mut tx| async move {
116 let query = "
117 SELECT metrics_id
118 FROM users
119 WHERE id = $1
120 ";
121 Ok(sqlx::query_scalar(query)
122 .bind(id)
123 .fetch_one(&mut tx)
124 .await?)
125 })
126 .await
127 }
128
129 pub async fn create_user(
130 &self,
131 email_address: &str,
132 admin: bool,
133 params: NewUserParams,
134 ) -> Result<NewUserResult> {
135 self.transact(|mut tx| async {
136 let query = "
137 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
138 VALUES ($1, $2, $3, $4, $5)
139 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
140 RETURNING id, metrics_id
141 ";
142
143 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
144 .bind(email_address)
145 .bind(¶ms.github_login)
146 .bind(¶ms.github_user_id)
147 .bind(admin)
148 .bind(Uuid::new_v4().to_string())
149 .fetch_one(&mut tx)
150 .await?;
151 tx.commit().await?;
152 Ok(NewUserResult {
153 user_id,
154 metrics_id,
155 signup_device_id: None,
156 inviting_user_id: None,
157 })
158 })
159 .await
160 }
161
162 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
163 unimplemented!()
164 }
165
166 pub async fn create_user_from_invite(
167 &self,
168 _invite: &Invite,
169 _user: NewUserParams,
170 ) -> Result<Option<NewUserResult>> {
171 unimplemented!()
172 }
173
174 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
175 unimplemented!()
176 }
177
178 pub async fn create_invite_from_code(
179 &self,
180 _code: &str,
181 _email_address: &str,
182 _device_id: Option<&str>,
183 ) -> Result<Invite> {
184 unimplemented!()
185 }
186
187 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
188 unimplemented!()
189 }
190}
191
192impl Db<sqlx::Postgres> {
193 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
194 let pool = sqlx::postgres::PgPoolOptions::new()
195 .max_connections(max_connections)
196 .connect(url)
197 .await?;
198 Ok(Self {
199 pool,
200 #[cfg(test)]
201 background: None,
202 #[cfg(test)]
203 runtime: None,
204 })
205 }
206
207 #[cfg(test)]
208 pub fn teardown(&self, url: &str) {
209 self.runtime.as_ref().unwrap().block_on(async {
210 use util::ResultExt;
211 let query = "
212 SELECT pg_terminate_backend(pg_stat_activity.pid)
213 FROM pg_stat_activity
214 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
215 ";
216 sqlx::query(query).execute(&self.pool).await.log_err();
217 self.pool.close().await;
218 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
219 .await
220 .log_err();
221 })
222 }
223
224 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
225 self.transact(|tx| async {
226 let mut tx = tx;
227 let like_string = Self::fuzzy_like_string(name_query);
228 let query = "
229 SELECT users.*
230 FROM users
231 WHERE github_login ILIKE $1
232 ORDER BY github_login <-> $2
233 LIMIT $3
234 ";
235 Ok(sqlx::query_as(query)
236 .bind(like_string)
237 .bind(name_query)
238 .bind(limit as i32)
239 .fetch_all(&mut tx)
240 .await?)
241 })
242 .await
243 }
244
245 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
246 let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
247 self.transact(|tx| async {
248 let mut tx = tx;
249 let query = "
250 SELECT users.*
251 FROM users
252 WHERE users.id = ANY ($1)
253 ";
254 Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
255 })
256 .await
257 }
258
259 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
260 self.transact(|mut tx| async move {
261 let query = "
262 SELECT metrics_id::text
263 FROM users
264 WHERE id = $1
265 ";
266 Ok(sqlx::query_scalar(query)
267 .bind(id)
268 .fetch_one(&mut tx)
269 .await?)
270 })
271 .await
272 }
273
274 pub async fn create_user(
275 &self,
276 email_address: &str,
277 admin: bool,
278 params: NewUserParams,
279 ) -> Result<NewUserResult> {
280 self.transact(|mut tx| async {
281 let query = "
282 INSERT INTO users (email_address, github_login, github_user_id, admin)
283 VALUES ($1, $2, $3, $4)
284 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
285 RETURNING id, metrics_id::text
286 ";
287
288 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
289 .bind(email_address)
290 .bind(¶ms.github_login)
291 .bind(params.github_user_id)
292 .bind(admin)
293 .fetch_one(&mut tx)
294 .await?;
295 tx.commit().await?;
296
297 Ok(NewUserResult {
298 user_id,
299 metrics_id,
300 signup_device_id: None,
301 inviting_user_id: None,
302 })
303 })
304 .await
305 }
306
307 pub async fn create_user_from_invite(
308 &self,
309 invite: &Invite,
310 user: NewUserParams,
311 ) -> Result<Option<NewUserResult>> {
312 self.transact(|mut tx| async {
313 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
314 i32,
315 Option<UserId>,
316 Option<UserId>,
317 Option<String>,
318 ) = sqlx::query_as(
319 "
320 SELECT id, user_id, inviting_user_id, device_id
321 FROM signups
322 WHERE
323 email_address = $1 AND
324 email_confirmation_code = $2
325 ",
326 )
327 .bind(&invite.email_address)
328 .bind(&invite.email_confirmation_code)
329 .fetch_optional(&mut tx)
330 .await?
331 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
332
333 if existing_user_id.is_some() {
334 return Ok(None);
335 }
336
337 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
338 "
339 INSERT INTO users
340 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
341 VALUES
342 ($1, $2, $3, FALSE, $4, $5)
343 ON CONFLICT (github_login) DO UPDATE SET
344 email_address = excluded.email_address,
345 github_user_id = excluded.github_user_id,
346 admin = excluded.admin
347 RETURNING id, metrics_id::text
348 ",
349 )
350 .bind(&invite.email_address)
351 .bind(&user.github_login)
352 .bind(&user.github_user_id)
353 .bind(&user.invite_count)
354 .bind(random_invite_code())
355 .fetch_one(&mut tx)
356 .await?;
357
358 sqlx::query(
359 "
360 UPDATE signups
361 SET user_id = $1
362 WHERE id = $2
363 ",
364 )
365 .bind(&user_id)
366 .bind(&signup_id)
367 .execute(&mut tx)
368 .await?;
369
370 if let Some(inviting_user_id) = inviting_user_id {
371 let id: Option<UserId> = sqlx::query_scalar(
372 "
373 UPDATE users
374 SET invite_count = invite_count - 1
375 WHERE id = $1 AND invite_count > 0
376 RETURNING id
377 ",
378 )
379 .bind(&inviting_user_id)
380 .fetch_optional(&mut tx)
381 .await?;
382
383 if id.is_none() {
384 Err(Error::Http(
385 StatusCode::UNAUTHORIZED,
386 "no invites remaining".to_string(),
387 ))?;
388 }
389
390 sqlx::query(
391 "
392 INSERT INTO contacts
393 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
394 VALUES
395 ($1, $2, TRUE, TRUE, TRUE)
396 ON CONFLICT DO NOTHING
397 ",
398 )
399 .bind(inviting_user_id)
400 .bind(user_id)
401 .execute(&mut tx)
402 .await?;
403 }
404
405 tx.commit().await?;
406 Ok(Some(NewUserResult {
407 user_id,
408 metrics_id,
409 inviting_user_id,
410 signup_device_id,
411 }))
412 })
413 .await
414 }
415
416 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
417 self.transact(|mut tx| async {
418 sqlx::query(
419 "
420 INSERT INTO signups
421 (
422 email_address,
423 email_confirmation_code,
424 email_confirmation_sent,
425 platform_linux,
426 platform_mac,
427 platform_windows,
428 platform_unknown,
429 editor_features,
430 programming_languages,
431 device_id
432 )
433 VALUES
434 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
435 RETURNING id
436 ",
437 )
438 .bind(&signup.email_address)
439 .bind(&random_email_confirmation_code())
440 .bind(&signup.platform_linux)
441 .bind(&signup.platform_mac)
442 .bind(&signup.platform_windows)
443 .bind(&signup.editor_features)
444 .bind(&signup.programming_languages)
445 .bind(&signup.device_id)
446 .execute(&mut tx)
447 .await?;
448 tx.commit().await?;
449 Ok(())
450 })
451 .await
452 }
453
454 pub async fn create_invite_from_code(
455 &self,
456 code: &str,
457 email_address: &str,
458 device_id: Option<&str>,
459 ) -> Result<Invite> {
460 self.transact(|mut tx| async {
461 let existing_user: Option<UserId> = sqlx::query_scalar(
462 "
463 SELECT id
464 FROM users
465 WHERE email_address = $1
466 ",
467 )
468 .bind(email_address)
469 .fetch_optional(&mut tx)
470 .await?;
471 if existing_user.is_some() {
472 Err(anyhow!("email address is already in use"))?;
473 }
474
475 let row: Option<(UserId, i32)> = sqlx::query_as(
476 "
477 SELECT id, invite_count
478 FROM users
479 WHERE invite_code = $1
480 ",
481 )
482 .bind(code)
483 .fetch_optional(&mut tx)
484 .await?;
485
486 let (inviter_id, invite_count) = match row {
487 Some(row) => row,
488 None => Err(Error::Http(
489 StatusCode::NOT_FOUND,
490 "invite code not found".to_string(),
491 ))?,
492 };
493
494 if invite_count == 0 {
495 Err(Error::Http(
496 StatusCode::UNAUTHORIZED,
497 "no invites remaining".to_string(),
498 ))?;
499 }
500
501 let email_confirmation_code: String = sqlx::query_scalar(
502 "
503 INSERT INTO signups
504 (
505 email_address,
506 email_confirmation_code,
507 email_confirmation_sent,
508 inviting_user_id,
509 platform_linux,
510 platform_mac,
511 platform_windows,
512 platform_unknown,
513 device_id
514 )
515 VALUES
516 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
517 ON CONFLICT (email_address)
518 DO UPDATE SET
519 inviting_user_id = excluded.inviting_user_id
520 RETURNING email_confirmation_code
521 ",
522 )
523 .bind(&email_address)
524 .bind(&random_email_confirmation_code())
525 .bind(&inviter_id)
526 .bind(&device_id)
527 .fetch_one(&mut tx)
528 .await?;
529
530 tx.commit().await?;
531
532 Ok(Invite {
533 email_address: email_address.into(),
534 email_confirmation_code,
535 })
536 })
537 .await
538 }
539
540 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
541 self.transact(|mut tx| async {
542 let emails = invites
543 .iter()
544 .map(|s| s.email_address.as_str())
545 .collect::<Vec<_>>();
546 sqlx::query(
547 "
548 UPDATE signups
549 SET email_confirmation_sent = TRUE
550 WHERE email_address = ANY ($1)
551 ",
552 )
553 .bind(&emails)
554 .execute(&mut tx)
555 .await?;
556 tx.commit().await?;
557 Ok(())
558 })
559 .await
560 }
561}
562
563impl<D> Db<D>
564where
565 Self: BeginTransaction<Database = D>,
566 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
567 D::Connection: sqlx::migrate::Migrate,
568 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
569 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
570 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
571 D::QueryResult: RowsAffected,
572 String: sqlx::Type<D>,
573 i32: sqlx::Type<D>,
574 i64: sqlx::Type<D>,
575 bool: sqlx::Type<D>,
576 str: sqlx::Type<D>,
577 Uuid: sqlx::Type<D>,
578 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
579 OffsetDateTime: sqlx::Type<D>,
580 PrimitiveDateTime: sqlx::Type<D>,
581 usize: sqlx::ColumnIndex<D::Row>,
582 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
583 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
584 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
585 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
586 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
587 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
588 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
589 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
590 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
591 for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
592 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
593 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
594 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
595{
596 pub async fn migrate(
597 &self,
598 migrations_path: &Path,
599 ignore_checksum_mismatch: bool,
600 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
601 let migrations = MigrationSource::resolve(migrations_path)
602 .await
603 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
604
605 let mut conn = self.pool.acquire().await?;
606
607 conn.ensure_migrations_table().await?;
608 let applied_migrations: HashMap<_, _> = conn
609 .list_applied_migrations()
610 .await?
611 .into_iter()
612 .map(|m| (m.version, m))
613 .collect();
614
615 let mut new_migrations = Vec::new();
616 for migration in migrations {
617 match applied_migrations.get(&migration.version) {
618 Some(applied_migration) => {
619 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
620 {
621 Err(anyhow!(
622 "checksum mismatch for applied migration {}",
623 migration.description
624 ))?;
625 }
626 }
627 None => {
628 let elapsed = conn.apply(&migration).await?;
629 new_migrations.push((migration, elapsed));
630 }
631 }
632 }
633
634 Ok(new_migrations)
635 }
636
637 pub fn fuzzy_like_string(string: &str) -> String {
638 let mut result = String::with_capacity(string.len() * 2 + 1);
639 for c in string.chars() {
640 if c.is_alphanumeric() {
641 result.push('%');
642 result.push(c);
643 }
644 }
645 result.push('%');
646 result
647 }
648
649 // users
650
651 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
652 self.transact(|tx| async {
653 let mut tx = tx;
654 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
655 Ok(sqlx::query_as(query)
656 .bind(limit as i32)
657 .bind((page * limit) as i32)
658 .fetch_all(&mut tx)
659 .await?)
660 })
661 .await
662 }
663
664 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
665 self.transact(|tx| async {
666 let mut tx = tx;
667 let query = "
668 SELECT users.*
669 FROM users
670 WHERE id = $1
671 LIMIT 1
672 ";
673 Ok(sqlx::query_as(query)
674 .bind(&id)
675 .fetch_optional(&mut tx)
676 .await?)
677 })
678 .await
679 }
680
681 pub async fn get_users_with_no_invites(
682 &self,
683 invited_by_another_user: bool,
684 ) -> Result<Vec<User>> {
685 self.transact(|tx| async {
686 let mut tx = tx;
687 let query = format!(
688 "
689 SELECT users.*
690 FROM users
691 WHERE invite_count = 0
692 AND inviter_id IS{} NULL
693 ",
694 if invited_by_another_user { " NOT" } else { "" }
695 );
696
697 Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
698 })
699 .await
700 }
701
702 pub async fn get_user_by_github_account(
703 &self,
704 github_login: &str,
705 github_user_id: Option<i32>,
706 ) -> Result<Option<User>> {
707 self.transact(|tx| async {
708 let mut tx = tx;
709 if let Some(github_user_id) = github_user_id {
710 let mut user = sqlx::query_as::<_, User>(
711 "
712 UPDATE users
713 SET github_login = $1
714 WHERE github_user_id = $2
715 RETURNING *
716 ",
717 )
718 .bind(github_login)
719 .bind(github_user_id)
720 .fetch_optional(&mut tx)
721 .await?;
722
723 if user.is_none() {
724 user = sqlx::query_as::<_, User>(
725 "
726 UPDATE users
727 SET github_user_id = $1
728 WHERE github_login = $2
729 RETURNING *
730 ",
731 )
732 .bind(github_user_id)
733 .bind(github_login)
734 .fetch_optional(&mut tx)
735 .await?;
736 }
737
738 Ok(user)
739 } else {
740 let user = sqlx::query_as(
741 "
742 SELECT * FROM users
743 WHERE github_login = $1
744 LIMIT 1
745 ",
746 )
747 .bind(github_login)
748 .fetch_optional(&mut tx)
749 .await?;
750 Ok(user)
751 }
752 })
753 .await
754 }
755
756 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
757 self.transact(|mut tx| async {
758 let query = "UPDATE users SET admin = $1 WHERE id = $2";
759 sqlx::query(query)
760 .bind(is_admin)
761 .bind(id.0)
762 .execute(&mut tx)
763 .await?;
764 tx.commit().await?;
765 Ok(())
766 })
767 .await
768 }
769
770 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
771 self.transact(|mut tx| async move {
772 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
773 sqlx::query(query)
774 .bind(connected_once)
775 .bind(id.0)
776 .execute(&mut tx)
777 .await?;
778 tx.commit().await?;
779 Ok(())
780 })
781 .await
782 }
783
784 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
785 self.transact(|mut tx| async move {
786 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
787 sqlx::query(query)
788 .bind(id.0)
789 .execute(&mut tx)
790 .await
791 .map(drop)?;
792 let query = "DELETE FROM users WHERE id = $1;";
793 sqlx::query(query).bind(id.0).execute(&mut tx).await?;
794 tx.commit().await?;
795 Ok(())
796 })
797 .await
798 }
799
800 // signups
801
802 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
803 self.transact(|mut tx| async move {
804 Ok(sqlx::query_as(
805 "
806 SELECT
807 COUNT(*) as count,
808 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
809 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
810 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
811 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
812 FROM (
813 SELECT *
814 FROM signups
815 WHERE
816 NOT email_confirmation_sent
817 ) AS unsent
818 ",
819 )
820 .fetch_one(&mut tx)
821 .await?)
822 })
823 .await
824 }
825
826 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
827 self.transact(|mut tx| async move {
828 Ok(sqlx::query_as(
829 "
830 SELECT
831 email_address, email_confirmation_code
832 FROM signups
833 WHERE
834 NOT email_confirmation_sent AND
835 (platform_mac OR platform_unknown)
836 LIMIT $1
837 ",
838 )
839 .bind(count as i32)
840 .fetch_all(&mut tx)
841 .await?)
842 })
843 .await
844 }
845
846 // invite codes
847
848 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
849 self.transact(|mut tx| async move {
850 if count > 0 {
851 sqlx::query(
852 "
853 UPDATE users
854 SET invite_code = $1
855 WHERE id = $2 AND invite_code IS NULL
856 ",
857 )
858 .bind(random_invite_code())
859 .bind(id)
860 .execute(&mut tx)
861 .await?;
862 }
863
864 sqlx::query(
865 "
866 UPDATE users
867 SET invite_count = $1
868 WHERE id = $2
869 ",
870 )
871 .bind(count as i32)
872 .bind(id)
873 .execute(&mut tx)
874 .await?;
875 tx.commit().await?;
876 Ok(())
877 })
878 .await
879 }
880
881 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
882 self.transact(|mut tx| async move {
883 let result: Option<(String, i32)> = sqlx::query_as(
884 "
885 SELECT invite_code, invite_count
886 FROM users
887 WHERE id = $1 AND invite_code IS NOT NULL
888 ",
889 )
890 .bind(id)
891 .fetch_optional(&mut tx)
892 .await?;
893 if let Some((code, count)) = result {
894 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
895 } else {
896 Ok(None)
897 }
898 })
899 .await
900 }
901
902 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
903 self.transact(|tx| async {
904 let mut tx = tx;
905 sqlx::query_as(
906 "
907 SELECT *
908 FROM users
909 WHERE invite_code = $1
910 ",
911 )
912 .bind(code)
913 .fetch_optional(&mut tx)
914 .await?
915 .ok_or_else(|| {
916 Error::Http(
917 StatusCode::NOT_FOUND,
918 "that invite code does not exist".to_string(),
919 )
920 })
921 })
922 .await
923 }
924
925 pub async fn create_room(
926 &self,
927 user_id: UserId,
928 connection_id: ConnectionId,
929 ) -> Result<proto::Room> {
930 self.transact(|mut tx| async move {
931 let live_kit_room = nanoid::nanoid!(30);
932 let room_id = sqlx::query_scalar(
933 "
934 INSERT INTO rooms (live_kit_room)
935 VALUES ($1)
936 RETURNING id
937 ",
938 )
939 .bind(&live_kit_room)
940 .fetch_one(&mut tx)
941 .await
942 .map(RoomId)?;
943
944 sqlx::query(
945 "
946 INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
947 VALUES ($1, $2, $3, $4, $5)
948 ",
949 )
950 .bind(room_id)
951 .bind(user_id)
952 .bind(connection_id.0 as i32)
953 .bind(user_id)
954 .bind(connection_id.0 as i32)
955 .execute(&mut tx)
956 .await?;
957
958 let room = self.get_room(room_id, &mut tx).await?;
959 tx.commit().await?;
960 Ok(room)
961 }).await
962 }
963
964 pub async fn call(
965 &self,
966 room_id: RoomId,
967 calling_user_id: UserId,
968 calling_connection_id: ConnectionId,
969 called_user_id: UserId,
970 initial_project_id: Option<ProjectId>,
971 ) -> Result<(proto::Room, proto::IncomingCall)> {
972 self.transact(|mut tx| async move {
973 sqlx::query(
974 "
975 INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id)
976 VALUES ($1, $2, $3, $4, $5)
977 ",
978 )
979 .bind(room_id)
980 .bind(called_user_id)
981 .bind(calling_user_id)
982 .bind(calling_connection_id.0 as i32)
983 .bind(initial_project_id)
984 .execute(&mut tx)
985 .await?;
986
987 let room = self.get_room(room_id, &mut tx).await?;
988 tx.commit().await?;
989
990 let incoming_call = Self::build_incoming_call(&room, called_user_id)
991 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
992 Ok((room, incoming_call))
993 }).await
994 }
995
996 pub async fn incoming_call_for_user(
997 &self,
998 user_id: UserId,
999 ) -> Result<Option<proto::IncomingCall>> {
1000 self.transact(|mut tx| async move {
1001 let room_id = sqlx::query_scalar::<_, RoomId>(
1002 "
1003 SELECT room_id
1004 FROM room_participants
1005 WHERE user_id = $1 AND answering_connection_id IS NULL
1006 ",
1007 )
1008 .bind(user_id)
1009 .fetch_optional(&mut tx)
1010 .await?;
1011
1012 if let Some(room_id) = room_id {
1013 let room = self.get_room(room_id, &mut tx).await?;
1014 Ok(Self::build_incoming_call(&room, user_id))
1015 } else {
1016 Ok(None)
1017 }
1018 })
1019 .await
1020 }
1021
1022 fn build_incoming_call(
1023 room: &proto::Room,
1024 called_user_id: UserId,
1025 ) -> Option<proto::IncomingCall> {
1026 let pending_participant = room
1027 .pending_participants
1028 .iter()
1029 .find(|participant| participant.user_id == called_user_id.to_proto())?;
1030
1031 Some(proto::IncomingCall {
1032 room_id: room.id,
1033 calling_user_id: pending_participant.calling_user_id,
1034 participant_user_ids: room
1035 .participants
1036 .iter()
1037 .map(|participant| participant.user_id)
1038 .collect(),
1039 initial_project: room.participants.iter().find_map(|participant| {
1040 let initial_project_id = pending_participant.initial_project_id?;
1041 participant
1042 .projects
1043 .iter()
1044 .find(|project| project.id == initial_project_id)
1045 .cloned()
1046 }),
1047 })
1048 }
1049
1050 pub async fn call_failed(
1051 &self,
1052 room_id: RoomId,
1053 called_user_id: UserId,
1054 ) -> Result<proto::Room> {
1055 self.transact(|mut tx| async move {
1056 sqlx::query(
1057 "
1058 DELETE FROM room_participants
1059 WHERE room_id = $1 AND user_id = $2
1060 ",
1061 )
1062 .bind(room_id)
1063 .bind(called_user_id)
1064 .execute(&mut tx)
1065 .await?;
1066
1067 let room = self.get_room(room_id, &mut tx).await?;
1068 tx.commit().await?;
1069 Ok(room)
1070 })
1071 .await
1072 }
1073
1074 pub async fn decline_call(
1075 &self,
1076 expected_room_id: Option<RoomId>,
1077 user_id: UserId,
1078 ) -> Result<proto::Room> {
1079 self.transact(|mut tx| async move {
1080 let room_id = sqlx::query_scalar(
1081 "
1082 DELETE FROM room_participants
1083 WHERE user_id = $1 AND answering_connection_id IS NULL
1084 RETURNING room_id
1085 ",
1086 )
1087 .bind(user_id)
1088 .fetch_one(&mut tx)
1089 .await?;
1090 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1091 return Err(anyhow!("declining call on unexpected room"))?;
1092 }
1093
1094 let room = self.get_room(room_id, &mut tx).await?;
1095 tx.commit().await?;
1096 Ok(room)
1097 })
1098 .await
1099 }
1100
1101 pub async fn cancel_call(
1102 &self,
1103 expected_room_id: Option<RoomId>,
1104 calling_connection_id: ConnectionId,
1105 called_user_id: UserId,
1106 ) -> Result<proto::Room> {
1107 self.transact(|mut tx| async move {
1108 let room_id = sqlx::query_scalar(
1109 "
1110 DELETE FROM room_participants
1111 WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1112 RETURNING room_id
1113 ",
1114 )
1115 .bind(called_user_id)
1116 .bind(calling_connection_id.0 as i32)
1117 .fetch_one(&mut tx)
1118 .await?;
1119 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1120 return Err(anyhow!("canceling call on unexpected room"))?;
1121 }
1122
1123 let room = self.get_room(room_id, &mut tx).await?;
1124 tx.commit().await?;
1125 Ok(room)
1126 }).await
1127 }
1128
1129 pub async fn join_room(
1130 &self,
1131 room_id: RoomId,
1132 user_id: UserId,
1133 connection_id: ConnectionId,
1134 ) -> Result<proto::Room> {
1135 self.transact(|mut tx| async move {
1136 sqlx::query(
1137 "
1138 UPDATE room_participants
1139 SET answering_connection_id = $1
1140 WHERE room_id = $2 AND user_id = $3
1141 RETURNING 1
1142 ",
1143 )
1144 .bind(connection_id.0 as i32)
1145 .bind(room_id)
1146 .bind(user_id)
1147 .fetch_one(&mut tx)
1148 .await?;
1149
1150 let room = self.get_room(room_id, &mut tx).await?;
1151 tx.commit().await?;
1152 Ok(room)
1153 })
1154 .await
1155 }
1156
1157 pub async fn leave_room(&self, connection_id: ConnectionId) -> Result<Option<LeftRoom>> {
1158 self.transact(|mut tx| async move {
1159 // Leave room.
1160 let room_id = sqlx::query_scalar::<_, RoomId>(
1161 "
1162 DELETE FROM room_participants
1163 WHERE answering_connection_id = $1
1164 RETURNING room_id
1165 ",
1166 )
1167 .bind(connection_id.0 as i32)
1168 .fetch_optional(&mut tx)
1169 .await?;
1170
1171 if let Some(room_id) = room_id {
1172 // Cancel pending calls initiated by the leaving user.
1173 let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1174 "
1175 DELETE FROM room_participants
1176 WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1177 RETURNING user_id
1178 ",
1179 )
1180 .bind(connection_id.0 as i32)
1181 .fetch_all(&mut tx)
1182 .await?;
1183
1184 let project_ids = sqlx::query_scalar::<_, ProjectId>(
1185 "
1186 SELECT project_id
1187 FROM project_collaborators
1188 WHERE connection_id = $1
1189 ",
1190 )
1191 .bind(connection_id.0 as i32)
1192 .fetch_all(&mut tx)
1193 .await?;
1194
1195 // Leave projects.
1196 let mut left_projects = HashMap::default();
1197 if !project_ids.is_empty() {
1198 let mut params = "?,".repeat(project_ids.len());
1199 params.pop();
1200 let query = format!(
1201 "
1202 SELECT *
1203 FROM project_collaborators
1204 WHERE project_id IN ({params})
1205 "
1206 );
1207 let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query);
1208 for project_id in project_ids {
1209 query = query.bind(project_id);
1210 }
1211
1212 let mut project_collaborators = query.fetch(&mut tx);
1213 while let Some(collaborator) = project_collaborators.next().await {
1214 let collaborator = collaborator?;
1215 let left_project =
1216 left_projects
1217 .entry(collaborator.project_id)
1218 .or_insert(LeftProject {
1219 id: collaborator.project_id,
1220 host_user_id: Default::default(),
1221 connection_ids: Default::default(),
1222 host_connection_id: Default::default(),
1223 });
1224
1225 let collaborator_connection_id =
1226 ConnectionId(collaborator.connection_id as u32);
1227 if collaborator_connection_id != connection_id {
1228 left_project.connection_ids.push(collaborator_connection_id);
1229 }
1230
1231 if collaborator.is_host {
1232 left_project.host_user_id = collaborator.user_id;
1233 left_project.host_connection_id =
1234 ConnectionId(collaborator.connection_id as u32);
1235 }
1236 }
1237 }
1238 sqlx::query(
1239 "
1240 DELETE FROM project_collaborators
1241 WHERE connection_id = $1
1242 ",
1243 )
1244 .bind(connection_id.0 as i32)
1245 .execute(&mut tx)
1246 .await?;
1247
1248 // Unshare projects.
1249 sqlx::query(
1250 "
1251 DELETE FROM projects
1252 WHERE room_id = $1 AND host_connection_id = $2
1253 ",
1254 )
1255 .bind(room_id)
1256 .bind(connection_id.0 as i32)
1257 .execute(&mut tx)
1258 .await?;
1259
1260 let room = self.get_room(room_id, &mut tx).await?;
1261 tx.commit().await?;
1262
1263 Ok(Some(LeftRoom {
1264 room,
1265 left_projects,
1266 canceled_calls_to_user_ids,
1267 }))
1268 } else {
1269 Ok(None)
1270 }
1271 })
1272 .await
1273 }
1274
1275 pub async fn update_room_participant_location(
1276 &self,
1277 room_id: RoomId,
1278 connection_id: ConnectionId,
1279 location: proto::ParticipantLocation,
1280 ) -> Result<proto::Room> {
1281 self.transact(|tx| async {
1282 let mut tx = tx;
1283 let location_kind;
1284 let location_project_id;
1285 match location
1286 .variant
1287 .as_ref()
1288 .ok_or_else(|| anyhow!("invalid location"))?
1289 {
1290 proto::participant_location::Variant::SharedProject(project) => {
1291 location_kind = 0;
1292 location_project_id = Some(ProjectId::from_proto(project.id));
1293 }
1294 proto::participant_location::Variant::UnsharedProject(_) => {
1295 location_kind = 1;
1296 location_project_id = None;
1297 }
1298 proto::participant_location::Variant::External(_) => {
1299 location_kind = 2;
1300 location_project_id = None;
1301 }
1302 }
1303
1304 sqlx::query(
1305 "
1306 UPDATE room_participants
1307 SET location_kind = $1, location_project_id = $2
1308 WHERE room_id = $3 AND answering_connection_id = $4
1309 RETURNING 1
1310 ",
1311 )
1312 .bind(location_kind)
1313 .bind(location_project_id)
1314 .bind(room_id)
1315 .bind(connection_id.0 as i32)
1316 .fetch_one(&mut tx)
1317 .await?;
1318
1319 let room = self.get_room(room_id, &mut tx).await?;
1320 tx.commit().await?;
1321 Ok(room)
1322 })
1323 .await
1324 }
1325
1326 async fn get_guest_connection_ids(
1327 &self,
1328 project_id: ProjectId,
1329 tx: &mut sqlx::Transaction<'_, D>,
1330 ) -> Result<Vec<ConnectionId>> {
1331 let mut guest_connection_ids = Vec::new();
1332 let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1333 "
1334 SELECT connection_id
1335 FROM project_collaborators
1336 WHERE project_id = $1 AND is_host = FALSE
1337 ",
1338 )
1339 .bind(project_id)
1340 .fetch(tx);
1341 while let Some(connection_id) = db_guest_connection_ids.next().await {
1342 guest_connection_ids.push(ConnectionId(connection_id? as u32));
1343 }
1344 Ok(guest_connection_ids)
1345 }
1346
1347 async fn get_room(
1348 &self,
1349 room_id: RoomId,
1350 tx: &mut sqlx::Transaction<'_, D>,
1351 ) -> Result<proto::Room> {
1352 let room: Room = sqlx::query_as(
1353 "
1354 SELECT *
1355 FROM rooms
1356 WHERE id = $1
1357 ",
1358 )
1359 .bind(room_id)
1360 .fetch_one(&mut *tx)
1361 .await?;
1362
1363 let mut db_participants =
1364 sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1365 "
1366 SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1367 FROM room_participants
1368 WHERE room_id = $1
1369 ",
1370 )
1371 .bind(room_id)
1372 .fetch(&mut *tx);
1373
1374 let mut participants = HashMap::default();
1375 let mut pending_participants = Vec::new();
1376 while let Some(participant) = db_participants.next().await {
1377 let (
1378 user_id,
1379 answering_connection_id,
1380 location_kind,
1381 location_project_id,
1382 calling_user_id,
1383 initial_project_id,
1384 ) = participant?;
1385 if let Some(answering_connection_id) = answering_connection_id {
1386 let location = match (location_kind, location_project_id) {
1387 (Some(0), Some(project_id)) => {
1388 Some(proto::participant_location::Variant::SharedProject(
1389 proto::participant_location::SharedProject {
1390 id: project_id.to_proto(),
1391 },
1392 ))
1393 }
1394 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
1395 Default::default(),
1396 )),
1397 _ => Some(proto::participant_location::Variant::External(
1398 Default::default(),
1399 )),
1400 };
1401 participants.insert(
1402 answering_connection_id,
1403 proto::Participant {
1404 user_id: user_id.to_proto(),
1405 peer_id: answering_connection_id as u32,
1406 projects: Default::default(),
1407 location: Some(proto::ParticipantLocation { variant: location }),
1408 },
1409 );
1410 } else {
1411 pending_participants.push(proto::PendingParticipant {
1412 user_id: user_id.to_proto(),
1413 calling_user_id: calling_user_id.to_proto(),
1414 initial_project_id: initial_project_id.map(|id| id.to_proto()),
1415 });
1416 }
1417 }
1418 drop(db_participants);
1419
1420 let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option<String>)>(
1421 "
1422 SELECT host_connection_id, projects.id, worktrees.root_name
1423 FROM projects
1424 LEFT JOIN worktrees ON projects.id = worktrees.project_id
1425 WHERE room_id = $1
1426 ",
1427 )
1428 .bind(room_id)
1429 .fetch(&mut *tx);
1430
1431 while let Some(row) = rows.next().await {
1432 let (connection_id, project_id, worktree_root_name) = row?;
1433 if let Some(participant) = participants.get_mut(&connection_id) {
1434 let project = if let Some(project) = participant
1435 .projects
1436 .iter_mut()
1437 .find(|project| project.id == project_id.to_proto())
1438 {
1439 project
1440 } else {
1441 participant.projects.push(proto::ParticipantProject {
1442 id: project_id.to_proto(),
1443 worktree_root_names: Default::default(),
1444 });
1445 participant.projects.last_mut().unwrap()
1446 };
1447 project.worktree_root_names.extend(worktree_root_name);
1448 }
1449 }
1450
1451 Ok(proto::Room {
1452 id: room.id.to_proto(),
1453 live_kit_room: room.live_kit_room,
1454 participants: participants.into_values().collect(),
1455 pending_participants,
1456 })
1457 }
1458
1459 // projects
1460
1461 pub async fn project_count_excluding_admins(&self) -> Result<usize> {
1462 self.transact(|mut tx| async move {
1463 Ok(sqlx::query_scalar::<_, i32>(
1464 "
1465 SELECT COUNT(*)
1466 FROM projects, users
1467 WHERE projects.host_user_id = users.id AND users.admin IS FALSE
1468 ",
1469 )
1470 .fetch_one(&mut tx)
1471 .await? as usize)
1472 })
1473 .await
1474 }
1475
1476 pub async fn share_project(
1477 &self,
1478 expected_room_id: RoomId,
1479 connection_id: ConnectionId,
1480 worktrees: &[proto::WorktreeMetadata],
1481 ) -> Result<(ProjectId, proto::Room)> {
1482 self.transact(|mut tx| async move {
1483 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1484 "
1485 SELECT room_id, user_id
1486 FROM room_participants
1487 WHERE answering_connection_id = $1
1488 ",
1489 )
1490 .bind(connection_id.0 as i32)
1491 .fetch_one(&mut tx)
1492 .await?;
1493 if room_id != expected_room_id {
1494 return Err(anyhow!("shared project on unexpected room"))?;
1495 }
1496
1497 let project_id: ProjectId = sqlx::query_scalar(
1498 "
1499 INSERT INTO projects (room_id, host_user_id, host_connection_id)
1500 VALUES ($1, $2, $3)
1501 RETURNING id
1502 ",
1503 )
1504 .bind(room_id)
1505 .bind(user_id)
1506 .bind(connection_id.0 as i32)
1507 .fetch_one(&mut tx)
1508 .await?;
1509
1510 if !worktrees.is_empty() {
1511 let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1512 params.pop();
1513 let query = format!(
1514 "
1515 INSERT INTO worktrees (
1516 project_id,
1517 id,
1518 root_name,
1519 abs_path,
1520 visible,
1521 scan_id,
1522 is_complete
1523 )
1524 VALUES {params}
1525 "
1526 );
1527
1528 let mut query = sqlx::query(&query);
1529 for worktree in worktrees {
1530 query = query
1531 .bind(project_id)
1532 .bind(worktree.id as i32)
1533 .bind(&worktree.root_name)
1534 .bind(&worktree.abs_path)
1535 .bind(worktree.visible)
1536 .bind(0)
1537 .bind(false);
1538 }
1539 query.execute(&mut tx).await?;
1540 }
1541
1542 sqlx::query(
1543 "
1544 INSERT INTO project_collaborators (
1545 project_id,
1546 connection_id,
1547 user_id,
1548 replica_id,
1549 is_host
1550 )
1551 VALUES ($1, $2, $3, $4, $5)
1552 ",
1553 )
1554 .bind(project_id)
1555 .bind(connection_id.0 as i32)
1556 .bind(user_id)
1557 .bind(0)
1558 .bind(true)
1559 .execute(&mut tx)
1560 .await?;
1561
1562 let room = self.get_room(room_id, &mut tx).await?;
1563 tx.commit().await?;
1564
1565 Ok((project_id, room))
1566 })
1567 .await
1568 }
1569
1570 pub async fn unshare_project(
1571 &self,
1572 project_id: ProjectId,
1573 connection_id: ConnectionId,
1574 ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1575 self.transact(|mut tx| async move {
1576 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1577 let room_id: RoomId = sqlx::query_scalar(
1578 "
1579 DELETE FROM projects
1580 WHERE id = $1 AND host_connection_id = $2
1581 RETURNING room_id
1582 ",
1583 )
1584 .bind(project_id)
1585 .bind(connection_id.0 as i32)
1586 .fetch_one(&mut tx)
1587 .await?;
1588 let room = self.get_room(room_id, &mut tx).await?;
1589 tx.commit().await?;
1590
1591 Ok((room, guest_connection_ids))
1592 })
1593 .await
1594 }
1595
1596 pub async fn update_project(
1597 &self,
1598 project_id: ProjectId,
1599 connection_id: ConnectionId,
1600 worktrees: &[proto::WorktreeMetadata],
1601 ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1602 self.transact(|mut tx| async move {
1603 let room_id: RoomId = sqlx::query_scalar(
1604 "
1605 SELECT room_id
1606 FROM projects
1607 WHERE id = $1 AND host_connection_id = $2
1608 ",
1609 )
1610 .bind(project_id)
1611 .bind(connection_id.0 as i32)
1612 .fetch_one(&mut tx)
1613 .await?;
1614
1615 if !worktrees.is_empty() {
1616 let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1617 params.pop();
1618 let query = format!(
1619 "
1620 INSERT INTO worktrees (
1621 project_id,
1622 id,
1623 root_name,
1624 abs_path,
1625 visible,
1626 scan_id,
1627 is_complete
1628 )
1629 VALUES {params}
1630 ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1631 "
1632 );
1633
1634 let mut query = sqlx::query(&query);
1635 for worktree in worktrees {
1636 query = query
1637 .bind(project_id)
1638 .bind(worktree.id as i32)
1639 .bind(&worktree.root_name)
1640 .bind(&worktree.abs_path)
1641 .bind(worktree.visible)
1642 .bind(0)
1643 .bind(false)
1644 }
1645 query.execute(&mut tx).await?;
1646 }
1647
1648 let mut params = "?,".repeat(worktrees.len());
1649 if !worktrees.is_empty() {
1650 params.pop();
1651 }
1652 let query = format!(
1653 "
1654 DELETE FROM worktrees
1655 WHERE project_id = ? AND id NOT IN ({params})
1656 ",
1657 );
1658
1659 let mut query = sqlx::query(&query).bind(project_id);
1660 for worktree in worktrees {
1661 query = query.bind(WorktreeId(worktree.id as i32));
1662 }
1663 query.execute(&mut tx).await?;
1664
1665 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1666 let room = self.get_room(room_id, &mut tx).await?;
1667 tx.commit().await?;
1668
1669 Ok((room, guest_connection_ids))
1670 })
1671 .await
1672 }
1673
1674 pub async fn update_worktree(
1675 &self,
1676 update: &proto::UpdateWorktree,
1677 connection_id: ConnectionId,
1678 ) -> Result<Vec<ConnectionId>> {
1679 self.transact(|mut tx| async move {
1680 let project_id = ProjectId::from_proto(update.project_id);
1681 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1682
1683 // Ensure the update comes from the host.
1684 sqlx::query(
1685 "
1686 SELECT 1
1687 FROM projects
1688 WHERE id = $1 AND host_connection_id = $2
1689 ",
1690 )
1691 .bind(project_id)
1692 .bind(connection_id.0 as i32)
1693 .fetch_one(&mut tx)
1694 .await?;
1695
1696 // Update metadata.
1697 sqlx::query(
1698 "
1699 UPDATE worktrees
1700 SET
1701 root_name = $1,
1702 scan_id = $2,
1703 is_complete = $3,
1704 abs_path = $4
1705 WHERE project_id = $5 AND id = $6
1706 RETURNING 1
1707 ",
1708 )
1709 .bind(&update.root_name)
1710 .bind(update.scan_id as i64)
1711 .bind(update.is_last_update)
1712 .bind(&update.abs_path)
1713 .bind(project_id)
1714 .bind(worktree_id)
1715 .fetch_one(&mut tx)
1716 .await?;
1717
1718 if !update.updated_entries.is_empty() {
1719 let mut params =
1720 "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len());
1721 params.pop();
1722
1723 let query = format!(
1724 "
1725 INSERT INTO worktree_entries (
1726 project_id,
1727 worktree_id,
1728 id,
1729 is_dir,
1730 path,
1731 inode,
1732 mtime_seconds,
1733 mtime_nanos,
1734 is_symlink,
1735 is_ignored
1736 )
1737 VALUES {params}
1738 ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET
1739 is_dir = excluded.is_dir,
1740 path = excluded.path,
1741 inode = excluded.inode,
1742 mtime_seconds = excluded.mtime_seconds,
1743 mtime_nanos = excluded.mtime_nanos,
1744 is_symlink = excluded.is_symlink,
1745 is_ignored = excluded.is_ignored
1746 "
1747 );
1748 let mut query = sqlx::query(&query);
1749 for entry in &update.updated_entries {
1750 let mtime = entry.mtime.clone().unwrap_or_default();
1751 query = query
1752 .bind(project_id)
1753 .bind(worktree_id)
1754 .bind(entry.id as i64)
1755 .bind(entry.is_dir)
1756 .bind(&entry.path)
1757 .bind(entry.inode as i64)
1758 .bind(mtime.seconds as i64)
1759 .bind(mtime.nanos as i32)
1760 .bind(entry.is_symlink)
1761 .bind(entry.is_ignored);
1762 }
1763 query.execute(&mut tx).await?;
1764 }
1765
1766 if !update.removed_entries.is_empty() {
1767 let mut params = "?,".repeat(update.removed_entries.len());
1768 params.pop();
1769 let query = format!(
1770 "
1771 DELETE FROM worktree_entries
1772 WHERE project_id = ? AND worktree_id = ? AND id IN ({params})
1773 "
1774 );
1775
1776 let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id);
1777 for entry_id in &update.removed_entries {
1778 query = query.bind(*entry_id as i64);
1779 }
1780 query.execute(&mut tx).await?;
1781 }
1782
1783 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1784 tx.commit().await?;
1785 Ok(connection_ids)
1786 })
1787 .await
1788 }
1789
1790 pub async fn update_diagnostic_summary(
1791 &self,
1792 update: &proto::UpdateDiagnosticSummary,
1793 connection_id: ConnectionId,
1794 ) -> Result<Vec<ConnectionId>> {
1795 self.transact(|mut tx| async {
1796 let project_id = ProjectId::from_proto(update.project_id);
1797 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1798 let summary = update
1799 .summary
1800 .as_ref()
1801 .ok_or_else(|| anyhow!("invalid summary"))?;
1802
1803 // Ensure the update comes from the host.
1804 sqlx::query(
1805 "
1806 SELECT 1
1807 FROM projects
1808 WHERE id = $1 AND host_connection_id = $2
1809 ",
1810 )
1811 .bind(project_id)
1812 .bind(connection_id.0 as i32)
1813 .fetch_one(&mut tx)
1814 .await?;
1815
1816 // Update summary.
1817 sqlx::query(
1818 "
1819 INSERT INTO worktree_diagnostic_summaries (
1820 project_id,
1821 worktree_id,
1822 path,
1823 language_server_id,
1824 error_count,
1825 warning_count
1826 )
1827 VALUES ($1, $2, $3, $4, $5, $6)
1828 ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET
1829 language_server_id = excluded.language_server_id,
1830 error_count = excluded.error_count,
1831 warning_count = excluded.warning_count
1832 ",
1833 )
1834 .bind(project_id)
1835 .bind(worktree_id)
1836 .bind(&summary.path)
1837 .bind(summary.language_server_id as i64)
1838 .bind(summary.error_count as i32)
1839 .bind(summary.warning_count as i32)
1840 .execute(&mut tx)
1841 .await?;
1842
1843 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1844 tx.commit().await?;
1845 Ok(connection_ids)
1846 })
1847 .await
1848 }
1849
1850 pub async fn start_language_server(
1851 &self,
1852 update: &proto::StartLanguageServer,
1853 connection_id: ConnectionId,
1854 ) -> Result<Vec<ConnectionId>> {
1855 self.transact(|mut tx| async {
1856 let project_id = ProjectId::from_proto(update.project_id);
1857 let server = update
1858 .server
1859 .as_ref()
1860 .ok_or_else(|| anyhow!("invalid language server"))?;
1861
1862 // Ensure the update comes from the host.
1863 sqlx::query(
1864 "
1865 SELECT 1
1866 FROM projects
1867 WHERE id = $1 AND host_connection_id = $2
1868 ",
1869 )
1870 .bind(project_id)
1871 .bind(connection_id.0 as i32)
1872 .fetch_one(&mut tx)
1873 .await?;
1874
1875 // Add the newly-started language server.
1876 sqlx::query(
1877 "
1878 INSERT INTO language_servers (project_id, id, name)
1879 VALUES ($1, $2, $3)
1880 ON CONFLICT (project_id, id) DO UPDATE SET
1881 name = excluded.name
1882 ",
1883 )
1884 .bind(project_id)
1885 .bind(server.id as i64)
1886 .bind(&server.name)
1887 .execute(&mut tx)
1888 .await?;
1889
1890 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1891 tx.commit().await?;
1892 Ok(connection_ids)
1893 })
1894 .await
1895 }
1896
1897 pub async fn join_project(
1898 &self,
1899 project_id: ProjectId,
1900 connection_id: ConnectionId,
1901 ) -> Result<(Project, ReplicaId)> {
1902 self.transact(|mut tx| async move {
1903 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1904 "
1905 SELECT room_id, user_id
1906 FROM room_participants
1907 WHERE answering_connection_id = $1
1908 ",
1909 )
1910 .bind(connection_id.0 as i32)
1911 .fetch_one(&mut tx)
1912 .await?;
1913
1914 // Ensure project id was shared on this room.
1915 sqlx::query(
1916 "
1917 SELECT 1
1918 FROM projects
1919 WHERE id = $1 AND room_id = $2
1920 ",
1921 )
1922 .bind(project_id)
1923 .bind(room_id)
1924 .fetch_one(&mut tx)
1925 .await?;
1926
1927 let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1928 "
1929 SELECT *
1930 FROM project_collaborators
1931 WHERE project_id = $1
1932 ",
1933 )
1934 .bind(project_id)
1935 .fetch_all(&mut tx)
1936 .await?;
1937 let replica_ids = collaborators
1938 .iter()
1939 .map(|c| c.replica_id)
1940 .collect::<HashSet<_>>();
1941 let mut replica_id = ReplicaId(1);
1942 while replica_ids.contains(&replica_id) {
1943 replica_id.0 += 1;
1944 }
1945 let new_collaborator = ProjectCollaborator {
1946 project_id,
1947 connection_id: connection_id.0 as i32,
1948 user_id,
1949 replica_id,
1950 is_host: false,
1951 };
1952
1953 sqlx::query(
1954 "
1955 INSERT INTO project_collaborators (
1956 project_id,
1957 connection_id,
1958 user_id,
1959 replica_id,
1960 is_host
1961 )
1962 VALUES ($1, $2, $3, $4, $5)
1963 ",
1964 )
1965 .bind(new_collaborator.project_id)
1966 .bind(new_collaborator.connection_id)
1967 .bind(new_collaborator.user_id)
1968 .bind(new_collaborator.replica_id)
1969 .bind(new_collaborator.is_host)
1970 .execute(&mut tx)
1971 .await?;
1972 collaborators.push(new_collaborator);
1973
1974 let worktree_rows = sqlx::query_as::<_, WorktreeRow>(
1975 "
1976 SELECT *
1977 FROM worktrees
1978 WHERE project_id = $1
1979 ",
1980 )
1981 .bind(project_id)
1982 .fetch_all(&mut tx)
1983 .await?;
1984 let mut worktrees = worktree_rows
1985 .into_iter()
1986 .map(|worktree_row| {
1987 (
1988 worktree_row.id,
1989 Worktree {
1990 id: worktree_row.id,
1991 abs_path: worktree_row.abs_path,
1992 root_name: worktree_row.root_name,
1993 visible: worktree_row.visible,
1994 entries: Default::default(),
1995 diagnostic_summaries: Default::default(),
1996 scan_id: worktree_row.scan_id as u64,
1997 is_complete: worktree_row.is_complete,
1998 },
1999 )
2000 })
2001 .collect::<BTreeMap<_, _>>();
2002
2003 // Populate worktree entries.
2004 {
2005 let mut entries = sqlx::query_as::<_, WorktreeEntry>(
2006 "
2007 SELECT *
2008 FROM worktree_entries
2009 WHERE project_id = $1
2010 ",
2011 )
2012 .bind(project_id)
2013 .fetch(&mut tx);
2014 while let Some(entry) = entries.next().await {
2015 let entry = entry?;
2016 if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) {
2017 worktree.entries.push(proto::Entry {
2018 id: entry.id as u64,
2019 is_dir: entry.is_dir,
2020 path: entry.path,
2021 inode: entry.inode as u64,
2022 mtime: Some(proto::Timestamp {
2023 seconds: entry.mtime_seconds as u64,
2024 nanos: entry.mtime_nanos as u32,
2025 }),
2026 is_symlink: entry.is_symlink,
2027 is_ignored: entry.is_ignored,
2028 });
2029 }
2030 }
2031 }
2032
2033 // Populate worktree diagnostic summaries.
2034 {
2035 let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(
2036 "
2037 SELECT *
2038 FROM worktree_diagnostic_summaries
2039 WHERE project_id = $1
2040 ",
2041 )
2042 .bind(project_id)
2043 .fetch(&mut tx);
2044 while let Some(summary) = summaries.next().await {
2045 let summary = summary?;
2046 if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) {
2047 worktree
2048 .diagnostic_summaries
2049 .push(proto::DiagnosticSummary {
2050 path: summary.path,
2051 language_server_id: summary.language_server_id as u64,
2052 error_count: summary.error_count as u32,
2053 warning_count: summary.warning_count as u32,
2054 });
2055 }
2056 }
2057 }
2058
2059 // Populate language servers.
2060 let language_servers = sqlx::query_as::<_, LanguageServer>(
2061 "
2062 SELECT *
2063 FROM language_servers
2064 WHERE project_id = $1
2065 ",
2066 )
2067 .bind(project_id)
2068 .fetch_all(&mut tx)
2069 .await?;
2070
2071 tx.commit().await?;
2072 Ok((
2073 Project {
2074 collaborators,
2075 worktrees,
2076 language_servers: language_servers
2077 .into_iter()
2078 .map(|language_server| proto::LanguageServer {
2079 id: language_server.id.to_proto(),
2080 name: language_server.name,
2081 })
2082 .collect(),
2083 },
2084 replica_id as ReplicaId,
2085 ))
2086 })
2087 .await
2088 }
2089
2090 pub async fn leave_project(
2091 &self,
2092 project_id: ProjectId,
2093 connection_id: ConnectionId,
2094 ) -> Result<LeftProject> {
2095 self.transact(|mut tx| async move {
2096 let result = sqlx::query(
2097 "
2098 DELETE FROM project_collaborators
2099 WHERE project_id = $1 AND connection_id = $2
2100 ",
2101 )
2102 .bind(project_id)
2103 .bind(connection_id.0 as i32)
2104 .execute(&mut tx)
2105 .await?;
2106
2107 if result.rows_affected() == 0 {
2108 Err(anyhow!("not a collaborator on this project"))?;
2109 }
2110
2111 let connection_ids = sqlx::query_scalar::<_, i32>(
2112 "
2113 SELECT connection_id
2114 FROM project_collaborators
2115 WHERE project_id = $1
2116 ",
2117 )
2118 .bind(project_id)
2119 .fetch_all(&mut tx)
2120 .await?
2121 .into_iter()
2122 .map(|id| ConnectionId(id as u32))
2123 .collect();
2124
2125 let (host_user_id, host_connection_id) = sqlx::query_as::<_, (i32, i32)>(
2126 "
2127 SELECT host_user_id, host_connection_id
2128 FROM projects
2129 WHERE id = $1
2130 ",
2131 )
2132 .bind(project_id)
2133 .fetch_one(&mut tx)
2134 .await?;
2135
2136 tx.commit().await?;
2137
2138 Ok(LeftProject {
2139 id: project_id,
2140 host_user_id: UserId(host_user_id),
2141 host_connection_id: ConnectionId(host_connection_id as u32),
2142 connection_ids,
2143 })
2144 })
2145 .await
2146 }
2147
2148 pub async fn project_collaborators(
2149 &self,
2150 project_id: ProjectId,
2151 connection_id: ConnectionId,
2152 ) -> Result<Vec<ProjectCollaborator>> {
2153 self.transact(|mut tx| async move {
2154 let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2155 "
2156 SELECT *
2157 FROM project_collaborators
2158 WHERE project_id = $1
2159 ",
2160 )
2161 .bind(project_id)
2162 .fetch_all(&mut tx)
2163 .await?;
2164
2165 if collaborators
2166 .iter()
2167 .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
2168 {
2169 Ok(collaborators)
2170 } else {
2171 Err(anyhow!("no such project"))?
2172 }
2173 })
2174 .await
2175 }
2176
2177 pub async fn project_connection_ids(
2178 &self,
2179 project_id: ProjectId,
2180 connection_id: ConnectionId,
2181 ) -> Result<HashSet<ConnectionId>> {
2182 self.transact(|mut tx| async move {
2183 let connection_ids = sqlx::query_scalar::<_, i32>(
2184 "
2185 SELECT connection_id
2186 FROM project_collaborators
2187 WHERE project_id = $1
2188 ",
2189 )
2190 .bind(project_id)
2191 .fetch_all(&mut tx)
2192 .await?;
2193
2194 if connection_ids.contains(&(connection_id.0 as i32)) {
2195 Ok(connection_ids
2196 .into_iter()
2197 .map(|connection_id| ConnectionId(connection_id as u32))
2198 .collect())
2199 } else {
2200 Err(anyhow!("no such project"))?
2201 }
2202 })
2203 .await
2204 }
2205
2206 // contacts
2207
2208 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
2209 self.transact(|mut tx| async move {
2210 let query = "
2211 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
2212 FROM contacts
2213 LEFT JOIN room_participants ON room_participants.user_id = $1
2214 WHERE user_id_a = $1 OR user_id_b = $1;
2215 ";
2216
2217 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
2218 .bind(user_id)
2219 .fetch(&mut tx);
2220
2221 let mut contacts = Vec::new();
2222 while let Some(row) = rows.next().await {
2223 let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
2224 if user_id_a == user_id {
2225 if accepted {
2226 contacts.push(Contact::Accepted {
2227 user_id: user_id_b,
2228 should_notify: should_notify && a_to_b,
2229 busy
2230 });
2231 } else if a_to_b {
2232 contacts.push(Contact::Outgoing { user_id: user_id_b })
2233 } else {
2234 contacts.push(Contact::Incoming {
2235 user_id: user_id_b,
2236 should_notify,
2237 });
2238 }
2239 } else if accepted {
2240 contacts.push(Contact::Accepted {
2241 user_id: user_id_a,
2242 should_notify: should_notify && !a_to_b,
2243 busy
2244 });
2245 } else if a_to_b {
2246 contacts.push(Contact::Incoming {
2247 user_id: user_id_a,
2248 should_notify,
2249 });
2250 } else {
2251 contacts.push(Contact::Outgoing { user_id: user_id_a });
2252 }
2253 }
2254
2255 contacts.sort_unstable_by_key(|contact| contact.user_id());
2256
2257 Ok(contacts)
2258 })
2259 .await
2260 }
2261
2262 pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
2263 self.transact(|mut tx| async move {
2264 Ok(sqlx::query_scalar::<_, i32>(
2265 "
2266 SELECT 1
2267 FROM room_participants
2268 WHERE room_participants.user_id = $1
2269 ",
2270 )
2271 .bind(user_id)
2272 .fetch_optional(&mut tx)
2273 .await?
2274 .is_some())
2275 })
2276 .await
2277 }
2278
2279 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
2280 self.transact(|mut tx| async move {
2281 let (id_a, id_b) = if user_id_1 < user_id_2 {
2282 (user_id_1, user_id_2)
2283 } else {
2284 (user_id_2, user_id_1)
2285 };
2286
2287 let query = "
2288 SELECT 1 FROM contacts
2289 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
2290 LIMIT 1
2291 ";
2292 Ok(sqlx::query_scalar::<_, i32>(query)
2293 .bind(id_a.0)
2294 .bind(id_b.0)
2295 .fetch_optional(&mut tx)
2296 .await?
2297 .is_some())
2298 })
2299 .await
2300 }
2301
2302 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
2303 self.transact(|mut tx| async move {
2304 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
2305 (sender_id, receiver_id, true)
2306 } else {
2307 (receiver_id, sender_id, false)
2308 };
2309 let query = "
2310 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
2311 VALUES ($1, $2, $3, FALSE, TRUE)
2312 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
2313 SET
2314 accepted = TRUE,
2315 should_notify = FALSE
2316 WHERE
2317 NOT contacts.accepted AND
2318 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
2319 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
2320 ";
2321 let result = sqlx::query(query)
2322 .bind(id_a.0)
2323 .bind(id_b.0)
2324 .bind(a_to_b)
2325 .execute(&mut tx)
2326 .await?;
2327
2328 if result.rows_affected() == 1 {
2329 tx.commit().await?;
2330 Ok(())
2331 } else {
2332 Err(anyhow!("contact already requested"))?
2333 }
2334 }).await
2335 }
2336
2337 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2338 self.transact(|mut tx| async move {
2339 let (id_a, id_b) = if responder_id < requester_id {
2340 (responder_id, requester_id)
2341 } else {
2342 (requester_id, responder_id)
2343 };
2344 let query = "
2345 DELETE FROM contacts
2346 WHERE user_id_a = $1 AND user_id_b = $2;
2347 ";
2348 let result = sqlx::query(query)
2349 .bind(id_a.0)
2350 .bind(id_b.0)
2351 .execute(&mut tx)
2352 .await?;
2353
2354 if result.rows_affected() == 1 {
2355 tx.commit().await?;
2356 Ok(())
2357 } else {
2358 Err(anyhow!("no such contact"))?
2359 }
2360 })
2361 .await
2362 }
2363
2364 pub async fn dismiss_contact_notification(
2365 &self,
2366 user_id: UserId,
2367 contact_user_id: UserId,
2368 ) -> Result<()> {
2369 self.transact(|mut tx| async move {
2370 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
2371 (user_id, contact_user_id, true)
2372 } else {
2373 (contact_user_id, user_id, false)
2374 };
2375
2376 let query = "
2377 UPDATE contacts
2378 SET should_notify = FALSE
2379 WHERE
2380 user_id_a = $1 AND user_id_b = $2 AND
2381 (
2382 (a_to_b = $3 AND accepted) OR
2383 (a_to_b != $3 AND NOT accepted)
2384 );
2385 ";
2386
2387 let result = sqlx::query(query)
2388 .bind(id_a.0)
2389 .bind(id_b.0)
2390 .bind(a_to_b)
2391 .execute(&mut tx)
2392 .await?;
2393
2394 if result.rows_affected() == 0 {
2395 Err(anyhow!("no such contact request"))?
2396 } else {
2397 tx.commit().await?;
2398 Ok(())
2399 }
2400 })
2401 .await
2402 }
2403
2404 pub async fn respond_to_contact_request(
2405 &self,
2406 responder_id: UserId,
2407 requester_id: UserId,
2408 accept: bool,
2409 ) -> Result<()> {
2410 self.transact(|mut tx| async move {
2411 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
2412 (responder_id, requester_id, false)
2413 } else {
2414 (requester_id, responder_id, true)
2415 };
2416 let result = if accept {
2417 let query = "
2418 UPDATE contacts
2419 SET accepted = TRUE, should_notify = TRUE
2420 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
2421 ";
2422 sqlx::query(query)
2423 .bind(id_a.0)
2424 .bind(id_b.0)
2425 .bind(a_to_b)
2426 .execute(&mut tx)
2427 .await?
2428 } else {
2429 let query = "
2430 DELETE FROM contacts
2431 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
2432 ";
2433 sqlx::query(query)
2434 .bind(id_a.0)
2435 .bind(id_b.0)
2436 .bind(a_to_b)
2437 .execute(&mut tx)
2438 .await?
2439 };
2440 if result.rows_affected() == 1 {
2441 tx.commit().await?;
2442 Ok(())
2443 } else {
2444 Err(anyhow!("no such contact request"))?
2445 }
2446 })
2447 .await
2448 }
2449
2450 // access tokens
2451
2452 pub async fn create_access_token_hash(
2453 &self,
2454 user_id: UserId,
2455 access_token_hash: &str,
2456 max_access_token_count: usize,
2457 ) -> Result<()> {
2458 self.transact(|tx| async {
2459 let mut tx = tx;
2460 let insert_query = "
2461 INSERT INTO access_tokens (user_id, hash)
2462 VALUES ($1, $2);
2463 ";
2464 let cleanup_query = "
2465 DELETE FROM access_tokens
2466 WHERE id IN (
2467 SELECT id from access_tokens
2468 WHERE user_id = $1
2469 ORDER BY id DESC
2470 LIMIT 10000
2471 OFFSET $3
2472 )
2473 ";
2474
2475 sqlx::query(insert_query)
2476 .bind(user_id.0)
2477 .bind(access_token_hash)
2478 .execute(&mut tx)
2479 .await?;
2480 sqlx::query(cleanup_query)
2481 .bind(user_id.0)
2482 .bind(access_token_hash)
2483 .bind(max_access_token_count as i32)
2484 .execute(&mut tx)
2485 .await?;
2486 Ok(tx.commit().await?)
2487 })
2488 .await
2489 }
2490
2491 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
2492 self.transact(|mut tx| async move {
2493 let query = "
2494 SELECT hash
2495 FROM access_tokens
2496 WHERE user_id = $1
2497 ORDER BY id DESC
2498 ";
2499 Ok(sqlx::query_scalar(query)
2500 .bind(user_id.0)
2501 .fetch_all(&mut tx)
2502 .await?)
2503 })
2504 .await
2505 }
2506
2507 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
2508 where
2509 F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
2510 Fut: Send + Future<Output = Result<T>>,
2511 {
2512 let body = async {
2513 loop {
2514 let tx = self.begin_transaction().await?;
2515 match f(tx).await {
2516 Ok(result) => return Ok(result),
2517 Err(error) => match error {
2518 Error::Database(error)
2519 if error
2520 .as_database_error()
2521 .and_then(|error| error.code())
2522 .as_deref()
2523 == Some("40001") =>
2524 {
2525 // Retry (don't break the loop)
2526 }
2527 error @ _ => return Err(error),
2528 },
2529 }
2530 }
2531 };
2532
2533 #[cfg(test)]
2534 {
2535 if let Some(background) = self.background.as_ref() {
2536 background.simulate_random_delay().await;
2537 }
2538
2539 let result = self.runtime.as_ref().unwrap().block_on(body);
2540
2541 if let Some(background) = self.background.as_ref() {
2542 background.simulate_random_delay().await;
2543 }
2544
2545 result
2546 }
2547
2548 #[cfg(not(test))]
2549 {
2550 body.await
2551 }
2552 }
2553}
2554
2555macro_rules! id_type {
2556 ($name:ident) => {
2557 #[derive(
2558 Clone,
2559 Copy,
2560 Debug,
2561 Default,
2562 PartialEq,
2563 Eq,
2564 PartialOrd,
2565 Ord,
2566 Hash,
2567 sqlx::Type,
2568 Serialize,
2569 Deserialize,
2570 )]
2571 #[sqlx(transparent)]
2572 #[serde(transparent)]
2573 pub struct $name(pub i32);
2574
2575 impl $name {
2576 #[allow(unused)]
2577 pub const MAX: Self = Self(i32::MAX);
2578
2579 #[allow(unused)]
2580 pub fn from_proto(value: u64) -> Self {
2581 Self(value as i32)
2582 }
2583
2584 #[allow(unused)]
2585 pub fn to_proto(self) -> u64 {
2586 self.0 as u64
2587 }
2588 }
2589
2590 impl std::fmt::Display for $name {
2591 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2592 self.0.fmt(f)
2593 }
2594 }
2595 };
2596}
2597
2598id_type!(UserId);
2599#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2600pub struct User {
2601 pub id: UserId,
2602 pub github_login: String,
2603 pub github_user_id: Option<i32>,
2604 pub email_address: Option<String>,
2605 pub admin: bool,
2606 pub invite_code: Option<String>,
2607 pub invite_count: i32,
2608 pub connected_once: bool,
2609}
2610
2611id_type!(RoomId);
2612#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2613pub struct Room {
2614 pub id: RoomId,
2615 pub live_kit_room: String,
2616}
2617
2618id_type!(ProjectId);
2619pub struct Project {
2620 pub collaborators: Vec<ProjectCollaborator>,
2621 pub worktrees: BTreeMap<WorktreeId, Worktree>,
2622 pub language_servers: Vec<proto::LanguageServer>,
2623}
2624
2625id_type!(ReplicaId);
2626#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2627pub struct ProjectCollaborator {
2628 pub project_id: ProjectId,
2629 pub connection_id: i32,
2630 pub user_id: UserId,
2631 pub replica_id: ReplicaId,
2632 pub is_host: bool,
2633}
2634
2635id_type!(WorktreeId);
2636#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2637struct WorktreeRow {
2638 pub id: WorktreeId,
2639 pub abs_path: String,
2640 pub root_name: String,
2641 pub visible: bool,
2642 pub scan_id: i64,
2643 pub is_complete: bool,
2644}
2645
2646pub struct Worktree {
2647 pub id: WorktreeId,
2648 pub abs_path: String,
2649 pub root_name: String,
2650 pub visible: bool,
2651 pub entries: Vec<proto::Entry>,
2652 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
2653 pub scan_id: u64,
2654 pub is_complete: bool,
2655}
2656
2657#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2658struct WorktreeEntry {
2659 id: i64,
2660 worktree_id: WorktreeId,
2661 is_dir: bool,
2662 path: String,
2663 inode: i64,
2664 mtime_seconds: i64,
2665 mtime_nanos: i32,
2666 is_symlink: bool,
2667 is_ignored: bool,
2668}
2669
2670#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2671struct WorktreeDiagnosticSummary {
2672 worktree_id: WorktreeId,
2673 path: String,
2674 language_server_id: i64,
2675 error_count: i32,
2676 warning_count: i32,
2677}
2678
2679id_type!(LanguageServerId);
2680#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2681struct LanguageServer {
2682 id: LanguageServerId,
2683 name: String,
2684}
2685
2686pub struct LeftProject {
2687 pub id: ProjectId,
2688 pub host_user_id: UserId,
2689 pub host_connection_id: ConnectionId,
2690 pub connection_ids: Vec<ConnectionId>,
2691}
2692
2693pub struct LeftRoom {
2694 pub room: proto::Room,
2695 pub left_projects: HashMap<ProjectId, LeftProject>,
2696 pub canceled_calls_to_user_ids: Vec<UserId>,
2697}
2698
2699#[derive(Clone, Debug, PartialEq, Eq)]
2700pub enum Contact {
2701 Accepted {
2702 user_id: UserId,
2703 should_notify: bool,
2704 busy: bool,
2705 },
2706 Outgoing {
2707 user_id: UserId,
2708 },
2709 Incoming {
2710 user_id: UserId,
2711 should_notify: bool,
2712 },
2713}
2714
2715impl Contact {
2716 pub fn user_id(&self) -> UserId {
2717 match self {
2718 Contact::Accepted { user_id, .. } => *user_id,
2719 Contact::Outgoing { user_id } => *user_id,
2720 Contact::Incoming { user_id, .. } => *user_id,
2721 }
2722 }
2723}
2724
2725#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2726pub struct IncomingContactRequest {
2727 pub requester_id: UserId,
2728 pub should_notify: bool,
2729}
2730
2731#[derive(Clone, Deserialize)]
2732pub struct Signup {
2733 pub email_address: String,
2734 pub platform_mac: bool,
2735 pub platform_windows: bool,
2736 pub platform_linux: bool,
2737 pub editor_features: Vec<String>,
2738 pub programming_languages: Vec<String>,
2739 pub device_id: Option<String>,
2740}
2741
2742#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2743pub struct WaitlistSummary {
2744 #[sqlx(default)]
2745 pub count: i64,
2746 #[sqlx(default)]
2747 pub linux_count: i64,
2748 #[sqlx(default)]
2749 pub mac_count: i64,
2750 #[sqlx(default)]
2751 pub windows_count: i64,
2752 #[sqlx(default)]
2753 pub unknown_count: i64,
2754}
2755
2756#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2757pub struct Invite {
2758 pub email_address: String,
2759 pub email_confirmation_code: String,
2760}
2761
2762#[derive(Debug, Serialize, Deserialize)]
2763pub struct NewUserParams {
2764 pub github_login: String,
2765 pub github_user_id: i32,
2766 pub invite_count: i32,
2767}
2768
2769#[derive(Debug)]
2770pub struct NewUserResult {
2771 pub user_id: UserId,
2772 pub metrics_id: String,
2773 pub inviting_user_id: Option<UserId>,
2774 pub signup_device_id: Option<String>,
2775}
2776
2777fn random_invite_code() -> String {
2778 nanoid::nanoid!(16)
2779}
2780
2781fn random_email_confirmation_code() -> String {
2782 nanoid::nanoid!(64)
2783}
2784
2785#[cfg(test)]
2786pub use test::*;
2787
2788#[cfg(test)]
2789mod test {
2790 use super::*;
2791 use gpui::executor::Background;
2792 use lazy_static::lazy_static;
2793 use parking_lot::Mutex;
2794 use rand::prelude::*;
2795 use sqlx::migrate::MigrateDatabase;
2796 use std::sync::Arc;
2797
2798 pub struct SqliteTestDb {
2799 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2800 pub conn: sqlx::sqlite::SqliteConnection,
2801 }
2802
2803 pub struct PostgresTestDb {
2804 pub db: Option<Arc<Db<sqlx::Postgres>>>,
2805 pub url: String,
2806 }
2807
2808 impl SqliteTestDb {
2809 pub fn new(background: Arc<Background>) -> Self {
2810 let mut rng = StdRng::from_entropy();
2811 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2812 let runtime = tokio::runtime::Builder::new_current_thread()
2813 .enable_io()
2814 .enable_time()
2815 .build()
2816 .unwrap();
2817
2818 let (mut db, conn) = runtime.block_on(async {
2819 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2820 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2821 db.migrate(migrations_path.as_ref(), false).await.unwrap();
2822 let conn = db.pool.acquire().await.unwrap().detach();
2823 (db, conn)
2824 });
2825
2826 db.background = Some(background);
2827 db.runtime = Some(runtime);
2828
2829 Self {
2830 db: Some(Arc::new(db)),
2831 conn,
2832 }
2833 }
2834
2835 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2836 self.db.as_ref().unwrap()
2837 }
2838 }
2839
2840 impl PostgresTestDb {
2841 pub fn new(background: Arc<Background>) -> Self {
2842 lazy_static! {
2843 static ref LOCK: Mutex<()> = Mutex::new(());
2844 }
2845
2846 let _guard = LOCK.lock();
2847 let mut rng = StdRng::from_entropy();
2848 let url = format!(
2849 "postgres://postgres@localhost/zed-test-{}",
2850 rng.gen::<u128>()
2851 );
2852 let runtime = tokio::runtime::Builder::new_current_thread()
2853 .enable_io()
2854 .enable_time()
2855 .build()
2856 .unwrap();
2857
2858 let mut db = runtime.block_on(async {
2859 sqlx::Postgres::create_database(&url)
2860 .await
2861 .expect("failed to create test db");
2862 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2863 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2864 db.migrate(Path::new(migrations_path), false).await.unwrap();
2865 db
2866 });
2867
2868 db.background = Some(background);
2869 db.runtime = Some(runtime);
2870
2871 Self {
2872 db: Some(Arc::new(db)),
2873 url,
2874 }
2875 }
2876
2877 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2878 self.db.as_ref().unwrap()
2879 }
2880 }
2881
2882 impl Drop for PostgresTestDb {
2883 fn drop(&mut self) {
2884 let db = self.db.take().unwrap();
2885 db.teardown(&self.url);
2886 }
2887 }
2888}