1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::{BTreeMap, HashMap, HashSet};
5use futures::{future::BoxFuture, FutureExt, StreamExt};
6use rpc::{proto, ConnectionId};
7use serde::{Deserialize, Serialize};
8use sqlx::{
9 migrate::{Migrate as _, Migration, MigrationSource},
10 types::Uuid,
11 FromRow,
12};
13use std::{future::Future, path::Path, time::Duration};
14use time::{OffsetDateTime, PrimitiveDateTime};
15
16#[cfg(test)]
17pub type DefaultDb = Db<sqlx::Sqlite>;
18
19#[cfg(not(test))]
20pub type DefaultDb = Db<sqlx::Postgres>;
21
22pub struct Db<D: sqlx::Database> {
23 pool: sqlx::Pool<D>,
24 #[cfg(test)]
25 background: Option<std::sync::Arc<gpui::executor::Background>>,
26 #[cfg(test)]
27 runtime: Option<tokio::runtime::Runtime>,
28}
29
30pub trait BeginTransaction: Send + Sync {
31 type Database: sqlx::Database;
32
33 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
34}
35
36// In Postgres, serializable transactions are opt-in
37impl BeginTransaction for Db<sqlx::Postgres> {
38 type Database = sqlx::Postgres;
39
40 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
41 async move {
42 let mut tx = self.pool.begin().await?;
43 sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
44 .await?;
45 Ok(tx)
46 }
47 .boxed()
48 }
49}
50
51// In Sqlite, transactions are inherently serializable.
52#[cfg(test)]
53impl BeginTransaction for Db<sqlx::Sqlite> {
54 type Database = sqlx::Sqlite;
55
56 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
57 async move { Ok(self.pool.begin().await?) }.boxed()
58 }
59}
60
61pub trait RowsAffected {
62 fn rows_affected(&self) -> u64;
63}
64
65#[cfg(test)]
66impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
67 fn rows_affected(&self) -> u64 {
68 self.rows_affected()
69 }
70}
71
72impl RowsAffected for sqlx::postgres::PgQueryResult {
73 fn rows_affected(&self) -> u64 {
74 self.rows_affected()
75 }
76}
77
78#[cfg(test)]
79impl Db<sqlx::Sqlite> {
80 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
81 use std::str::FromStr as _;
82 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
83 .unwrap()
84 .create_if_missing(true)
85 .shared_cache(true);
86 let pool = sqlx::sqlite::SqlitePoolOptions::new()
87 .min_connections(2)
88 .max_connections(max_connections)
89 .connect_with(options)
90 .await?;
91 Ok(Self {
92 pool,
93 background: None,
94 runtime: None,
95 })
96 }
97
98 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
99 self.transact(|tx| async {
100 let mut tx = tx;
101 let query = "
102 SELECT users.*
103 FROM users
104 WHERE users.id IN (SELECT value from json_each($1))
105 ";
106 Ok(sqlx::query_as(query)
107 .bind(&serde_json::json!(ids))
108 .fetch_all(&mut tx)
109 .await?)
110 })
111 .await
112 }
113
114 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
115 self.transact(|mut tx| async move {
116 let query = "
117 SELECT metrics_id
118 FROM users
119 WHERE id = $1
120 ";
121 Ok(sqlx::query_scalar(query)
122 .bind(id)
123 .fetch_one(&mut tx)
124 .await?)
125 })
126 .await
127 }
128
129 pub async fn create_user(
130 &self,
131 email_address: &str,
132 admin: bool,
133 params: NewUserParams,
134 ) -> Result<NewUserResult> {
135 self.transact(|mut tx| async {
136 let query = "
137 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
138 VALUES ($1, $2, $3, $4, $5)
139 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
140 RETURNING id, metrics_id
141 ";
142
143 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
144 .bind(email_address)
145 .bind(¶ms.github_login)
146 .bind(¶ms.github_user_id)
147 .bind(admin)
148 .bind(Uuid::new_v4().to_string())
149 .fetch_one(&mut tx)
150 .await?;
151 tx.commit().await?;
152 Ok(NewUserResult {
153 user_id,
154 metrics_id,
155 signup_device_id: None,
156 inviting_user_id: None,
157 })
158 })
159 .await
160 }
161
162 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
163 unimplemented!()
164 }
165
166 pub async fn create_user_from_invite(
167 &self,
168 _invite: &Invite,
169 _user: NewUserParams,
170 ) -> Result<Option<NewUserResult>> {
171 unimplemented!()
172 }
173
174 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
175 unimplemented!()
176 }
177
178 pub async fn create_invite_from_code(
179 &self,
180 _code: &str,
181 _email_address: &str,
182 _device_id: Option<&str>,
183 ) -> Result<Invite> {
184 unimplemented!()
185 }
186
187 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
188 unimplemented!()
189 }
190}
191
192impl Db<sqlx::Postgres> {
193 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
194 let pool = sqlx::postgres::PgPoolOptions::new()
195 .max_connections(max_connections)
196 .connect(url)
197 .await?;
198 Ok(Self {
199 pool,
200 #[cfg(test)]
201 background: None,
202 #[cfg(test)]
203 runtime: None,
204 })
205 }
206
207 #[cfg(test)]
208 pub fn teardown(&self, url: &str) {
209 self.runtime.as_ref().unwrap().block_on(async {
210 use util::ResultExt;
211 let query = "
212 SELECT pg_terminate_backend(pg_stat_activity.pid)
213 FROM pg_stat_activity
214 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
215 ";
216 sqlx::query(query).execute(&self.pool).await.log_err();
217 self.pool.close().await;
218 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
219 .await
220 .log_err();
221 })
222 }
223
224 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
225 self.transact(|tx| async {
226 let mut tx = tx;
227 let like_string = Self::fuzzy_like_string(name_query);
228 let query = "
229 SELECT users.*
230 FROM users
231 WHERE github_login ILIKE $1
232 ORDER BY github_login <-> $2
233 LIMIT $3
234 ";
235 Ok(sqlx::query_as(query)
236 .bind(like_string)
237 .bind(name_query)
238 .bind(limit as i32)
239 .fetch_all(&mut tx)
240 .await?)
241 })
242 .await
243 }
244
245 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
246 let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
247 self.transact(|tx| async {
248 let mut tx = tx;
249 let query = "
250 SELECT users.*
251 FROM users
252 WHERE users.id = ANY ($1)
253 ";
254 Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
255 })
256 .await
257 }
258
259 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
260 self.transact(|mut tx| async move {
261 let query = "
262 SELECT metrics_id::text
263 FROM users
264 WHERE id = $1
265 ";
266 Ok(sqlx::query_scalar(query)
267 .bind(id)
268 .fetch_one(&mut tx)
269 .await?)
270 })
271 .await
272 }
273
274 pub async fn create_user(
275 &self,
276 email_address: &str,
277 admin: bool,
278 params: NewUserParams,
279 ) -> Result<NewUserResult> {
280 self.transact(|mut tx| async {
281 let query = "
282 INSERT INTO users (email_address, github_login, github_user_id, admin)
283 VALUES ($1, $2, $3, $4)
284 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
285 RETURNING id, metrics_id::text
286 ";
287
288 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
289 .bind(email_address)
290 .bind(¶ms.github_login)
291 .bind(params.github_user_id)
292 .bind(admin)
293 .fetch_one(&mut tx)
294 .await?;
295 tx.commit().await?;
296
297 Ok(NewUserResult {
298 user_id,
299 metrics_id,
300 signup_device_id: None,
301 inviting_user_id: None,
302 })
303 })
304 .await
305 }
306
307 pub async fn create_user_from_invite(
308 &self,
309 invite: &Invite,
310 user: NewUserParams,
311 ) -> Result<Option<NewUserResult>> {
312 self.transact(|mut tx| async {
313 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
314 i32,
315 Option<UserId>,
316 Option<UserId>,
317 Option<String>,
318 ) = sqlx::query_as(
319 "
320 SELECT id, user_id, inviting_user_id, device_id
321 FROM signups
322 WHERE
323 email_address = $1 AND
324 email_confirmation_code = $2
325 ",
326 )
327 .bind(&invite.email_address)
328 .bind(&invite.email_confirmation_code)
329 .fetch_optional(&mut tx)
330 .await?
331 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
332
333 if existing_user_id.is_some() {
334 return Ok(None);
335 }
336
337 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
338 "
339 INSERT INTO users
340 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
341 VALUES
342 ($1, $2, $3, FALSE, $4, $5)
343 ON CONFLICT (github_login) DO UPDATE SET
344 email_address = excluded.email_address,
345 github_user_id = excluded.github_user_id,
346 admin = excluded.admin
347 RETURNING id, metrics_id::text
348 ",
349 )
350 .bind(&invite.email_address)
351 .bind(&user.github_login)
352 .bind(&user.github_user_id)
353 .bind(&user.invite_count)
354 .bind(random_invite_code())
355 .fetch_one(&mut tx)
356 .await?;
357
358 sqlx::query(
359 "
360 UPDATE signups
361 SET user_id = $1
362 WHERE id = $2
363 ",
364 )
365 .bind(&user_id)
366 .bind(&signup_id)
367 .execute(&mut tx)
368 .await?;
369
370 if let Some(inviting_user_id) = inviting_user_id {
371 let id: Option<UserId> = sqlx::query_scalar(
372 "
373 UPDATE users
374 SET invite_count = invite_count - 1
375 WHERE id = $1 AND invite_count > 0
376 RETURNING id
377 ",
378 )
379 .bind(&inviting_user_id)
380 .fetch_optional(&mut tx)
381 .await?;
382
383 if id.is_none() {
384 Err(Error::Http(
385 StatusCode::UNAUTHORIZED,
386 "no invites remaining".to_string(),
387 ))?;
388 }
389
390 sqlx::query(
391 "
392 INSERT INTO contacts
393 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
394 VALUES
395 ($1, $2, TRUE, TRUE, TRUE)
396 ON CONFLICT DO NOTHING
397 ",
398 )
399 .bind(inviting_user_id)
400 .bind(user_id)
401 .execute(&mut tx)
402 .await?;
403 }
404
405 tx.commit().await?;
406 Ok(Some(NewUserResult {
407 user_id,
408 metrics_id,
409 inviting_user_id,
410 signup_device_id,
411 }))
412 })
413 .await
414 }
415
416 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
417 self.transact(|mut tx| async {
418 sqlx::query(
419 "
420 INSERT INTO signups
421 (
422 email_address,
423 email_confirmation_code,
424 email_confirmation_sent,
425 platform_linux,
426 platform_mac,
427 platform_windows,
428 platform_unknown,
429 editor_features,
430 programming_languages,
431 device_id
432 )
433 VALUES
434 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
435 RETURNING id
436 ",
437 )
438 .bind(&signup.email_address)
439 .bind(&random_email_confirmation_code())
440 .bind(&signup.platform_linux)
441 .bind(&signup.platform_mac)
442 .bind(&signup.platform_windows)
443 .bind(&signup.editor_features)
444 .bind(&signup.programming_languages)
445 .bind(&signup.device_id)
446 .execute(&mut tx)
447 .await?;
448 tx.commit().await?;
449 Ok(())
450 })
451 .await
452 }
453
454 pub async fn create_invite_from_code(
455 &self,
456 code: &str,
457 email_address: &str,
458 device_id: Option<&str>,
459 ) -> Result<Invite> {
460 self.transact(|mut tx| async {
461 let existing_user: Option<UserId> = sqlx::query_scalar(
462 "
463 SELECT id
464 FROM users
465 WHERE email_address = $1
466 ",
467 )
468 .bind(email_address)
469 .fetch_optional(&mut tx)
470 .await?;
471 if existing_user.is_some() {
472 Err(anyhow!("email address is already in use"))?;
473 }
474
475 let row: Option<(UserId, i32)> = sqlx::query_as(
476 "
477 SELECT id, invite_count
478 FROM users
479 WHERE invite_code = $1
480 ",
481 )
482 .bind(code)
483 .fetch_optional(&mut tx)
484 .await?;
485
486 let (inviter_id, invite_count) = match row {
487 Some(row) => row,
488 None => Err(Error::Http(
489 StatusCode::NOT_FOUND,
490 "invite code not found".to_string(),
491 ))?,
492 };
493
494 if invite_count == 0 {
495 Err(Error::Http(
496 StatusCode::UNAUTHORIZED,
497 "no invites remaining".to_string(),
498 ))?;
499 }
500
501 let email_confirmation_code: String = sqlx::query_scalar(
502 "
503 INSERT INTO signups
504 (
505 email_address,
506 email_confirmation_code,
507 email_confirmation_sent,
508 inviting_user_id,
509 platform_linux,
510 platform_mac,
511 platform_windows,
512 platform_unknown,
513 device_id
514 )
515 VALUES
516 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
517 ON CONFLICT (email_address)
518 DO UPDATE SET
519 inviting_user_id = excluded.inviting_user_id
520 RETURNING email_confirmation_code
521 ",
522 )
523 .bind(&email_address)
524 .bind(&random_email_confirmation_code())
525 .bind(&inviter_id)
526 .bind(&device_id)
527 .fetch_one(&mut tx)
528 .await?;
529
530 tx.commit().await?;
531
532 Ok(Invite {
533 email_address: email_address.into(),
534 email_confirmation_code,
535 })
536 })
537 .await
538 }
539
540 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
541 self.transact(|mut tx| async {
542 let emails = invites
543 .iter()
544 .map(|s| s.email_address.as_str())
545 .collect::<Vec<_>>();
546 sqlx::query(
547 "
548 UPDATE signups
549 SET email_confirmation_sent = TRUE
550 WHERE email_address = ANY ($1)
551 ",
552 )
553 .bind(&emails)
554 .execute(&mut tx)
555 .await?;
556 tx.commit().await?;
557 Ok(())
558 })
559 .await
560 }
561}
562
563impl<D> Db<D>
564where
565 Self: BeginTransaction<Database = D>,
566 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
567 D::Connection: sqlx::migrate::Migrate,
568 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
569 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
570 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
571 D::QueryResult: RowsAffected,
572 String: sqlx::Type<D>,
573 i32: sqlx::Type<D>,
574 i64: sqlx::Type<D>,
575 bool: sqlx::Type<D>,
576 str: sqlx::Type<D>,
577 Uuid: sqlx::Type<D>,
578 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
579 OffsetDateTime: sqlx::Type<D>,
580 PrimitiveDateTime: sqlx::Type<D>,
581 usize: sqlx::ColumnIndex<D::Row>,
582 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
583 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
584 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
585 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
586 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
587 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
588 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
589 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
590 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
591 for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
592 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
593 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
594 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
595{
596 pub async fn migrate(
597 &self,
598 migrations_path: &Path,
599 ignore_checksum_mismatch: bool,
600 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
601 let migrations = MigrationSource::resolve(migrations_path)
602 .await
603 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
604
605 let mut conn = self.pool.acquire().await?;
606
607 conn.ensure_migrations_table().await?;
608 let applied_migrations: HashMap<_, _> = conn
609 .list_applied_migrations()
610 .await?
611 .into_iter()
612 .map(|m| (m.version, m))
613 .collect();
614
615 let mut new_migrations = Vec::new();
616 for migration in migrations {
617 match applied_migrations.get(&migration.version) {
618 Some(applied_migration) => {
619 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
620 {
621 Err(anyhow!(
622 "checksum mismatch for applied migration {}",
623 migration.description
624 ))?;
625 }
626 }
627 None => {
628 let elapsed = conn.apply(&migration).await?;
629 new_migrations.push((migration, elapsed));
630 }
631 }
632 }
633
634 Ok(new_migrations)
635 }
636
637 pub fn fuzzy_like_string(string: &str) -> String {
638 let mut result = String::with_capacity(string.len() * 2 + 1);
639 for c in string.chars() {
640 if c.is_alphanumeric() {
641 result.push('%');
642 result.push(c);
643 }
644 }
645 result.push('%');
646 result
647 }
648
649 // users
650
651 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
652 self.transact(|tx| async {
653 let mut tx = tx;
654 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
655 Ok(sqlx::query_as(query)
656 .bind(limit as i32)
657 .bind((page * limit) as i32)
658 .fetch_all(&mut tx)
659 .await?)
660 })
661 .await
662 }
663
664 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
665 self.transact(|tx| async {
666 let mut tx = tx;
667 let query = "
668 SELECT users.*
669 FROM users
670 WHERE id = $1
671 LIMIT 1
672 ";
673 Ok(sqlx::query_as(query)
674 .bind(&id)
675 .fetch_optional(&mut tx)
676 .await?)
677 })
678 .await
679 }
680
681 pub async fn get_users_with_no_invites(
682 &self,
683 invited_by_another_user: bool,
684 ) -> Result<Vec<User>> {
685 self.transact(|tx| async {
686 let mut tx = tx;
687 let query = format!(
688 "
689 SELECT users.*
690 FROM users
691 WHERE invite_count = 0
692 AND inviter_id IS{} NULL
693 ",
694 if invited_by_another_user { " NOT" } else { "" }
695 );
696
697 Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
698 })
699 .await
700 }
701
702 pub async fn get_user_by_github_account(
703 &self,
704 github_login: &str,
705 github_user_id: Option<i32>,
706 ) -> Result<Option<User>> {
707 self.transact(|tx| async {
708 let mut tx = tx;
709 if let Some(github_user_id) = github_user_id {
710 let mut user = sqlx::query_as::<_, User>(
711 "
712 UPDATE users
713 SET github_login = $1
714 WHERE github_user_id = $2
715 RETURNING *
716 ",
717 )
718 .bind(github_login)
719 .bind(github_user_id)
720 .fetch_optional(&mut tx)
721 .await?;
722
723 if user.is_none() {
724 user = sqlx::query_as::<_, User>(
725 "
726 UPDATE users
727 SET github_user_id = $1
728 WHERE github_login = $2
729 RETURNING *
730 ",
731 )
732 .bind(github_user_id)
733 .bind(github_login)
734 .fetch_optional(&mut tx)
735 .await?;
736 }
737
738 Ok(user)
739 } else {
740 let user = sqlx::query_as(
741 "
742 SELECT * FROM users
743 WHERE github_login = $1
744 LIMIT 1
745 ",
746 )
747 .bind(github_login)
748 .fetch_optional(&mut tx)
749 .await?;
750 Ok(user)
751 }
752 })
753 .await
754 }
755
756 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
757 self.transact(|mut tx| async {
758 let query = "UPDATE users SET admin = $1 WHERE id = $2";
759 sqlx::query(query)
760 .bind(is_admin)
761 .bind(id.0)
762 .execute(&mut tx)
763 .await?;
764 tx.commit().await?;
765 Ok(())
766 })
767 .await
768 }
769
770 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
771 self.transact(|mut tx| async move {
772 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
773 sqlx::query(query)
774 .bind(connected_once)
775 .bind(id.0)
776 .execute(&mut tx)
777 .await?;
778 tx.commit().await?;
779 Ok(())
780 })
781 .await
782 }
783
784 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
785 self.transact(|mut tx| async move {
786 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
787 sqlx::query(query)
788 .bind(id.0)
789 .execute(&mut tx)
790 .await
791 .map(drop)?;
792 let query = "DELETE FROM users WHERE id = $1;";
793 sqlx::query(query).bind(id.0).execute(&mut tx).await?;
794 tx.commit().await?;
795 Ok(())
796 })
797 .await
798 }
799
800 // signups
801
802 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
803 self.transact(|mut tx| async move {
804 Ok(sqlx::query_as(
805 "
806 SELECT
807 COUNT(*) as count,
808 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
809 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
810 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
811 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
812 FROM (
813 SELECT *
814 FROM signups
815 WHERE
816 NOT email_confirmation_sent
817 ) AS unsent
818 ",
819 )
820 .fetch_one(&mut tx)
821 .await?)
822 })
823 .await
824 }
825
826 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
827 self.transact(|mut tx| async move {
828 Ok(sqlx::query_as(
829 "
830 SELECT
831 email_address, email_confirmation_code
832 FROM signups
833 WHERE
834 NOT email_confirmation_sent AND
835 (platform_mac OR platform_unknown)
836 LIMIT $1
837 ",
838 )
839 .bind(count as i32)
840 .fetch_all(&mut tx)
841 .await?)
842 })
843 .await
844 }
845
846 // invite codes
847
848 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
849 self.transact(|mut tx| async move {
850 if count > 0 {
851 sqlx::query(
852 "
853 UPDATE users
854 SET invite_code = $1
855 WHERE id = $2 AND invite_code IS NULL
856 ",
857 )
858 .bind(random_invite_code())
859 .bind(id)
860 .execute(&mut tx)
861 .await?;
862 }
863
864 sqlx::query(
865 "
866 UPDATE users
867 SET invite_count = $1
868 WHERE id = $2
869 ",
870 )
871 .bind(count as i32)
872 .bind(id)
873 .execute(&mut tx)
874 .await?;
875 tx.commit().await?;
876 Ok(())
877 })
878 .await
879 }
880
881 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
882 self.transact(|mut tx| async move {
883 let result: Option<(String, i32)> = sqlx::query_as(
884 "
885 SELECT invite_code, invite_count
886 FROM users
887 WHERE id = $1 AND invite_code IS NOT NULL
888 ",
889 )
890 .bind(id)
891 .fetch_optional(&mut tx)
892 .await?;
893 if let Some((code, count)) = result {
894 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
895 } else {
896 Ok(None)
897 }
898 })
899 .await
900 }
901
902 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
903 self.transact(|tx| async {
904 let mut tx = tx;
905 sqlx::query_as(
906 "
907 SELECT *
908 FROM users
909 WHERE invite_code = $1
910 ",
911 )
912 .bind(code)
913 .fetch_optional(&mut tx)
914 .await?
915 .ok_or_else(|| {
916 Error::Http(
917 StatusCode::NOT_FOUND,
918 "that invite code does not exist".to_string(),
919 )
920 })
921 })
922 .await
923 }
924
925 pub async fn create_room(
926 &self,
927 user_id: UserId,
928 connection_id: ConnectionId,
929 ) -> Result<proto::Room> {
930 self.transact(|mut tx| async move {
931 let live_kit_room = nanoid::nanoid!(30);
932 let room_id = sqlx::query_scalar(
933 "
934 INSERT INTO rooms (live_kit_room, version)
935 VALUES ($1, $2)
936 RETURNING id
937 ",
938 )
939 .bind(&live_kit_room)
940 .bind(0)
941 .fetch_one(&mut tx)
942 .await
943 .map(RoomId)?;
944
945 sqlx::query(
946 "
947 INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
948 VALUES ($1, $2, $3, $4, $5)
949 ",
950 )
951 .bind(room_id)
952 .bind(user_id)
953 .bind(connection_id.0 as i32)
954 .bind(user_id)
955 .bind(connection_id.0 as i32)
956 .execute(&mut tx)
957 .await?;
958
959 self.commit_room_transaction(room_id, tx).await
960 }).await
961 }
962
963 pub async fn call(
964 &self,
965 room_id: RoomId,
966 calling_user_id: UserId,
967 calling_connection_id: ConnectionId,
968 called_user_id: UserId,
969 initial_project_id: Option<ProjectId>,
970 ) -> Result<(proto::Room, proto::IncomingCall)> {
971 self.transact(|mut tx| async move {
972 sqlx::query(
973 "
974 INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id)
975 VALUES ($1, $2, $3, $4, $5)
976 ",
977 )
978 .bind(room_id)
979 .bind(called_user_id)
980 .bind(calling_user_id)
981 .bind(calling_connection_id.0 as i32)
982 .bind(initial_project_id)
983 .execute(&mut tx)
984 .await?;
985
986 let room = self.commit_room_transaction(room_id, tx).await?;
987 let incoming_call = Self::build_incoming_call(&room, called_user_id)
988 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
989 Ok((room, incoming_call))
990 }).await
991 }
992
993 pub async fn incoming_call_for_user(
994 &self,
995 user_id: UserId,
996 ) -> Result<Option<proto::IncomingCall>> {
997 self.transact(|mut tx| async move {
998 let room_id = sqlx::query_scalar::<_, RoomId>(
999 "
1000 SELECT room_id
1001 FROM room_participants
1002 WHERE user_id = $1 AND answering_connection_id IS NULL
1003 ",
1004 )
1005 .bind(user_id)
1006 .fetch_optional(&mut tx)
1007 .await?;
1008
1009 if let Some(room_id) = room_id {
1010 let room = self.get_room(room_id, &mut tx).await?;
1011 Ok(Self::build_incoming_call(&room, user_id))
1012 } else {
1013 Ok(None)
1014 }
1015 })
1016 .await
1017 }
1018
1019 fn build_incoming_call(
1020 room: &proto::Room,
1021 called_user_id: UserId,
1022 ) -> Option<proto::IncomingCall> {
1023 let pending_participant = room
1024 .pending_participants
1025 .iter()
1026 .find(|participant| participant.user_id == called_user_id.to_proto())?;
1027
1028 Some(proto::IncomingCall {
1029 room_id: room.id,
1030 calling_user_id: pending_participant.calling_user_id,
1031 participant_user_ids: room
1032 .participants
1033 .iter()
1034 .map(|participant| participant.user_id)
1035 .collect(),
1036 initial_project: room.participants.iter().find_map(|participant| {
1037 let initial_project_id = pending_participant.initial_project_id?;
1038 participant
1039 .projects
1040 .iter()
1041 .find(|project| project.id == initial_project_id)
1042 .cloned()
1043 }),
1044 })
1045 }
1046
1047 pub async fn call_failed(
1048 &self,
1049 room_id: RoomId,
1050 called_user_id: UserId,
1051 ) -> Result<proto::Room> {
1052 self.transact(|mut tx| async move {
1053 sqlx::query(
1054 "
1055 DELETE FROM room_participants
1056 WHERE room_id = $1 AND user_id = $2
1057 ",
1058 )
1059 .bind(room_id)
1060 .bind(called_user_id)
1061 .execute(&mut tx)
1062 .await?;
1063
1064 self.commit_room_transaction(room_id, tx).await
1065 })
1066 .await
1067 }
1068
1069 pub async fn decline_call(
1070 &self,
1071 expected_room_id: Option<RoomId>,
1072 user_id: UserId,
1073 ) -> Result<proto::Room> {
1074 self.transact(|mut tx| async move {
1075 let room_id = sqlx::query_scalar(
1076 "
1077 DELETE FROM room_participants
1078 WHERE user_id = $1 AND answering_connection_id IS NULL
1079 RETURNING room_id
1080 ",
1081 )
1082 .bind(user_id)
1083 .fetch_one(&mut tx)
1084 .await?;
1085 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1086 return Err(anyhow!("declining call on unexpected room"))?;
1087 }
1088
1089 self.commit_room_transaction(room_id, tx).await
1090 })
1091 .await
1092 }
1093
1094 pub async fn cancel_call(
1095 &self,
1096 expected_room_id: Option<RoomId>,
1097 calling_connection_id: ConnectionId,
1098 called_user_id: UserId,
1099 ) -> Result<proto::Room> {
1100 self.transact(|mut tx| async move {
1101 let room_id = sqlx::query_scalar(
1102 "
1103 DELETE FROM room_participants
1104 WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1105 RETURNING room_id
1106 ",
1107 )
1108 .bind(called_user_id)
1109 .bind(calling_connection_id.0 as i32)
1110 .fetch_one(&mut tx)
1111 .await?;
1112 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1113 return Err(anyhow!("canceling call on unexpected room"))?;
1114 }
1115
1116 self.commit_room_transaction(room_id, tx).await
1117 }).await
1118 }
1119
1120 pub async fn join_room(
1121 &self,
1122 room_id: RoomId,
1123 user_id: UserId,
1124 connection_id: ConnectionId,
1125 ) -> Result<proto::Room> {
1126 self.transact(|mut tx| async move {
1127 sqlx::query(
1128 "
1129 UPDATE room_participants
1130 SET answering_connection_id = $1
1131 WHERE room_id = $2 AND user_id = $3
1132 RETURNING 1
1133 ",
1134 )
1135 .bind(connection_id.0 as i32)
1136 .bind(room_id)
1137 .bind(user_id)
1138 .fetch_one(&mut tx)
1139 .await?;
1140 self.commit_room_transaction(room_id, tx).await
1141 })
1142 .await
1143 }
1144
1145 pub async fn leave_room(&self, connection_id: ConnectionId) -> Result<Option<LeftRoom>> {
1146 self.transact(|mut tx| async move {
1147 // Leave room.
1148 let room_id = sqlx::query_scalar::<_, RoomId>(
1149 "
1150 DELETE FROM room_participants
1151 WHERE answering_connection_id = $1
1152 RETURNING room_id
1153 ",
1154 )
1155 .bind(connection_id.0 as i32)
1156 .fetch_optional(&mut tx)
1157 .await?;
1158
1159 if let Some(room_id) = room_id {
1160 // Cancel pending calls initiated by the leaving user.
1161 let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1162 "
1163 DELETE FROM room_participants
1164 WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1165 RETURNING user_id
1166 ",
1167 )
1168 .bind(connection_id.0 as i32)
1169 .fetch_all(&mut tx)
1170 .await?;
1171
1172 let project_ids = sqlx::query_scalar::<_, ProjectId>(
1173 "
1174 SELECT project_id
1175 FROM project_collaborators
1176 WHERE connection_id = $1
1177 ",
1178 )
1179 .bind(connection_id.0 as i32)
1180 .fetch_all(&mut tx)
1181 .await?;
1182
1183 // Leave projects.
1184 let mut left_projects = HashMap::default();
1185 if !project_ids.is_empty() {
1186 let mut params = "?,".repeat(project_ids.len());
1187 params.pop();
1188 let query = format!(
1189 "
1190 SELECT *
1191 FROM project_collaborators
1192 WHERE project_id IN ({params})
1193 "
1194 );
1195 let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query);
1196 for project_id in project_ids {
1197 query = query.bind(project_id);
1198 }
1199
1200 let mut project_collaborators = query.fetch(&mut tx);
1201 while let Some(collaborator) = project_collaborators.next().await {
1202 let collaborator = collaborator?;
1203 let left_project =
1204 left_projects
1205 .entry(collaborator.project_id)
1206 .or_insert(LeftProject {
1207 id: collaborator.project_id,
1208 host_user_id: Default::default(),
1209 connection_ids: Default::default(),
1210 host_connection_id: Default::default(),
1211 });
1212
1213 let collaborator_connection_id =
1214 ConnectionId(collaborator.connection_id as u32);
1215 if collaborator_connection_id != connection_id {
1216 left_project.connection_ids.push(collaborator_connection_id);
1217 }
1218
1219 if collaborator.is_host {
1220 left_project.host_user_id = collaborator.user_id;
1221 left_project.host_connection_id =
1222 ConnectionId(collaborator.connection_id as u32);
1223 }
1224 }
1225 }
1226 sqlx::query(
1227 "
1228 DELETE FROM project_collaborators
1229 WHERE connection_id = $1
1230 ",
1231 )
1232 .bind(connection_id.0 as i32)
1233 .execute(&mut tx)
1234 .await?;
1235
1236 // Unshare projects.
1237 sqlx::query(
1238 "
1239 DELETE FROM projects
1240 WHERE room_id = $1 AND host_connection_id = $2
1241 ",
1242 )
1243 .bind(room_id)
1244 .bind(connection_id.0 as i32)
1245 .execute(&mut tx)
1246 .await?;
1247
1248 let room = self.commit_room_transaction(room_id, tx).await?;
1249 Ok(Some(LeftRoom {
1250 room,
1251 left_projects,
1252 canceled_calls_to_user_ids,
1253 }))
1254 } else {
1255 Ok(None)
1256 }
1257 })
1258 .await
1259 }
1260
1261 pub async fn update_room_participant_location(
1262 &self,
1263 room_id: RoomId,
1264 connection_id: ConnectionId,
1265 location: proto::ParticipantLocation,
1266 ) -> Result<proto::Room> {
1267 self.transact(|tx| async {
1268 let mut tx = tx;
1269 let location_kind;
1270 let location_project_id;
1271 match location
1272 .variant
1273 .as_ref()
1274 .ok_or_else(|| anyhow!("invalid location"))?
1275 {
1276 proto::participant_location::Variant::SharedProject(project) => {
1277 location_kind = 0;
1278 location_project_id = Some(ProjectId::from_proto(project.id));
1279 }
1280 proto::participant_location::Variant::UnsharedProject(_) => {
1281 location_kind = 1;
1282 location_project_id = None;
1283 }
1284 proto::participant_location::Variant::External(_) => {
1285 location_kind = 2;
1286 location_project_id = None;
1287 }
1288 }
1289
1290 sqlx::query(
1291 "
1292 UPDATE room_participants
1293 SET location_kind = $1, location_project_id = $2
1294 WHERE room_id = $3 AND answering_connection_id = $4
1295 RETURNING 1
1296 ",
1297 )
1298 .bind(location_kind)
1299 .bind(location_project_id)
1300 .bind(room_id)
1301 .bind(connection_id.0 as i32)
1302 .fetch_one(&mut tx)
1303 .await?;
1304
1305 self.commit_room_transaction(room_id, tx).await
1306 })
1307 .await
1308 }
1309
1310 async fn commit_room_transaction(
1311 &self,
1312 room_id: RoomId,
1313 mut tx: sqlx::Transaction<'_, D>,
1314 ) -> Result<proto::Room> {
1315 sqlx::query(
1316 "
1317 UPDATE rooms
1318 SET version = version + 1
1319 WHERE id = $1
1320 ",
1321 )
1322 .bind(room_id)
1323 .execute(&mut tx)
1324 .await?;
1325 let room = self.get_room(room_id, &mut tx).await?;
1326 tx.commit().await?;
1327
1328 Ok(room)
1329 }
1330
1331 async fn get_guest_connection_ids(
1332 &self,
1333 project_id: ProjectId,
1334 tx: &mut sqlx::Transaction<'_, D>,
1335 ) -> Result<Vec<ConnectionId>> {
1336 let mut guest_connection_ids = Vec::new();
1337 let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1338 "
1339 SELECT connection_id
1340 FROM project_collaborators
1341 WHERE project_id = $1 AND is_host = FALSE
1342 ",
1343 )
1344 .bind(project_id)
1345 .fetch(tx);
1346 while let Some(connection_id) = db_guest_connection_ids.next().await {
1347 guest_connection_ids.push(ConnectionId(connection_id? as u32));
1348 }
1349 Ok(guest_connection_ids)
1350 }
1351
1352 async fn get_room(
1353 &self,
1354 room_id: RoomId,
1355 tx: &mut sqlx::Transaction<'_, D>,
1356 ) -> Result<proto::Room> {
1357 let room: Room = sqlx::query_as(
1358 "
1359 SELECT *
1360 FROM rooms
1361 WHERE id = $1
1362 ",
1363 )
1364 .bind(room_id)
1365 .fetch_one(&mut *tx)
1366 .await?;
1367
1368 let mut db_participants =
1369 sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1370 "
1371 SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1372 FROM room_participants
1373 WHERE room_id = $1
1374 ",
1375 )
1376 .bind(room_id)
1377 .fetch(&mut *tx);
1378
1379 let mut participants = HashMap::default();
1380 let mut pending_participants = Vec::new();
1381 while let Some(participant) = db_participants.next().await {
1382 let (
1383 user_id,
1384 answering_connection_id,
1385 location_kind,
1386 location_project_id,
1387 calling_user_id,
1388 initial_project_id,
1389 ) = participant?;
1390 if let Some(answering_connection_id) = answering_connection_id {
1391 let location = match (location_kind, location_project_id) {
1392 (Some(0), Some(project_id)) => {
1393 Some(proto::participant_location::Variant::SharedProject(
1394 proto::participant_location::SharedProject {
1395 id: project_id.to_proto(),
1396 },
1397 ))
1398 }
1399 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
1400 Default::default(),
1401 )),
1402 _ => Some(proto::participant_location::Variant::External(
1403 Default::default(),
1404 )),
1405 };
1406 participants.insert(
1407 answering_connection_id,
1408 proto::Participant {
1409 user_id: user_id.to_proto(),
1410 peer_id: answering_connection_id as u32,
1411 projects: Default::default(),
1412 location: Some(proto::ParticipantLocation { variant: location }),
1413 },
1414 );
1415 } else {
1416 pending_participants.push(proto::PendingParticipant {
1417 user_id: user_id.to_proto(),
1418 calling_user_id: calling_user_id.to_proto(),
1419 initial_project_id: initial_project_id.map(|id| id.to_proto()),
1420 });
1421 }
1422 }
1423 drop(db_participants);
1424
1425 let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option<String>)>(
1426 "
1427 SELECT host_connection_id, projects.id, worktrees.root_name
1428 FROM projects
1429 LEFT JOIN worktrees ON projects.id = worktrees.project_id
1430 WHERE room_id = $1
1431 ",
1432 )
1433 .bind(room_id)
1434 .fetch(&mut *tx);
1435
1436 while let Some(row) = rows.next().await {
1437 let (connection_id, project_id, worktree_root_name) = row?;
1438 if let Some(participant) = participants.get_mut(&connection_id) {
1439 let project = if let Some(project) = participant
1440 .projects
1441 .iter_mut()
1442 .find(|project| project.id == project_id.to_proto())
1443 {
1444 project
1445 } else {
1446 participant.projects.push(proto::ParticipantProject {
1447 id: project_id.to_proto(),
1448 worktree_root_names: Default::default(),
1449 });
1450 participant.projects.last_mut().unwrap()
1451 };
1452 project.worktree_root_names.extend(worktree_root_name);
1453 }
1454 }
1455
1456 Ok(proto::Room {
1457 id: room.id.to_proto(),
1458 version: room.version as u64,
1459 live_kit_room: room.live_kit_room,
1460 participants: participants.into_values().collect(),
1461 pending_participants,
1462 })
1463 }
1464
1465 // projects
1466
1467 pub async fn share_project(
1468 &self,
1469 expected_room_id: RoomId,
1470 connection_id: ConnectionId,
1471 worktrees: &[proto::WorktreeMetadata],
1472 ) -> Result<(ProjectId, proto::Room)> {
1473 self.transact(|mut tx| async move {
1474 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1475 "
1476 SELECT room_id, user_id
1477 FROM room_participants
1478 WHERE answering_connection_id = $1
1479 ",
1480 )
1481 .bind(connection_id.0 as i32)
1482 .fetch_one(&mut tx)
1483 .await?;
1484 if room_id != expected_room_id {
1485 return Err(anyhow!("shared project on unexpected room"))?;
1486 }
1487
1488 let project_id: ProjectId = sqlx::query_scalar(
1489 "
1490 INSERT INTO projects (room_id, host_user_id, host_connection_id)
1491 VALUES ($1, $2, $3)
1492 RETURNING id
1493 ",
1494 )
1495 .bind(room_id)
1496 .bind(user_id)
1497 .bind(connection_id.0 as i32)
1498 .fetch_one(&mut tx)
1499 .await?;
1500
1501 if !worktrees.is_empty() {
1502 let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1503 params.pop();
1504 let query = format!(
1505 "
1506 INSERT INTO worktrees (
1507 project_id,
1508 id,
1509 root_name,
1510 abs_path,
1511 visible,
1512 scan_id,
1513 is_complete
1514 )
1515 VALUES {params}
1516 "
1517 );
1518
1519 let mut query = sqlx::query(&query);
1520 for worktree in worktrees {
1521 query = query
1522 .bind(project_id)
1523 .bind(worktree.id as i32)
1524 .bind(&worktree.root_name)
1525 .bind(&worktree.abs_path)
1526 .bind(worktree.visible)
1527 .bind(0)
1528 .bind(false);
1529 }
1530 query.execute(&mut tx).await?;
1531 }
1532
1533 sqlx::query(
1534 "
1535 INSERT INTO project_collaborators (
1536 project_id,
1537 connection_id,
1538 user_id,
1539 replica_id,
1540 is_host
1541 )
1542 VALUES ($1, $2, $3, $4, $5)
1543 ",
1544 )
1545 .bind(project_id)
1546 .bind(connection_id.0 as i32)
1547 .bind(user_id)
1548 .bind(0)
1549 .bind(true)
1550 .execute(&mut tx)
1551 .await?;
1552
1553 let room = self.commit_room_transaction(room_id, tx).await?;
1554 Ok((project_id, room))
1555 })
1556 .await
1557 }
1558
1559 pub async fn unshare_project(
1560 &self,
1561 project_id: ProjectId,
1562 connection_id: ConnectionId,
1563 ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1564 self.transact(|mut tx| async move {
1565 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1566 let room_id: RoomId = sqlx::query_scalar(
1567 "
1568 DELETE FROM projects
1569 WHERE id = $1 AND host_connection_id = $2
1570 RETURNING room_id
1571 ",
1572 )
1573 .bind(project_id)
1574 .bind(connection_id.0 as i32)
1575 .fetch_one(&mut tx)
1576 .await?;
1577 let room = self.commit_room_transaction(room_id, tx).await?;
1578
1579 Ok((room, guest_connection_ids))
1580 })
1581 .await
1582 }
1583
1584 pub async fn update_project(
1585 &self,
1586 project_id: ProjectId,
1587 connection_id: ConnectionId,
1588 worktrees: &[proto::WorktreeMetadata],
1589 ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1590 self.transact(|mut tx| async move {
1591 let room_id: RoomId = sqlx::query_scalar(
1592 "
1593 SELECT room_id
1594 FROM projects
1595 WHERE id = $1 AND host_connection_id = $2
1596 ",
1597 )
1598 .bind(project_id)
1599 .bind(connection_id.0 as i32)
1600 .fetch_one(&mut tx)
1601 .await?;
1602
1603 if !worktrees.is_empty() {
1604 let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1605 params.pop();
1606 let query = format!(
1607 "
1608 INSERT INTO worktrees (
1609 project_id,
1610 id,
1611 root_name,
1612 abs_path,
1613 visible,
1614 scan_id,
1615 is_complete
1616 )
1617 VALUES {params}
1618 ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1619 "
1620 );
1621
1622 let mut query = sqlx::query(&query);
1623 for worktree in worktrees {
1624 query = query
1625 .bind(project_id)
1626 .bind(worktree.id as i32)
1627 .bind(&worktree.root_name)
1628 .bind(&worktree.abs_path)
1629 .bind(worktree.visible)
1630 .bind(0)
1631 .bind(false)
1632 }
1633 query.execute(&mut tx).await?;
1634 }
1635
1636 let mut params = "?,".repeat(worktrees.len());
1637 if !worktrees.is_empty() {
1638 params.pop();
1639 }
1640 let query = format!(
1641 "
1642 DELETE FROM worktrees
1643 WHERE project_id = ? AND id NOT IN ({params})
1644 ",
1645 );
1646
1647 let mut query = sqlx::query(&query).bind(project_id);
1648 for worktree in worktrees {
1649 query = query.bind(WorktreeId(worktree.id as i32));
1650 }
1651 query.execute(&mut tx).await?;
1652
1653 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1654 let room = self.commit_room_transaction(room_id, tx).await?;
1655
1656 Ok((room, guest_connection_ids))
1657 })
1658 .await
1659 }
1660
1661 pub async fn update_worktree(
1662 &self,
1663 update: &proto::UpdateWorktree,
1664 connection_id: ConnectionId,
1665 ) -> Result<Vec<ConnectionId>> {
1666 self.transact(|mut tx| async move {
1667 let project_id = ProjectId::from_proto(update.project_id);
1668 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1669
1670 // Ensure the update comes from the host.
1671 sqlx::query(
1672 "
1673 SELECT 1
1674 FROM projects
1675 WHERE id = $1 AND host_connection_id = $2
1676 ",
1677 )
1678 .bind(project_id)
1679 .bind(connection_id.0 as i32)
1680 .fetch_one(&mut tx)
1681 .await?;
1682
1683 // Update metadata.
1684 sqlx::query(
1685 "
1686 UPDATE worktrees
1687 SET
1688 root_name = $1,
1689 scan_id = $2,
1690 is_complete = $3,
1691 abs_path = $4
1692 WHERE project_id = $5 AND id = $6
1693 RETURNING 1
1694 ",
1695 )
1696 .bind(&update.root_name)
1697 .bind(update.scan_id as i64)
1698 .bind(update.is_last_update)
1699 .bind(&update.abs_path)
1700 .bind(project_id)
1701 .bind(worktree_id)
1702 .fetch_one(&mut tx)
1703 .await?;
1704
1705 if !update.updated_entries.is_empty() {
1706 let mut params =
1707 "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len());
1708 params.pop();
1709
1710 let query = format!(
1711 "
1712 INSERT INTO worktree_entries (
1713 project_id,
1714 worktree_id,
1715 id,
1716 is_dir,
1717 path,
1718 inode,
1719 mtime_seconds,
1720 mtime_nanos,
1721 is_symlink,
1722 is_ignored
1723 )
1724 VALUES {params}
1725 ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET
1726 is_dir = excluded.is_dir,
1727 path = excluded.path,
1728 inode = excluded.inode,
1729 mtime_seconds = excluded.mtime_seconds,
1730 mtime_nanos = excluded.mtime_nanos,
1731 is_symlink = excluded.is_symlink,
1732 is_ignored = excluded.is_ignored
1733 "
1734 );
1735 let mut query = sqlx::query(&query);
1736 for entry in &update.updated_entries {
1737 let mtime = entry.mtime.clone().unwrap_or_default();
1738 query = query
1739 .bind(project_id)
1740 .bind(worktree_id)
1741 .bind(entry.id as i64)
1742 .bind(entry.is_dir)
1743 .bind(&entry.path)
1744 .bind(entry.inode as i64)
1745 .bind(mtime.seconds as i64)
1746 .bind(mtime.nanos as i32)
1747 .bind(entry.is_symlink)
1748 .bind(entry.is_ignored);
1749 }
1750 query.execute(&mut tx).await?;
1751 }
1752
1753 if !update.removed_entries.is_empty() {
1754 let mut params = "?,".repeat(update.removed_entries.len());
1755 params.pop();
1756 let query = format!(
1757 "
1758 DELETE FROM worktree_entries
1759 WHERE project_id = ? AND worktree_id = ? AND id IN ({params})
1760 "
1761 );
1762
1763 let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id);
1764 for entry_id in &update.removed_entries {
1765 query = query.bind(*entry_id as i64);
1766 }
1767 query.execute(&mut tx).await?;
1768 }
1769
1770 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1771 tx.commit().await?;
1772 Ok(connection_ids)
1773 })
1774 .await
1775 }
1776
1777 pub async fn update_diagnostic_summary(
1778 &self,
1779 update: &proto::UpdateDiagnosticSummary,
1780 connection_id: ConnectionId,
1781 ) -> Result<Vec<ConnectionId>> {
1782 self.transact(|mut tx| async {
1783 let project_id = ProjectId::from_proto(update.project_id);
1784 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1785 let summary = update
1786 .summary
1787 .as_ref()
1788 .ok_or_else(|| anyhow!("invalid summary"))?;
1789
1790 // Ensure the update comes from the host.
1791 sqlx::query(
1792 "
1793 SELECT 1
1794 FROM projects
1795 WHERE id = $1 AND host_connection_id = $2
1796 ",
1797 )
1798 .bind(project_id)
1799 .bind(connection_id.0 as i32)
1800 .fetch_one(&mut tx)
1801 .await?;
1802
1803 // Update summary.
1804 sqlx::query(
1805 "
1806 INSERT INTO worktree_diagnostic_summaries (
1807 project_id,
1808 worktree_id,
1809 path,
1810 language_server_id,
1811 error_count,
1812 warning_count
1813 )
1814 VALUES ($1, $2, $3, $4, $5, $6)
1815 ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET
1816 language_server_id = excluded.language_server_id,
1817 error_count = excluded.error_count,
1818 warning_count = excluded.warning_count
1819 ",
1820 )
1821 .bind(project_id)
1822 .bind(worktree_id)
1823 .bind(&summary.path)
1824 .bind(summary.language_server_id as i64)
1825 .bind(summary.error_count as i32)
1826 .bind(summary.warning_count as i32)
1827 .execute(&mut tx)
1828 .await?;
1829
1830 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1831 tx.commit().await?;
1832 Ok(connection_ids)
1833 })
1834 .await
1835 }
1836
1837 pub async fn start_language_server(
1838 &self,
1839 update: &proto::StartLanguageServer,
1840 connection_id: ConnectionId,
1841 ) -> Result<Vec<ConnectionId>> {
1842 self.transact(|mut tx| async {
1843 let project_id = ProjectId::from_proto(update.project_id);
1844 let server = update
1845 .server
1846 .as_ref()
1847 .ok_or_else(|| anyhow!("invalid language server"))?;
1848
1849 // Ensure the update comes from the host.
1850 sqlx::query(
1851 "
1852 SELECT 1
1853 FROM projects
1854 WHERE id = $1 AND host_connection_id = $2
1855 ",
1856 )
1857 .bind(project_id)
1858 .bind(connection_id.0 as i32)
1859 .fetch_one(&mut tx)
1860 .await?;
1861
1862 // Add the newly-started language server.
1863 sqlx::query(
1864 "
1865 INSERT INTO language_servers (project_id, id, name)
1866 VALUES ($1, $2, $3)
1867 ON CONFLICT (project_id, id) DO UPDATE SET
1868 name = excluded.name
1869 ",
1870 )
1871 .bind(project_id)
1872 .bind(server.id as i64)
1873 .bind(&server.name)
1874 .execute(&mut tx)
1875 .await?;
1876
1877 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1878 tx.commit().await?;
1879 Ok(connection_ids)
1880 })
1881 .await
1882 }
1883
1884 pub async fn join_project(
1885 &self,
1886 project_id: ProjectId,
1887 connection_id: ConnectionId,
1888 ) -> Result<(Project, ReplicaId)> {
1889 self.transact(|mut tx| async move {
1890 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1891 "
1892 SELECT room_id, user_id
1893 FROM room_participants
1894 WHERE answering_connection_id = $1
1895 ",
1896 )
1897 .bind(connection_id.0 as i32)
1898 .fetch_one(&mut tx)
1899 .await?;
1900
1901 // Ensure project id was shared on this room.
1902 sqlx::query(
1903 "
1904 SELECT 1
1905 FROM projects
1906 WHERE id = $1 AND room_id = $2
1907 ",
1908 )
1909 .bind(project_id)
1910 .bind(room_id)
1911 .fetch_one(&mut tx)
1912 .await?;
1913
1914 let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1915 "
1916 SELECT *
1917 FROM project_collaborators
1918 WHERE project_id = $1
1919 ",
1920 )
1921 .bind(project_id)
1922 .fetch_all(&mut tx)
1923 .await?;
1924 let replica_ids = collaborators
1925 .iter()
1926 .map(|c| c.replica_id)
1927 .collect::<HashSet<_>>();
1928 let mut replica_id = ReplicaId(1);
1929 while replica_ids.contains(&replica_id) {
1930 replica_id.0 += 1;
1931 }
1932 let new_collaborator = ProjectCollaborator {
1933 project_id,
1934 connection_id: connection_id.0 as i32,
1935 user_id,
1936 replica_id,
1937 is_host: false,
1938 };
1939
1940 sqlx::query(
1941 "
1942 INSERT INTO project_collaborators (
1943 project_id,
1944 connection_id,
1945 user_id,
1946 replica_id,
1947 is_host
1948 )
1949 VALUES ($1, $2, $3, $4, $5)
1950 ",
1951 )
1952 .bind(new_collaborator.project_id)
1953 .bind(new_collaborator.connection_id)
1954 .bind(new_collaborator.user_id)
1955 .bind(new_collaborator.replica_id)
1956 .bind(new_collaborator.is_host)
1957 .execute(&mut tx)
1958 .await?;
1959 collaborators.push(new_collaborator);
1960
1961 let worktree_rows = sqlx::query_as::<_, WorktreeRow>(
1962 "
1963 SELECT *
1964 FROM worktrees
1965 WHERE project_id = $1
1966 ",
1967 )
1968 .bind(project_id)
1969 .fetch_all(&mut tx)
1970 .await?;
1971 let mut worktrees = worktree_rows
1972 .into_iter()
1973 .map(|worktree_row| {
1974 (
1975 worktree_row.id,
1976 Worktree {
1977 id: worktree_row.id,
1978 abs_path: worktree_row.abs_path,
1979 root_name: worktree_row.root_name,
1980 visible: worktree_row.visible,
1981 entries: Default::default(),
1982 diagnostic_summaries: Default::default(),
1983 scan_id: worktree_row.scan_id as u64,
1984 is_complete: worktree_row.is_complete,
1985 },
1986 )
1987 })
1988 .collect::<BTreeMap<_, _>>();
1989
1990 // Populate worktree entries.
1991 {
1992 let mut entries = sqlx::query_as::<_, WorktreeEntry>(
1993 "
1994 SELECT *
1995 FROM worktree_entries
1996 WHERE project_id = $1
1997 ",
1998 )
1999 .bind(project_id)
2000 .fetch(&mut tx);
2001 while let Some(entry) = entries.next().await {
2002 let entry = entry?;
2003 if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) {
2004 worktree.entries.push(proto::Entry {
2005 id: entry.id as u64,
2006 is_dir: entry.is_dir,
2007 path: entry.path,
2008 inode: entry.inode as u64,
2009 mtime: Some(proto::Timestamp {
2010 seconds: entry.mtime_seconds as u64,
2011 nanos: entry.mtime_nanos as u32,
2012 }),
2013 is_symlink: entry.is_symlink,
2014 is_ignored: entry.is_ignored,
2015 });
2016 }
2017 }
2018 }
2019
2020 // Populate worktree diagnostic summaries.
2021 {
2022 let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(
2023 "
2024 SELECT *
2025 FROM worktree_diagnostic_summaries
2026 WHERE project_id = $1
2027 ",
2028 )
2029 .bind(project_id)
2030 .fetch(&mut tx);
2031 while let Some(summary) = summaries.next().await {
2032 let summary = summary?;
2033 if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) {
2034 worktree
2035 .diagnostic_summaries
2036 .push(proto::DiagnosticSummary {
2037 path: summary.path,
2038 language_server_id: summary.language_server_id as u64,
2039 error_count: summary.error_count as u32,
2040 warning_count: summary.warning_count as u32,
2041 });
2042 }
2043 }
2044 }
2045
2046 // Populate language servers.
2047 let language_servers = sqlx::query_as::<_, LanguageServer>(
2048 "
2049 SELECT *
2050 FROM language_servers
2051 WHERE project_id = $1
2052 ",
2053 )
2054 .bind(project_id)
2055 .fetch_all(&mut tx)
2056 .await?;
2057
2058 tx.commit().await?;
2059 Ok((
2060 Project {
2061 collaborators,
2062 worktrees,
2063 language_servers: language_servers
2064 .into_iter()
2065 .map(|language_server| proto::LanguageServer {
2066 id: language_server.id.to_proto(),
2067 name: language_server.name,
2068 })
2069 .collect(),
2070 },
2071 replica_id as ReplicaId,
2072 ))
2073 })
2074 .await
2075 }
2076
2077 pub async fn leave_project(
2078 &self,
2079 project_id: ProjectId,
2080 connection_id: ConnectionId,
2081 ) -> Result<LeftProject> {
2082 self.transact(|mut tx| async move {
2083 let result = sqlx::query(
2084 "
2085 DELETE FROM project_collaborators
2086 WHERE project_id = $1 AND connection_id = $2
2087 ",
2088 )
2089 .bind(project_id)
2090 .bind(connection_id.0 as i32)
2091 .execute(&mut tx)
2092 .await?;
2093
2094 if result.rows_affected() == 0 {
2095 Err(anyhow!("not a collaborator on this project"))?;
2096 }
2097
2098 let connection_ids = sqlx::query_scalar::<_, i32>(
2099 "
2100 SELECT connection_id
2101 FROM project_collaborators
2102 WHERE project_id = $1
2103 ",
2104 )
2105 .bind(project_id)
2106 .fetch_all(&mut tx)
2107 .await?
2108 .into_iter()
2109 .map(|id| ConnectionId(id as u32))
2110 .collect();
2111
2112 let (host_user_id, host_connection_id) = sqlx::query_as::<_, (i32, i32)>(
2113 "
2114 SELECT host_user_id, host_connection_id
2115 FROM projects
2116 WHERE id = $1
2117 ",
2118 )
2119 .bind(project_id)
2120 .fetch_one(&mut tx)
2121 .await?;
2122
2123 tx.commit().await?;
2124
2125 Ok(LeftProject {
2126 id: project_id,
2127 host_user_id: UserId(host_user_id),
2128 host_connection_id: ConnectionId(host_connection_id as u32),
2129 connection_ids,
2130 })
2131 })
2132 .await
2133 }
2134
2135 pub async fn project_collaborators(
2136 &self,
2137 project_id: ProjectId,
2138 connection_id: ConnectionId,
2139 ) -> Result<Vec<ProjectCollaborator>> {
2140 self.transact(|mut tx| async move {
2141 let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2142 "
2143 SELECT *
2144 FROM project_collaborators
2145 WHERE project_id = $1
2146 ",
2147 )
2148 .bind(project_id)
2149 .fetch_all(&mut tx)
2150 .await?;
2151
2152 if collaborators
2153 .iter()
2154 .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
2155 {
2156 Ok(collaborators)
2157 } else {
2158 Err(anyhow!("no such project"))?
2159 }
2160 })
2161 .await
2162 }
2163
2164 pub async fn project_connection_ids(
2165 &self,
2166 project_id: ProjectId,
2167 connection_id: ConnectionId,
2168 ) -> Result<HashSet<ConnectionId>> {
2169 self.transact(|mut tx| async move {
2170 let connection_ids = sqlx::query_scalar::<_, i32>(
2171 "
2172 SELECT connection_id
2173 FROM project_collaborators
2174 WHERE project_id = $1
2175 ",
2176 )
2177 .bind(project_id)
2178 .fetch_all(&mut tx)
2179 .await?;
2180
2181 if connection_ids.contains(&(connection_id.0 as i32)) {
2182 Ok(connection_ids
2183 .into_iter()
2184 .map(|connection_id| ConnectionId(connection_id as u32))
2185 .collect())
2186 } else {
2187 Err(anyhow!("no such project"))?
2188 }
2189 })
2190 .await
2191 }
2192
2193 // contacts
2194
2195 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
2196 self.transact(|mut tx| async move {
2197 let query = "
2198 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
2199 FROM contacts
2200 LEFT JOIN room_participants ON room_participants.user_id = $1
2201 WHERE user_id_a = $1 OR user_id_b = $1;
2202 ";
2203
2204 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
2205 .bind(user_id)
2206 .fetch(&mut tx);
2207
2208 let mut contacts = Vec::new();
2209 while let Some(row) = rows.next().await {
2210 let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
2211 if user_id_a == user_id {
2212 if accepted {
2213 contacts.push(Contact::Accepted {
2214 user_id: user_id_b,
2215 should_notify: should_notify && a_to_b,
2216 busy
2217 });
2218 } else if a_to_b {
2219 contacts.push(Contact::Outgoing { user_id: user_id_b })
2220 } else {
2221 contacts.push(Contact::Incoming {
2222 user_id: user_id_b,
2223 should_notify,
2224 });
2225 }
2226 } else if accepted {
2227 contacts.push(Contact::Accepted {
2228 user_id: user_id_a,
2229 should_notify: should_notify && !a_to_b,
2230 busy
2231 });
2232 } else if a_to_b {
2233 contacts.push(Contact::Incoming {
2234 user_id: user_id_a,
2235 should_notify,
2236 });
2237 } else {
2238 contacts.push(Contact::Outgoing { user_id: user_id_a });
2239 }
2240 }
2241
2242 contacts.sort_unstable_by_key(|contact| contact.user_id());
2243
2244 Ok(contacts)
2245 })
2246 .await
2247 }
2248
2249 pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
2250 self.transact(|mut tx| async move {
2251 Ok(sqlx::query_scalar::<_, i32>(
2252 "
2253 SELECT 1
2254 FROM room_participants
2255 WHERE room_participants.user_id = $1
2256 ",
2257 )
2258 .bind(user_id)
2259 .fetch_optional(&mut tx)
2260 .await?
2261 .is_some())
2262 })
2263 .await
2264 }
2265
2266 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
2267 self.transact(|mut tx| async move {
2268 let (id_a, id_b) = if user_id_1 < user_id_2 {
2269 (user_id_1, user_id_2)
2270 } else {
2271 (user_id_2, user_id_1)
2272 };
2273
2274 let query = "
2275 SELECT 1 FROM contacts
2276 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
2277 LIMIT 1
2278 ";
2279 Ok(sqlx::query_scalar::<_, i32>(query)
2280 .bind(id_a.0)
2281 .bind(id_b.0)
2282 .fetch_optional(&mut tx)
2283 .await?
2284 .is_some())
2285 })
2286 .await
2287 }
2288
2289 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
2290 self.transact(|mut tx| async move {
2291 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
2292 (sender_id, receiver_id, true)
2293 } else {
2294 (receiver_id, sender_id, false)
2295 };
2296 let query = "
2297 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
2298 VALUES ($1, $2, $3, FALSE, TRUE)
2299 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
2300 SET
2301 accepted = TRUE,
2302 should_notify = FALSE
2303 WHERE
2304 NOT contacts.accepted AND
2305 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
2306 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
2307 ";
2308 let result = sqlx::query(query)
2309 .bind(id_a.0)
2310 .bind(id_b.0)
2311 .bind(a_to_b)
2312 .execute(&mut tx)
2313 .await?;
2314
2315 if result.rows_affected() == 1 {
2316 tx.commit().await?;
2317 Ok(())
2318 } else {
2319 Err(anyhow!("contact already requested"))?
2320 }
2321 }).await
2322 }
2323
2324 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2325 self.transact(|mut tx| async move {
2326 let (id_a, id_b) = if responder_id < requester_id {
2327 (responder_id, requester_id)
2328 } else {
2329 (requester_id, responder_id)
2330 };
2331 let query = "
2332 DELETE FROM contacts
2333 WHERE user_id_a = $1 AND user_id_b = $2;
2334 ";
2335 let result = sqlx::query(query)
2336 .bind(id_a.0)
2337 .bind(id_b.0)
2338 .execute(&mut tx)
2339 .await?;
2340
2341 if result.rows_affected() == 1 {
2342 tx.commit().await?;
2343 Ok(())
2344 } else {
2345 Err(anyhow!("no such contact"))?
2346 }
2347 })
2348 .await
2349 }
2350
2351 pub async fn dismiss_contact_notification(
2352 &self,
2353 user_id: UserId,
2354 contact_user_id: UserId,
2355 ) -> Result<()> {
2356 self.transact(|mut tx| async move {
2357 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
2358 (user_id, contact_user_id, true)
2359 } else {
2360 (contact_user_id, user_id, false)
2361 };
2362
2363 let query = "
2364 UPDATE contacts
2365 SET should_notify = FALSE
2366 WHERE
2367 user_id_a = $1 AND user_id_b = $2 AND
2368 (
2369 (a_to_b = $3 AND accepted) OR
2370 (a_to_b != $3 AND NOT accepted)
2371 );
2372 ";
2373
2374 let result = sqlx::query(query)
2375 .bind(id_a.0)
2376 .bind(id_b.0)
2377 .bind(a_to_b)
2378 .execute(&mut tx)
2379 .await?;
2380
2381 if result.rows_affected() == 0 {
2382 Err(anyhow!("no such contact request"))?
2383 } else {
2384 tx.commit().await?;
2385 Ok(())
2386 }
2387 })
2388 .await
2389 }
2390
2391 pub async fn respond_to_contact_request(
2392 &self,
2393 responder_id: UserId,
2394 requester_id: UserId,
2395 accept: bool,
2396 ) -> Result<()> {
2397 self.transact(|mut tx| async move {
2398 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
2399 (responder_id, requester_id, false)
2400 } else {
2401 (requester_id, responder_id, true)
2402 };
2403 let result = if accept {
2404 let query = "
2405 UPDATE contacts
2406 SET accepted = TRUE, should_notify = TRUE
2407 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
2408 ";
2409 sqlx::query(query)
2410 .bind(id_a.0)
2411 .bind(id_b.0)
2412 .bind(a_to_b)
2413 .execute(&mut tx)
2414 .await?
2415 } else {
2416 let query = "
2417 DELETE FROM contacts
2418 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
2419 ";
2420 sqlx::query(query)
2421 .bind(id_a.0)
2422 .bind(id_b.0)
2423 .bind(a_to_b)
2424 .execute(&mut tx)
2425 .await?
2426 };
2427 if result.rows_affected() == 1 {
2428 tx.commit().await?;
2429 Ok(())
2430 } else {
2431 Err(anyhow!("no such contact request"))?
2432 }
2433 })
2434 .await
2435 }
2436
2437 // access tokens
2438
2439 pub async fn create_access_token_hash(
2440 &self,
2441 user_id: UserId,
2442 access_token_hash: &str,
2443 max_access_token_count: usize,
2444 ) -> Result<()> {
2445 self.transact(|tx| async {
2446 let mut tx = tx;
2447 let insert_query = "
2448 INSERT INTO access_tokens (user_id, hash)
2449 VALUES ($1, $2);
2450 ";
2451 let cleanup_query = "
2452 DELETE FROM access_tokens
2453 WHERE id IN (
2454 SELECT id from access_tokens
2455 WHERE user_id = $1
2456 ORDER BY id DESC
2457 LIMIT 10000
2458 OFFSET $3
2459 )
2460 ";
2461
2462 sqlx::query(insert_query)
2463 .bind(user_id.0)
2464 .bind(access_token_hash)
2465 .execute(&mut tx)
2466 .await?;
2467 sqlx::query(cleanup_query)
2468 .bind(user_id.0)
2469 .bind(access_token_hash)
2470 .bind(max_access_token_count as i32)
2471 .execute(&mut tx)
2472 .await?;
2473 Ok(tx.commit().await?)
2474 })
2475 .await
2476 }
2477
2478 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
2479 self.transact(|mut tx| async move {
2480 let query = "
2481 SELECT hash
2482 FROM access_tokens
2483 WHERE user_id = $1
2484 ORDER BY id DESC
2485 ";
2486 Ok(sqlx::query_scalar(query)
2487 .bind(user_id.0)
2488 .fetch_all(&mut tx)
2489 .await?)
2490 })
2491 .await
2492 }
2493
2494 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
2495 where
2496 F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
2497 Fut: Send + Future<Output = Result<T>>,
2498 {
2499 let body = async {
2500 loop {
2501 let tx = self.begin_transaction().await?;
2502 match f(tx).await {
2503 Ok(result) => return Ok(result),
2504 Err(error) => match error {
2505 Error::Database(error)
2506 if error
2507 .as_database_error()
2508 .and_then(|error| error.code())
2509 .as_deref()
2510 == Some("hey") =>
2511 {
2512 // Retry (don't break the loop)
2513 }
2514 error @ _ => return Err(error),
2515 },
2516 }
2517 }
2518 };
2519
2520 #[cfg(test)]
2521 {
2522 if let Some(background) = self.background.as_ref() {
2523 background.simulate_random_delay().await;
2524 }
2525
2526 let result = self.runtime.as_ref().unwrap().block_on(body);
2527
2528 if let Some(background) = self.background.as_ref() {
2529 background.simulate_random_delay().await;
2530 }
2531
2532 result
2533 }
2534
2535 #[cfg(not(test))]
2536 {
2537 body.await
2538 }
2539 }
2540}
2541
2542macro_rules! id_type {
2543 ($name:ident) => {
2544 #[derive(
2545 Clone,
2546 Copy,
2547 Debug,
2548 Default,
2549 PartialEq,
2550 Eq,
2551 PartialOrd,
2552 Ord,
2553 Hash,
2554 sqlx::Type,
2555 Serialize,
2556 Deserialize,
2557 )]
2558 #[sqlx(transparent)]
2559 #[serde(transparent)]
2560 pub struct $name(pub i32);
2561
2562 impl $name {
2563 #[allow(unused)]
2564 pub const MAX: Self = Self(i32::MAX);
2565
2566 #[allow(unused)]
2567 pub fn from_proto(value: u64) -> Self {
2568 Self(value as i32)
2569 }
2570
2571 #[allow(unused)]
2572 pub fn to_proto(self) -> u64 {
2573 self.0 as u64
2574 }
2575 }
2576
2577 impl std::fmt::Display for $name {
2578 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2579 self.0.fmt(f)
2580 }
2581 }
2582 };
2583}
2584
2585id_type!(UserId);
2586#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2587pub struct User {
2588 pub id: UserId,
2589 pub github_login: String,
2590 pub github_user_id: Option<i32>,
2591 pub email_address: Option<String>,
2592 pub admin: bool,
2593 pub invite_code: Option<String>,
2594 pub invite_count: i32,
2595 pub connected_once: bool,
2596}
2597
2598id_type!(RoomId);
2599#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2600pub struct Room {
2601 pub id: RoomId,
2602 pub version: i32,
2603 pub live_kit_room: String,
2604}
2605
2606id_type!(ProjectId);
2607pub struct Project {
2608 pub collaborators: Vec<ProjectCollaborator>,
2609 pub worktrees: BTreeMap<WorktreeId, Worktree>,
2610 pub language_servers: Vec<proto::LanguageServer>,
2611}
2612
2613id_type!(ReplicaId);
2614#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2615pub struct ProjectCollaborator {
2616 pub project_id: ProjectId,
2617 pub connection_id: i32,
2618 pub user_id: UserId,
2619 pub replica_id: ReplicaId,
2620 pub is_host: bool,
2621}
2622
2623id_type!(WorktreeId);
2624#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2625struct WorktreeRow {
2626 pub id: WorktreeId,
2627 pub abs_path: String,
2628 pub root_name: String,
2629 pub visible: bool,
2630 pub scan_id: i64,
2631 pub is_complete: bool,
2632}
2633
2634pub struct Worktree {
2635 pub id: WorktreeId,
2636 pub abs_path: String,
2637 pub root_name: String,
2638 pub visible: bool,
2639 pub entries: Vec<proto::Entry>,
2640 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
2641 pub scan_id: u64,
2642 pub is_complete: bool,
2643}
2644
2645#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2646struct WorktreeEntry {
2647 id: i64,
2648 worktree_id: WorktreeId,
2649 is_dir: bool,
2650 path: String,
2651 inode: i64,
2652 mtime_seconds: i64,
2653 mtime_nanos: i32,
2654 is_symlink: bool,
2655 is_ignored: bool,
2656}
2657
2658#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2659struct WorktreeDiagnosticSummary {
2660 worktree_id: WorktreeId,
2661 path: String,
2662 language_server_id: i64,
2663 error_count: i32,
2664 warning_count: i32,
2665}
2666
2667id_type!(LanguageServerId);
2668#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2669struct LanguageServer {
2670 id: LanguageServerId,
2671 name: String,
2672}
2673
2674pub struct LeftProject {
2675 pub id: ProjectId,
2676 pub host_user_id: UserId,
2677 pub host_connection_id: ConnectionId,
2678 pub connection_ids: Vec<ConnectionId>,
2679}
2680
2681pub struct LeftRoom {
2682 pub room: proto::Room,
2683 pub left_projects: HashMap<ProjectId, LeftProject>,
2684 pub canceled_calls_to_user_ids: Vec<UserId>,
2685}
2686
2687#[derive(Clone, Debug, PartialEq, Eq)]
2688pub enum Contact {
2689 Accepted {
2690 user_id: UserId,
2691 should_notify: bool,
2692 busy: bool,
2693 },
2694 Outgoing {
2695 user_id: UserId,
2696 },
2697 Incoming {
2698 user_id: UserId,
2699 should_notify: bool,
2700 },
2701}
2702
2703impl Contact {
2704 pub fn user_id(&self) -> UserId {
2705 match self {
2706 Contact::Accepted { user_id, .. } => *user_id,
2707 Contact::Outgoing { user_id } => *user_id,
2708 Contact::Incoming { user_id, .. } => *user_id,
2709 }
2710 }
2711}
2712
2713#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2714pub struct IncomingContactRequest {
2715 pub requester_id: UserId,
2716 pub should_notify: bool,
2717}
2718
2719#[derive(Clone, Deserialize)]
2720pub struct Signup {
2721 pub email_address: String,
2722 pub platform_mac: bool,
2723 pub platform_windows: bool,
2724 pub platform_linux: bool,
2725 pub editor_features: Vec<String>,
2726 pub programming_languages: Vec<String>,
2727 pub device_id: Option<String>,
2728}
2729
2730#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2731pub struct WaitlistSummary {
2732 #[sqlx(default)]
2733 pub count: i64,
2734 #[sqlx(default)]
2735 pub linux_count: i64,
2736 #[sqlx(default)]
2737 pub mac_count: i64,
2738 #[sqlx(default)]
2739 pub windows_count: i64,
2740 #[sqlx(default)]
2741 pub unknown_count: i64,
2742}
2743
2744#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2745pub struct Invite {
2746 pub email_address: String,
2747 pub email_confirmation_code: String,
2748}
2749
2750#[derive(Debug, Serialize, Deserialize)]
2751pub struct NewUserParams {
2752 pub github_login: String,
2753 pub github_user_id: i32,
2754 pub invite_count: i32,
2755}
2756
2757#[derive(Debug)]
2758pub struct NewUserResult {
2759 pub user_id: UserId,
2760 pub metrics_id: String,
2761 pub inviting_user_id: Option<UserId>,
2762 pub signup_device_id: Option<String>,
2763}
2764
2765fn random_invite_code() -> String {
2766 nanoid::nanoid!(16)
2767}
2768
2769fn random_email_confirmation_code() -> String {
2770 nanoid::nanoid!(64)
2771}
2772
2773#[cfg(test)]
2774pub use test::*;
2775
2776#[cfg(test)]
2777mod test {
2778 use super::*;
2779 use gpui::executor::Background;
2780 use lazy_static::lazy_static;
2781 use parking_lot::Mutex;
2782 use rand::prelude::*;
2783 use sqlx::migrate::MigrateDatabase;
2784 use std::sync::Arc;
2785
2786 pub struct SqliteTestDb {
2787 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2788 pub conn: sqlx::sqlite::SqliteConnection,
2789 }
2790
2791 pub struct PostgresTestDb {
2792 pub db: Option<Arc<Db<sqlx::Postgres>>>,
2793 pub url: String,
2794 }
2795
2796 impl SqliteTestDb {
2797 pub fn new(background: Arc<Background>) -> Self {
2798 let mut rng = StdRng::from_entropy();
2799 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2800 let runtime = tokio::runtime::Builder::new_current_thread()
2801 .enable_io()
2802 .enable_time()
2803 .build()
2804 .unwrap();
2805
2806 let (mut db, conn) = runtime.block_on(async {
2807 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2808 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2809 db.migrate(migrations_path.as_ref(), false).await.unwrap();
2810 let conn = db.pool.acquire().await.unwrap().detach();
2811 (db, conn)
2812 });
2813
2814 db.background = Some(background);
2815 db.runtime = Some(runtime);
2816
2817 Self {
2818 db: Some(Arc::new(db)),
2819 conn,
2820 }
2821 }
2822
2823 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2824 self.db.as_ref().unwrap()
2825 }
2826 }
2827
2828 impl PostgresTestDb {
2829 pub fn new(background: Arc<Background>) -> Self {
2830 lazy_static! {
2831 static ref LOCK: Mutex<()> = Mutex::new(());
2832 }
2833
2834 let _guard = LOCK.lock();
2835 let mut rng = StdRng::from_entropy();
2836 let url = format!(
2837 "postgres://postgres@localhost/zed-test-{}",
2838 rng.gen::<u128>()
2839 );
2840 let runtime = tokio::runtime::Builder::new_current_thread()
2841 .enable_io()
2842 .enable_time()
2843 .build()
2844 .unwrap();
2845
2846 let mut db = runtime.block_on(async {
2847 sqlx::Postgres::create_database(&url)
2848 .await
2849 .expect("failed to create test db");
2850 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2851 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2852 db.migrate(Path::new(migrations_path), false).await.unwrap();
2853 db
2854 });
2855
2856 db.background = Some(background);
2857 db.runtime = Some(runtime);
2858
2859 Self {
2860 db: Some(Arc::new(db)),
2861 url,
2862 }
2863 }
2864
2865 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2866 self.db.as_ref().unwrap()
2867 }
2868 }
2869
2870 impl Drop for PostgresTestDb {
2871 fn drop(&mut self) {
2872 let db = self.db.take().unwrap();
2873 db.teardown(&self.url);
2874 }
2875 }
2876}