1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::{BTreeMap, HashMap, HashSet};
5use futures::{future::BoxFuture, FutureExt, StreamExt};
6use rpc::{proto, ConnectionId};
7use serde::{Deserialize, Serialize};
8use sqlx::{
9 migrate::{Migrate as _, Migration, MigrationSource},
10 types::Uuid,
11 FromRow,
12};
13use std::{future::Future, path::Path, time::Duration};
14use time::{OffsetDateTime, PrimitiveDateTime};
15
16#[cfg(test)]
17pub type DefaultDb = Db<sqlx::Sqlite>;
18
19#[cfg(not(test))]
20pub type DefaultDb = Db<sqlx::Postgres>;
21
22pub struct Db<D: sqlx::Database> {
23 pool: sqlx::Pool<D>,
24 #[cfg(test)]
25 background: Option<std::sync::Arc<gpui::executor::Background>>,
26 #[cfg(test)]
27 runtime: Option<tokio::runtime::Runtime>,
28}
29
30pub trait BeginTransaction: Send + Sync {
31 type Database: sqlx::Database;
32
33 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
34}
35
36// In Postgres, serializable transactions are opt-in
37impl BeginTransaction for Db<sqlx::Postgres> {
38 type Database = sqlx::Postgres;
39
40 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
41 async move {
42 let mut tx = self.pool.begin().await?;
43 sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
44 .await?;
45 Ok(tx)
46 }
47 .boxed()
48 }
49}
50
51// In Sqlite, transactions are inherently serializable.
52impl BeginTransaction for Db<sqlx::Sqlite> {
53 type Database = sqlx::Sqlite;
54
55 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
56 async move { Ok(self.pool.begin().await?) }.boxed()
57 }
58}
59
60pub trait RowsAffected {
61 fn rows_affected(&self) -> u64;
62}
63
64#[cfg(test)]
65impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
66 fn rows_affected(&self) -> u64 {
67 self.rows_affected()
68 }
69}
70
71impl RowsAffected for sqlx::postgres::PgQueryResult {
72 fn rows_affected(&self) -> u64 {
73 self.rows_affected()
74 }
75}
76
77#[cfg(test)]
78impl Db<sqlx::Sqlite> {
79 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
80 use std::str::FromStr as _;
81 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
82 .unwrap()
83 .create_if_missing(true)
84 .shared_cache(true);
85 let pool = sqlx::sqlite::SqlitePoolOptions::new()
86 .min_connections(2)
87 .max_connections(max_connections)
88 .connect_with(options)
89 .await?;
90 Ok(Self {
91 pool,
92 background: None,
93 runtime: None,
94 })
95 }
96
97 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
98 self.transact(|tx| async {
99 let mut tx = tx;
100 let query = "
101 SELECT users.*
102 FROM users
103 WHERE users.id IN (SELECT value from json_each($1))
104 ";
105 Ok(sqlx::query_as(query)
106 .bind(&serde_json::json!(ids))
107 .fetch_all(&mut tx)
108 .await?)
109 })
110 .await
111 }
112
113 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
114 self.transact(|mut tx| async move {
115 let query = "
116 SELECT metrics_id
117 FROM users
118 WHERE id = $1
119 ";
120 Ok(sqlx::query_scalar(query)
121 .bind(id)
122 .fetch_one(&mut tx)
123 .await?)
124 })
125 .await
126 }
127
128 pub async fn create_user(
129 &self,
130 email_address: &str,
131 admin: bool,
132 params: NewUserParams,
133 ) -> Result<NewUserResult> {
134 self.transact(|mut tx| async {
135 let query = "
136 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
137 VALUES ($1, $2, $3, $4, $5)
138 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
139 RETURNING id, metrics_id
140 ";
141
142 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
143 .bind(email_address)
144 .bind(¶ms.github_login)
145 .bind(¶ms.github_user_id)
146 .bind(admin)
147 .bind(Uuid::new_v4().to_string())
148 .fetch_one(&mut tx)
149 .await?;
150 tx.commit().await?;
151 Ok(NewUserResult {
152 user_id,
153 metrics_id,
154 signup_device_id: None,
155 inviting_user_id: None,
156 })
157 })
158 .await
159 }
160
161 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
162 unimplemented!()
163 }
164
165 pub async fn create_user_from_invite(
166 &self,
167 _invite: &Invite,
168 _user: NewUserParams,
169 ) -> Result<Option<NewUserResult>> {
170 unimplemented!()
171 }
172
173 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
174 unimplemented!()
175 }
176
177 pub async fn create_invite_from_code(
178 &self,
179 _code: &str,
180 _email_address: &str,
181 _device_id: Option<&str>,
182 ) -> Result<Invite> {
183 unimplemented!()
184 }
185
186 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
187 unimplemented!()
188 }
189}
190
191impl Db<sqlx::Postgres> {
192 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
193 let pool = sqlx::postgres::PgPoolOptions::new()
194 .max_connections(max_connections)
195 .connect(url)
196 .await?;
197 Ok(Self {
198 pool,
199 #[cfg(test)]
200 background: None,
201 #[cfg(test)]
202 runtime: None,
203 })
204 }
205
206 #[cfg(test)]
207 pub fn teardown(&self, url: &str) {
208 self.runtime.as_ref().unwrap().block_on(async {
209 use util::ResultExt;
210 let query = "
211 SELECT pg_terminate_backend(pg_stat_activity.pid)
212 FROM pg_stat_activity
213 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
214 ";
215 sqlx::query(query).execute(&self.pool).await.log_err();
216 self.pool.close().await;
217 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
218 .await
219 .log_err();
220 })
221 }
222
223 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
224 self.transact(|tx| async {
225 let mut tx = tx;
226 let like_string = Self::fuzzy_like_string(name_query);
227 let query = "
228 SELECT users.*
229 FROM users
230 WHERE github_login ILIKE $1
231 ORDER BY github_login <-> $2
232 LIMIT $3
233 ";
234 Ok(sqlx::query_as(query)
235 .bind(like_string)
236 .bind(name_query)
237 .bind(limit as i32)
238 .fetch_all(&mut tx)
239 .await?)
240 })
241 .await
242 }
243
244 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
245 let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
246 self.transact(|tx| async {
247 let mut tx = tx;
248 let query = "
249 SELECT users.*
250 FROM users
251 WHERE users.id = ANY ($1)
252 ";
253 Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
254 })
255 .await
256 }
257
258 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
259 self.transact(|mut tx| async move {
260 let query = "
261 SELECT metrics_id::text
262 FROM users
263 WHERE id = $1
264 ";
265 Ok(sqlx::query_scalar(query)
266 .bind(id)
267 .fetch_one(&mut tx)
268 .await?)
269 })
270 .await
271 }
272
273 pub async fn create_user(
274 &self,
275 email_address: &str,
276 admin: bool,
277 params: NewUserParams,
278 ) -> Result<NewUserResult> {
279 self.transact(|mut tx| async {
280 let query = "
281 INSERT INTO users (email_address, github_login, github_user_id, admin)
282 VALUES ($1, $2, $3, $4)
283 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
284 RETURNING id, metrics_id::text
285 ";
286
287 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
288 .bind(email_address)
289 .bind(¶ms.github_login)
290 .bind(params.github_user_id)
291 .bind(admin)
292 .fetch_one(&mut tx)
293 .await?;
294 tx.commit().await?;
295
296 Ok(NewUserResult {
297 user_id,
298 metrics_id,
299 signup_device_id: None,
300 inviting_user_id: None,
301 })
302 })
303 .await
304 }
305
306 pub async fn create_user_from_invite(
307 &self,
308 invite: &Invite,
309 user: NewUserParams,
310 ) -> Result<Option<NewUserResult>> {
311 self.transact(|mut tx| async {
312 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
313 i32,
314 Option<UserId>,
315 Option<UserId>,
316 Option<String>,
317 ) = sqlx::query_as(
318 "
319 SELECT id, user_id, inviting_user_id, device_id
320 FROM signups
321 WHERE
322 email_address = $1 AND
323 email_confirmation_code = $2
324 ",
325 )
326 .bind(&invite.email_address)
327 .bind(&invite.email_confirmation_code)
328 .fetch_optional(&mut tx)
329 .await?
330 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
331
332 if existing_user_id.is_some() {
333 return Ok(None);
334 }
335
336 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
337 "
338 INSERT INTO users
339 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
340 VALUES
341 ($1, $2, $3, FALSE, $4, $5)
342 ON CONFLICT (github_login) DO UPDATE SET
343 email_address = excluded.email_address,
344 github_user_id = excluded.github_user_id,
345 admin = excluded.admin
346 RETURNING id, metrics_id::text
347 ",
348 )
349 .bind(&invite.email_address)
350 .bind(&user.github_login)
351 .bind(&user.github_user_id)
352 .bind(&user.invite_count)
353 .bind(random_invite_code())
354 .fetch_one(&mut tx)
355 .await?;
356
357 sqlx::query(
358 "
359 UPDATE signups
360 SET user_id = $1
361 WHERE id = $2
362 ",
363 )
364 .bind(&user_id)
365 .bind(&signup_id)
366 .execute(&mut tx)
367 .await?;
368
369 if let Some(inviting_user_id) = inviting_user_id {
370 let id: Option<UserId> = sqlx::query_scalar(
371 "
372 UPDATE users
373 SET invite_count = invite_count - 1
374 WHERE id = $1 AND invite_count > 0
375 RETURNING id
376 ",
377 )
378 .bind(&inviting_user_id)
379 .fetch_optional(&mut tx)
380 .await?;
381
382 if id.is_none() {
383 Err(Error::Http(
384 StatusCode::UNAUTHORIZED,
385 "no invites remaining".to_string(),
386 ))?;
387 }
388
389 sqlx::query(
390 "
391 INSERT INTO contacts
392 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
393 VALUES
394 ($1, $2, TRUE, TRUE, TRUE)
395 ON CONFLICT DO NOTHING
396 ",
397 )
398 .bind(inviting_user_id)
399 .bind(user_id)
400 .execute(&mut tx)
401 .await?;
402 }
403
404 tx.commit().await?;
405 Ok(Some(NewUserResult {
406 user_id,
407 metrics_id,
408 inviting_user_id,
409 signup_device_id,
410 }))
411 })
412 .await
413 }
414
415 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
416 self.transact(|mut tx| async {
417 sqlx::query(
418 "
419 INSERT INTO signups
420 (
421 email_address,
422 email_confirmation_code,
423 email_confirmation_sent,
424 platform_linux,
425 platform_mac,
426 platform_windows,
427 platform_unknown,
428 editor_features,
429 programming_languages,
430 device_id
431 )
432 VALUES
433 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
434 RETURNING id
435 ",
436 )
437 .bind(&signup.email_address)
438 .bind(&random_email_confirmation_code())
439 .bind(&signup.platform_linux)
440 .bind(&signup.platform_mac)
441 .bind(&signup.platform_windows)
442 .bind(&signup.editor_features)
443 .bind(&signup.programming_languages)
444 .bind(&signup.device_id)
445 .execute(&mut tx)
446 .await?;
447 tx.commit().await?;
448 Ok(())
449 })
450 .await
451 }
452
453 pub async fn create_invite_from_code(
454 &self,
455 code: &str,
456 email_address: &str,
457 device_id: Option<&str>,
458 ) -> Result<Invite> {
459 self.transact(|mut tx| async {
460 let existing_user: Option<UserId> = sqlx::query_scalar(
461 "
462 SELECT id
463 FROM users
464 WHERE email_address = $1
465 ",
466 )
467 .bind(email_address)
468 .fetch_optional(&mut tx)
469 .await?;
470 if existing_user.is_some() {
471 Err(anyhow!("email address is already in use"))?;
472 }
473
474 let row: Option<(UserId, i32)> = sqlx::query_as(
475 "
476 SELECT id, invite_count
477 FROM users
478 WHERE invite_code = $1
479 ",
480 )
481 .bind(code)
482 .fetch_optional(&mut tx)
483 .await?;
484
485 let (inviter_id, invite_count) = match row {
486 Some(row) => row,
487 None => Err(Error::Http(
488 StatusCode::NOT_FOUND,
489 "invite code not found".to_string(),
490 ))?,
491 };
492
493 if invite_count == 0 {
494 Err(Error::Http(
495 StatusCode::UNAUTHORIZED,
496 "no invites remaining".to_string(),
497 ))?;
498 }
499
500 let email_confirmation_code: String = sqlx::query_scalar(
501 "
502 INSERT INTO signups
503 (
504 email_address,
505 email_confirmation_code,
506 email_confirmation_sent,
507 inviting_user_id,
508 platform_linux,
509 platform_mac,
510 platform_windows,
511 platform_unknown,
512 device_id
513 )
514 VALUES
515 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
516 ON CONFLICT (email_address)
517 DO UPDATE SET
518 inviting_user_id = excluded.inviting_user_id
519 RETURNING email_confirmation_code
520 ",
521 )
522 .bind(&email_address)
523 .bind(&random_email_confirmation_code())
524 .bind(&inviter_id)
525 .bind(&device_id)
526 .fetch_one(&mut tx)
527 .await?;
528
529 tx.commit().await?;
530
531 Ok(Invite {
532 email_address: email_address.into(),
533 email_confirmation_code,
534 })
535 })
536 .await
537 }
538
539 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
540 self.transact(|mut tx| async {
541 let emails = invites
542 .iter()
543 .map(|s| s.email_address.as_str())
544 .collect::<Vec<_>>();
545 sqlx::query(
546 "
547 UPDATE signups
548 SET email_confirmation_sent = TRUE
549 WHERE email_address = ANY ($1)
550 ",
551 )
552 .bind(&emails)
553 .execute(&mut tx)
554 .await?;
555 tx.commit().await?;
556 Ok(())
557 })
558 .await
559 }
560}
561
562impl<D> Db<D>
563where
564 Self: BeginTransaction<Database = D>,
565 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
566 D::Connection: sqlx::migrate::Migrate,
567 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
568 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
569 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
570 D::QueryResult: RowsAffected,
571 String: sqlx::Type<D>,
572 i32: sqlx::Type<D>,
573 i64: sqlx::Type<D>,
574 bool: sqlx::Type<D>,
575 str: sqlx::Type<D>,
576 Uuid: sqlx::Type<D>,
577 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
578 OffsetDateTime: sqlx::Type<D>,
579 PrimitiveDateTime: sqlx::Type<D>,
580 usize: sqlx::ColumnIndex<D::Row>,
581 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
582 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
583 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
584 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
585 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
586 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
587 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
588 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
589 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
590 for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
591 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
592 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
593 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
594{
595 pub async fn migrate(
596 &self,
597 migrations_path: &Path,
598 ignore_checksum_mismatch: bool,
599 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
600 let migrations = MigrationSource::resolve(migrations_path)
601 .await
602 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
603
604 let mut conn = self.pool.acquire().await?;
605
606 conn.ensure_migrations_table().await?;
607 let applied_migrations: HashMap<_, _> = conn
608 .list_applied_migrations()
609 .await?
610 .into_iter()
611 .map(|m| (m.version, m))
612 .collect();
613
614 let mut new_migrations = Vec::new();
615 for migration in migrations {
616 match applied_migrations.get(&migration.version) {
617 Some(applied_migration) => {
618 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
619 {
620 Err(anyhow!(
621 "checksum mismatch for applied migration {}",
622 migration.description
623 ))?;
624 }
625 }
626 None => {
627 let elapsed = conn.apply(&migration).await?;
628 new_migrations.push((migration, elapsed));
629 }
630 }
631 }
632
633 Ok(new_migrations)
634 }
635
636 pub fn fuzzy_like_string(string: &str) -> String {
637 let mut result = String::with_capacity(string.len() * 2 + 1);
638 for c in string.chars() {
639 if c.is_alphanumeric() {
640 result.push('%');
641 result.push(c);
642 }
643 }
644 result.push('%');
645 result
646 }
647
648 // users
649
650 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
651 self.transact(|tx| async {
652 let mut tx = tx;
653 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
654 Ok(sqlx::query_as(query)
655 .bind(limit as i32)
656 .bind((page * limit) as i32)
657 .fetch_all(&mut tx)
658 .await?)
659 })
660 .await
661 }
662
663 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
664 self.transact(|tx| async {
665 let mut tx = tx;
666 let query = "
667 SELECT users.*
668 FROM users
669 WHERE id = $1
670 LIMIT 1
671 ";
672 Ok(sqlx::query_as(query)
673 .bind(&id)
674 .fetch_optional(&mut tx)
675 .await?)
676 })
677 .await
678 }
679
680 pub async fn get_users_with_no_invites(
681 &self,
682 invited_by_another_user: bool,
683 ) -> Result<Vec<User>> {
684 self.transact(|tx| async {
685 let mut tx = tx;
686 let query = format!(
687 "
688 SELECT users.*
689 FROM users
690 WHERE invite_count = 0
691 AND inviter_id IS{} NULL
692 ",
693 if invited_by_another_user { " NOT" } else { "" }
694 );
695
696 Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
697 })
698 .await
699 }
700
701 pub async fn get_user_by_github_account(
702 &self,
703 github_login: &str,
704 github_user_id: Option<i32>,
705 ) -> Result<Option<User>> {
706 self.transact(|tx| async {
707 let mut tx = tx;
708 if let Some(github_user_id) = github_user_id {
709 let mut user = sqlx::query_as::<_, User>(
710 "
711 UPDATE users
712 SET github_login = $1
713 WHERE github_user_id = $2
714 RETURNING *
715 ",
716 )
717 .bind(github_login)
718 .bind(github_user_id)
719 .fetch_optional(&mut tx)
720 .await?;
721
722 if user.is_none() {
723 user = sqlx::query_as::<_, User>(
724 "
725 UPDATE users
726 SET github_user_id = $1
727 WHERE github_login = $2
728 RETURNING *
729 ",
730 )
731 .bind(github_user_id)
732 .bind(github_login)
733 .fetch_optional(&mut tx)
734 .await?;
735 }
736
737 Ok(user)
738 } else {
739 let user = sqlx::query_as(
740 "
741 SELECT * FROM users
742 WHERE github_login = $1
743 LIMIT 1
744 ",
745 )
746 .bind(github_login)
747 .fetch_optional(&mut tx)
748 .await?;
749 Ok(user)
750 }
751 })
752 .await
753 }
754
755 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
756 self.transact(|mut tx| async {
757 let query = "UPDATE users SET admin = $1 WHERE id = $2";
758 sqlx::query(query)
759 .bind(is_admin)
760 .bind(id.0)
761 .execute(&mut tx)
762 .await?;
763 tx.commit().await?;
764 Ok(())
765 })
766 .await
767 }
768
769 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
770 self.transact(|mut tx| async move {
771 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
772 sqlx::query(query)
773 .bind(connected_once)
774 .bind(id.0)
775 .execute(&mut tx)
776 .await?;
777 tx.commit().await?;
778 Ok(())
779 })
780 .await
781 }
782
783 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
784 self.transact(|mut tx| async move {
785 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
786 sqlx::query(query)
787 .bind(id.0)
788 .execute(&mut tx)
789 .await
790 .map(drop)?;
791 let query = "DELETE FROM users WHERE id = $1;";
792 sqlx::query(query).bind(id.0).execute(&mut tx).await?;
793 tx.commit().await?;
794 Ok(())
795 })
796 .await
797 }
798
799 // signups
800
801 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
802 self.transact(|mut tx| async move {
803 Ok(sqlx::query_as(
804 "
805 SELECT
806 COUNT(*) as count,
807 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
808 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
809 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
810 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
811 FROM (
812 SELECT *
813 FROM signups
814 WHERE
815 NOT email_confirmation_sent
816 ) AS unsent
817 ",
818 )
819 .fetch_one(&mut tx)
820 .await?)
821 })
822 .await
823 }
824
825 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
826 self.transact(|mut tx| async move {
827 Ok(sqlx::query_as(
828 "
829 SELECT
830 email_address, email_confirmation_code
831 FROM signups
832 WHERE
833 NOT email_confirmation_sent AND
834 (platform_mac OR platform_unknown)
835 LIMIT $1
836 ",
837 )
838 .bind(count as i32)
839 .fetch_all(&mut tx)
840 .await?)
841 })
842 .await
843 }
844
845 // invite codes
846
847 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
848 self.transact(|mut tx| async move {
849 if count > 0 {
850 sqlx::query(
851 "
852 UPDATE users
853 SET invite_code = $1
854 WHERE id = $2 AND invite_code IS NULL
855 ",
856 )
857 .bind(random_invite_code())
858 .bind(id)
859 .execute(&mut tx)
860 .await?;
861 }
862
863 sqlx::query(
864 "
865 UPDATE users
866 SET invite_count = $1
867 WHERE id = $2
868 ",
869 )
870 .bind(count as i32)
871 .bind(id)
872 .execute(&mut tx)
873 .await?;
874 tx.commit().await?;
875 Ok(())
876 })
877 .await
878 }
879
880 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
881 self.transact(|mut tx| async move {
882 let result: Option<(String, i32)> = sqlx::query_as(
883 "
884 SELECT invite_code, invite_count
885 FROM users
886 WHERE id = $1 AND invite_code IS NOT NULL
887 ",
888 )
889 .bind(id)
890 .fetch_optional(&mut tx)
891 .await?;
892 if let Some((code, count)) = result {
893 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
894 } else {
895 Ok(None)
896 }
897 })
898 .await
899 }
900
901 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
902 self.transact(|tx| async {
903 let mut tx = tx;
904 sqlx::query_as(
905 "
906 SELECT *
907 FROM users
908 WHERE invite_code = $1
909 ",
910 )
911 .bind(code)
912 .fetch_optional(&mut tx)
913 .await?
914 .ok_or_else(|| {
915 Error::Http(
916 StatusCode::NOT_FOUND,
917 "that invite code does not exist".to_string(),
918 )
919 })
920 })
921 .await
922 }
923
924 pub async fn create_room(
925 &self,
926 user_id: UserId,
927 connection_id: ConnectionId,
928 ) -> Result<proto::Room> {
929 self.transact(|mut tx| async move {
930 let live_kit_room = nanoid::nanoid!(30);
931 let room_id = sqlx::query_scalar(
932 "
933 INSERT INTO rooms (live_kit_room, version)
934 VALUES ($1, $2)
935 RETURNING id
936 ",
937 )
938 .bind(&live_kit_room)
939 .bind(0)
940 .fetch_one(&mut tx)
941 .await
942 .map(RoomId)?;
943
944 sqlx::query(
945 "
946 INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
947 VALUES ($1, $2, $3, $4, $5)
948 ",
949 )
950 .bind(room_id)
951 .bind(user_id)
952 .bind(connection_id.0 as i32)
953 .bind(user_id)
954 .bind(connection_id.0 as i32)
955 .execute(&mut tx)
956 .await?;
957
958 self.commit_room_transaction(room_id, tx).await
959 }).await
960 }
961
962 pub async fn call(
963 &self,
964 room_id: RoomId,
965 calling_user_id: UserId,
966 calling_connection_id: ConnectionId,
967 called_user_id: UserId,
968 initial_project_id: Option<ProjectId>,
969 ) -> Result<(proto::Room, proto::IncomingCall)> {
970 self.transact(|mut tx| async move {
971 sqlx::query(
972 "
973 INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id)
974 VALUES ($1, $2, $3, $4, $5)
975 ",
976 )
977 .bind(room_id)
978 .bind(called_user_id)
979 .bind(calling_user_id)
980 .bind(calling_connection_id.0 as i32)
981 .bind(initial_project_id)
982 .execute(&mut tx)
983 .await?;
984
985 let room = self.commit_room_transaction(room_id, tx).await?;
986 let incoming_call = Self::build_incoming_call(&room, called_user_id)
987 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
988 Ok((room, incoming_call))
989 }).await
990 }
991
992 pub async fn incoming_call_for_user(
993 &self,
994 user_id: UserId,
995 ) -> Result<Option<proto::IncomingCall>> {
996 self.transact(|mut tx| async move {
997 let room_id = sqlx::query_scalar::<_, RoomId>(
998 "
999 SELECT room_id
1000 FROM room_participants
1001 WHERE user_id = $1 AND answering_connection_id IS NULL
1002 ",
1003 )
1004 .bind(user_id)
1005 .fetch_optional(&mut tx)
1006 .await?;
1007
1008 if let Some(room_id) = room_id {
1009 let room = self.get_room(room_id, &mut tx).await?;
1010 Ok(Self::build_incoming_call(&room, user_id))
1011 } else {
1012 Ok(None)
1013 }
1014 })
1015 .await
1016 }
1017
1018 fn build_incoming_call(
1019 room: &proto::Room,
1020 called_user_id: UserId,
1021 ) -> Option<proto::IncomingCall> {
1022 let pending_participant = room
1023 .pending_participants
1024 .iter()
1025 .find(|participant| participant.user_id == called_user_id.to_proto())?;
1026
1027 Some(proto::IncomingCall {
1028 room_id: room.id,
1029 calling_user_id: pending_participant.calling_user_id,
1030 participant_user_ids: room
1031 .participants
1032 .iter()
1033 .map(|participant| participant.user_id)
1034 .collect(),
1035 initial_project: room.participants.iter().find_map(|participant| {
1036 let initial_project_id = pending_participant.initial_project_id?;
1037 participant
1038 .projects
1039 .iter()
1040 .find(|project| project.id == initial_project_id)
1041 .cloned()
1042 }),
1043 })
1044 }
1045
1046 pub async fn call_failed(
1047 &self,
1048 room_id: RoomId,
1049 called_user_id: UserId,
1050 ) -> Result<proto::Room> {
1051 self.transact(|mut tx| async move {
1052 sqlx::query(
1053 "
1054 DELETE FROM room_participants
1055 WHERE room_id = $1 AND user_id = $2
1056 ",
1057 )
1058 .bind(room_id)
1059 .bind(called_user_id)
1060 .execute(&mut tx)
1061 .await?;
1062
1063 self.commit_room_transaction(room_id, tx).await
1064 })
1065 .await
1066 }
1067
1068 pub async fn decline_call(
1069 &self,
1070 expected_room_id: Option<RoomId>,
1071 user_id: UserId,
1072 ) -> Result<proto::Room> {
1073 self.transact(|mut tx| async move {
1074 let room_id = sqlx::query_scalar(
1075 "
1076 DELETE FROM room_participants
1077 WHERE user_id = $1 AND answering_connection_id IS NULL
1078 RETURNING room_id
1079 ",
1080 )
1081 .bind(user_id)
1082 .fetch_one(&mut tx)
1083 .await?;
1084 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1085 return Err(anyhow!("declining call on unexpected room"))?;
1086 }
1087
1088 self.commit_room_transaction(room_id, tx).await
1089 })
1090 .await
1091 }
1092
1093 pub async fn cancel_call(
1094 &self,
1095 expected_room_id: Option<RoomId>,
1096 calling_connection_id: ConnectionId,
1097 called_user_id: UserId,
1098 ) -> Result<proto::Room> {
1099 self.transact(|mut tx| async move {
1100 let room_id = sqlx::query_scalar(
1101 "
1102 DELETE FROM room_participants
1103 WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1104 RETURNING room_id
1105 ",
1106 )
1107 .bind(called_user_id)
1108 .bind(calling_connection_id.0 as i32)
1109 .fetch_one(&mut tx)
1110 .await?;
1111 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1112 return Err(anyhow!("canceling call on unexpected room"))?;
1113 }
1114
1115 self.commit_room_transaction(room_id, tx).await
1116 }).await
1117 }
1118
1119 pub async fn join_room(
1120 &self,
1121 room_id: RoomId,
1122 user_id: UserId,
1123 connection_id: ConnectionId,
1124 ) -> Result<proto::Room> {
1125 self.transact(|mut tx| async move {
1126 sqlx::query(
1127 "
1128 UPDATE room_participants
1129 SET answering_connection_id = $1
1130 WHERE room_id = $2 AND user_id = $3
1131 RETURNING 1
1132 ",
1133 )
1134 .bind(connection_id.0 as i32)
1135 .bind(room_id)
1136 .bind(user_id)
1137 .fetch_one(&mut tx)
1138 .await?;
1139 self.commit_room_transaction(room_id, tx).await
1140 })
1141 .await
1142 }
1143
1144 pub async fn leave_room_for_connection(
1145 &self,
1146 connection_id: ConnectionId,
1147 ) -> Result<Option<LeftRoom>> {
1148 self.transact(|mut tx| async move {
1149 // Leave room.
1150 let room_id = sqlx::query_scalar::<_, RoomId>(
1151 "
1152 DELETE FROM room_participants
1153 WHERE answering_connection_id = $1
1154 RETURNING room_id
1155 ",
1156 )
1157 .bind(connection_id.0 as i32)
1158 .fetch_optional(&mut tx)
1159 .await?;
1160
1161 if let Some(room_id) = room_id {
1162 // Cancel pending calls initiated by the leaving user.
1163 let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1164 "
1165 DELETE FROM room_participants
1166 WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1167 RETURNING user_id
1168 ",
1169 )
1170 .bind(connection_id.0 as i32)
1171 .fetch_all(&mut tx)
1172 .await?;
1173
1174 let project_ids = sqlx::query_scalar::<_, ProjectId>(
1175 "
1176 SELECT project_id
1177 FROM project_collaborators
1178 WHERE connection_id = $1
1179 ",
1180 )
1181 .bind(connection_id.0 as i32)
1182 .fetch_all(&mut tx)
1183 .await?;
1184
1185 // Leave projects.
1186 let mut left_projects = HashMap::default();
1187 if !project_ids.is_empty() {
1188 let mut params = "?,".repeat(project_ids.len());
1189 params.pop();
1190 let query = format!(
1191 "
1192 SELECT *
1193 FROM project_collaborators
1194 WHERE project_id IN ({params})
1195 "
1196 );
1197 let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query);
1198 for project_id in project_ids {
1199 query = query.bind(project_id);
1200 }
1201
1202 let mut project_collaborators = query.fetch(&mut tx);
1203 while let Some(collaborator) = project_collaborators.next().await {
1204 let collaborator = collaborator?;
1205 let left_project =
1206 left_projects
1207 .entry(collaborator.project_id)
1208 .or_insert(LeftProject {
1209 id: collaborator.project_id,
1210 host_user_id: Default::default(),
1211 connection_ids: Default::default(),
1212 host_connection_id: Default::default(),
1213 });
1214
1215 let collaborator_connection_id =
1216 ConnectionId(collaborator.connection_id as u32);
1217 if collaborator_connection_id != connection_id {
1218 left_project.connection_ids.push(collaborator_connection_id);
1219 }
1220
1221 if collaborator.is_host {
1222 left_project.host_user_id = collaborator.user_id;
1223 left_project.host_connection_id =
1224 ConnectionId(collaborator.connection_id as u32);
1225 }
1226 }
1227 }
1228 sqlx::query(
1229 "
1230 DELETE FROM project_collaborators
1231 WHERE connection_id = $1
1232 ",
1233 )
1234 .bind(connection_id.0 as i32)
1235 .execute(&mut tx)
1236 .await?;
1237
1238 // Unshare projects.
1239 sqlx::query(
1240 "
1241 DELETE FROM projects
1242 WHERE room_id = $1 AND host_connection_id = $2
1243 ",
1244 )
1245 .bind(room_id)
1246 .bind(connection_id.0 as i32)
1247 .execute(&mut tx)
1248 .await?;
1249
1250 let room = self.commit_room_transaction(room_id, tx).await?;
1251 Ok(Some(LeftRoom {
1252 room,
1253 left_projects,
1254 canceled_calls_to_user_ids,
1255 }))
1256 } else {
1257 Ok(None)
1258 }
1259 })
1260 .await
1261 }
1262
1263 pub async fn update_room_participant_location(
1264 &self,
1265 room_id: RoomId,
1266 connection_id: ConnectionId,
1267 location: proto::ParticipantLocation,
1268 ) -> Result<proto::Room> {
1269 self.transact(|tx| async {
1270 let mut tx = tx;
1271 let location_kind;
1272 let location_project_id;
1273 match location
1274 .variant
1275 .as_ref()
1276 .ok_or_else(|| anyhow!("invalid location"))?
1277 {
1278 proto::participant_location::Variant::SharedProject(project) => {
1279 location_kind = 0;
1280 location_project_id = Some(ProjectId::from_proto(project.id));
1281 }
1282 proto::participant_location::Variant::UnsharedProject(_) => {
1283 location_kind = 1;
1284 location_project_id = None;
1285 }
1286 proto::participant_location::Variant::External(_) => {
1287 location_kind = 2;
1288 location_project_id = None;
1289 }
1290 }
1291
1292 sqlx::query(
1293 "
1294 UPDATE room_participants
1295 SET location_kind = $1, location_project_id = $2
1296 WHERE room_id = $3 AND answering_connection_id = $4
1297 RETURNING 1
1298 ",
1299 )
1300 .bind(location_kind)
1301 .bind(location_project_id)
1302 .bind(room_id)
1303 .bind(connection_id.0 as i32)
1304 .fetch_one(&mut tx)
1305 .await?;
1306
1307 self.commit_room_transaction(room_id, tx).await
1308 })
1309 .await
1310 }
1311
1312 async fn commit_room_transaction(
1313 &self,
1314 room_id: RoomId,
1315 mut tx: sqlx::Transaction<'_, D>,
1316 ) -> Result<proto::Room> {
1317 sqlx::query(
1318 "
1319 UPDATE rooms
1320 SET version = version + 1
1321 WHERE id = $1
1322 ",
1323 )
1324 .bind(room_id)
1325 .execute(&mut tx)
1326 .await?;
1327 let room = self.get_room(room_id, &mut tx).await?;
1328 tx.commit().await?;
1329
1330 Ok(room)
1331 }
1332
1333 async fn get_guest_connection_ids(
1334 &self,
1335 project_id: ProjectId,
1336 tx: &mut sqlx::Transaction<'_, D>,
1337 ) -> Result<Vec<ConnectionId>> {
1338 let mut guest_connection_ids = Vec::new();
1339 let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1340 "
1341 SELECT connection_id
1342 FROM project_collaborators
1343 WHERE project_id = $1 AND is_host = FALSE
1344 ",
1345 )
1346 .bind(project_id)
1347 .fetch(tx);
1348 while let Some(connection_id) = db_guest_connection_ids.next().await {
1349 guest_connection_ids.push(ConnectionId(connection_id? as u32));
1350 }
1351 Ok(guest_connection_ids)
1352 }
1353
1354 async fn get_room(
1355 &self,
1356 room_id: RoomId,
1357 tx: &mut sqlx::Transaction<'_, D>,
1358 ) -> Result<proto::Room> {
1359 let room: Room = sqlx::query_as(
1360 "
1361 SELECT *
1362 FROM rooms
1363 WHERE id = $1
1364 ",
1365 )
1366 .bind(room_id)
1367 .fetch_one(&mut *tx)
1368 .await?;
1369
1370 let mut db_participants =
1371 sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1372 "
1373 SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1374 FROM room_participants
1375 WHERE room_id = $1
1376 ",
1377 )
1378 .bind(room_id)
1379 .fetch(&mut *tx);
1380
1381 let mut participants = HashMap::default();
1382 let mut pending_participants = Vec::new();
1383 while let Some(participant) = db_participants.next().await {
1384 let (
1385 user_id,
1386 answering_connection_id,
1387 location_kind,
1388 location_project_id,
1389 calling_user_id,
1390 initial_project_id,
1391 ) = participant?;
1392 if let Some(answering_connection_id) = answering_connection_id {
1393 let location = match (location_kind, location_project_id) {
1394 (Some(0), Some(project_id)) => {
1395 Some(proto::participant_location::Variant::SharedProject(
1396 proto::participant_location::SharedProject {
1397 id: project_id.to_proto(),
1398 },
1399 ))
1400 }
1401 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
1402 Default::default(),
1403 )),
1404 _ => Some(proto::participant_location::Variant::External(
1405 Default::default(),
1406 )),
1407 };
1408 participants.insert(
1409 answering_connection_id,
1410 proto::Participant {
1411 user_id: user_id.to_proto(),
1412 peer_id: answering_connection_id as u32,
1413 projects: Default::default(),
1414 location: Some(proto::ParticipantLocation { variant: location }),
1415 },
1416 );
1417 } else {
1418 pending_participants.push(proto::PendingParticipant {
1419 user_id: user_id.to_proto(),
1420 calling_user_id: calling_user_id.to_proto(),
1421 initial_project_id: initial_project_id.map(|id| id.to_proto()),
1422 });
1423 }
1424 }
1425 drop(db_participants);
1426
1427 let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option<String>)>(
1428 "
1429 SELECT host_connection_id, projects.id, worktrees.root_name
1430 FROM projects
1431 LEFT JOIN worktrees ON projects.id = worktrees.project_id
1432 WHERE room_id = $1
1433 ",
1434 )
1435 .bind(room_id)
1436 .fetch(&mut *tx);
1437
1438 while let Some(row) = rows.next().await {
1439 let (connection_id, project_id, worktree_root_name) = row?;
1440 if let Some(participant) = participants.get_mut(&connection_id) {
1441 let project = if let Some(project) = participant
1442 .projects
1443 .iter_mut()
1444 .find(|project| project.id == project_id.to_proto())
1445 {
1446 project
1447 } else {
1448 participant.projects.push(proto::ParticipantProject {
1449 id: project_id.to_proto(),
1450 worktree_root_names: Default::default(),
1451 });
1452 participant.projects.last_mut().unwrap()
1453 };
1454 project.worktree_root_names.extend(worktree_root_name);
1455 }
1456 }
1457
1458 Ok(proto::Room {
1459 id: room.id.to_proto(),
1460 version: room.version as u64,
1461 live_kit_room: room.live_kit_room,
1462 participants: participants.into_values().collect(),
1463 pending_participants,
1464 })
1465 }
1466
1467 // projects
1468
1469 pub async fn share_project(
1470 &self,
1471 expected_room_id: RoomId,
1472 connection_id: ConnectionId,
1473 worktrees: &[proto::WorktreeMetadata],
1474 ) -> Result<(ProjectId, proto::Room)> {
1475 self.transact(|mut tx| async move {
1476 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1477 "
1478 SELECT room_id, user_id
1479 FROM room_participants
1480 WHERE answering_connection_id = $1
1481 ",
1482 )
1483 .bind(connection_id.0 as i32)
1484 .fetch_one(&mut tx)
1485 .await?;
1486 if room_id != expected_room_id {
1487 return Err(anyhow!("shared project on unexpected room"))?;
1488 }
1489
1490 let project_id: ProjectId = sqlx::query_scalar(
1491 "
1492 INSERT INTO projects (room_id, host_user_id, host_connection_id)
1493 VALUES ($1, $2, $3)
1494 RETURNING id
1495 ",
1496 )
1497 .bind(room_id)
1498 .bind(user_id)
1499 .bind(connection_id.0 as i32)
1500 .fetch_one(&mut tx)
1501 .await
1502 .unwrap();
1503
1504 if !worktrees.is_empty() {
1505 let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1506 params.pop();
1507 let query = format!(
1508 "
1509 INSERT INTO worktrees (
1510 project_id,
1511 id,
1512 root_name,
1513 abs_path,
1514 visible,
1515 scan_id,
1516 is_complete
1517 )
1518 VALUES {params}
1519 "
1520 );
1521
1522 let mut query = sqlx::query(&query);
1523 for worktree in worktrees {
1524 query = query
1525 .bind(project_id)
1526 .bind(worktree.id as i32)
1527 .bind(&worktree.root_name)
1528 .bind(&worktree.abs_path)
1529 .bind(worktree.visible)
1530 .bind(0)
1531 .bind(false);
1532 }
1533 query.execute(&mut tx).await.unwrap();
1534 }
1535
1536 sqlx::query(
1537 "
1538 INSERT INTO project_collaborators (
1539 project_id,
1540 connection_id,
1541 user_id,
1542 replica_id,
1543 is_host
1544 )
1545 VALUES ($1, $2, $3, $4, $5)
1546 ",
1547 )
1548 .bind(project_id)
1549 .bind(connection_id.0 as i32)
1550 .bind(user_id)
1551 .bind(0)
1552 .bind(true)
1553 .execute(&mut tx)
1554 .await
1555 .unwrap();
1556
1557 let room = self.commit_room_transaction(room_id, tx).await?;
1558 Ok((project_id, room))
1559 })
1560 .await
1561 }
1562
1563 pub async fn unshare_project(
1564 &self,
1565 project_id: ProjectId,
1566 connection_id: ConnectionId,
1567 ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1568 self.transact(|mut tx| async move {
1569 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1570 let room_id: RoomId = sqlx::query_scalar(
1571 "
1572 DELETE FROM projects
1573 WHERE id = $1 AND host_connection_id = $2
1574 RETURNING room_id
1575 ",
1576 )
1577 .bind(project_id)
1578 .bind(connection_id.0 as i32)
1579 .fetch_one(&mut tx)
1580 .await?;
1581 let room = self.commit_room_transaction(room_id, tx).await?;
1582
1583 Ok((room, guest_connection_ids))
1584 })
1585 .await
1586 }
1587
1588 pub async fn update_project(
1589 &self,
1590 project_id: ProjectId,
1591 connection_id: ConnectionId,
1592 worktrees: &[proto::WorktreeMetadata],
1593 ) -> Result<(proto::Room, Vec<ConnectionId>)> {
1594 self.transact(|mut tx| async move {
1595 let room_id: RoomId = sqlx::query_scalar(
1596 "
1597 SELECT room_id
1598 FROM projects
1599 WHERE id = $1 AND host_connection_id = $2
1600 ",
1601 )
1602 .bind(project_id)
1603 .bind(connection_id.0 as i32)
1604 .fetch_one(&mut tx)
1605 .await?;
1606
1607 if !worktrees.is_empty() {
1608 let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1609 params.pop();
1610 let query = format!(
1611 "
1612 INSERT INTO worktrees (
1613 project_id,
1614 id,
1615 root_name,
1616 abs_path,
1617 visible,
1618 scan_id,
1619 is_complete
1620 )
1621 VALUES {params}
1622 ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1623 "
1624 );
1625
1626 let mut query = sqlx::query(&query);
1627 for worktree in worktrees {
1628 query = query
1629 .bind(project_id)
1630 .bind(worktree.id as i32)
1631 .bind(&worktree.root_name)
1632 .bind(&worktree.abs_path)
1633 .bind(worktree.visible)
1634 .bind(0)
1635 .bind(false)
1636 }
1637 query.execute(&mut tx).await?;
1638 }
1639
1640 let mut params = "?,".repeat(worktrees.len());
1641 if !worktrees.is_empty() {
1642 params.pop();
1643 }
1644 let query = format!(
1645 "
1646 DELETE FROM worktrees
1647 WHERE project_id = ? AND id NOT IN ({params})
1648 ",
1649 );
1650
1651 let mut query = sqlx::query(&query).bind(project_id);
1652 for worktree in worktrees {
1653 query = query.bind(WorktreeId(worktree.id as i32));
1654 }
1655 query.execute(&mut tx).await?;
1656
1657 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1658 let room = self.commit_room_transaction(room_id, tx).await?;
1659
1660 Ok((room, guest_connection_ids))
1661 })
1662 .await
1663 }
1664
1665 pub async fn update_worktree(
1666 &self,
1667 update: &proto::UpdateWorktree,
1668 connection_id: ConnectionId,
1669 ) -> Result<Vec<ConnectionId>> {
1670 self.transact(|mut tx| async move {
1671 let project_id = ProjectId::from_proto(update.project_id);
1672 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1673
1674 // Ensure the update comes from the host.
1675 sqlx::query(
1676 "
1677 SELECT 1
1678 FROM projects
1679 WHERE id = $1 AND host_connection_id = $2
1680 ",
1681 )
1682 .bind(project_id)
1683 .bind(connection_id.0 as i32)
1684 .fetch_one(&mut tx)
1685 .await?;
1686
1687 // Update metadata.
1688 sqlx::query(
1689 "
1690 UPDATE worktrees
1691 SET
1692 root_name = $1,
1693 scan_id = $2,
1694 is_complete = $3,
1695 abs_path = $4
1696 WHERE project_id = $5 AND id = $6
1697 RETURNING 1
1698 ",
1699 )
1700 .bind(&update.root_name)
1701 .bind(update.scan_id as i64)
1702 .bind(update.is_last_update)
1703 .bind(&update.abs_path)
1704 .bind(project_id)
1705 .bind(worktree_id)
1706 .fetch_one(&mut tx)
1707 .await?;
1708
1709 if !update.updated_entries.is_empty() {
1710 let mut params =
1711 "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len());
1712 params.pop();
1713
1714 let query = format!(
1715 "
1716 INSERT INTO worktree_entries (
1717 project_id,
1718 worktree_id,
1719 id,
1720 is_dir,
1721 path,
1722 inode,
1723 mtime_seconds,
1724 mtime_nanos,
1725 is_symlink,
1726 is_ignored
1727 )
1728 VALUES {params}
1729 ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET
1730 is_dir = excluded.is_dir,
1731 path = excluded.path,
1732 inode = excluded.inode,
1733 mtime_seconds = excluded.mtime_seconds,
1734 mtime_nanos = excluded.mtime_nanos,
1735 is_symlink = excluded.is_symlink,
1736 is_ignored = excluded.is_ignored
1737 "
1738 );
1739 let mut query = sqlx::query(&query);
1740 for entry in &update.updated_entries {
1741 let mtime = entry.mtime.clone().unwrap_or_default();
1742 query = query
1743 .bind(project_id)
1744 .bind(worktree_id)
1745 .bind(entry.id as i64)
1746 .bind(entry.is_dir)
1747 .bind(&entry.path)
1748 .bind(entry.inode as i64)
1749 .bind(mtime.seconds as i64)
1750 .bind(mtime.nanos as i32)
1751 .bind(entry.is_symlink)
1752 .bind(entry.is_ignored);
1753 }
1754 query.execute(&mut tx).await?;
1755 }
1756
1757 if !update.removed_entries.is_empty() {
1758 let mut params = "?,".repeat(update.removed_entries.len());
1759 params.pop();
1760 let query = format!(
1761 "
1762 DELETE FROM worktree_entries
1763 WHERE project_id = ? AND worktree_id = ? AND id IN ({params})
1764 "
1765 );
1766
1767 let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id);
1768 for entry_id in &update.removed_entries {
1769 query = query.bind(*entry_id as i64);
1770 }
1771 query.execute(&mut tx).await?;
1772 }
1773
1774 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1775 tx.commit().await?;
1776 Ok(connection_ids)
1777 })
1778 .await
1779 }
1780
1781 pub async fn update_diagnostic_summary(
1782 &self,
1783 update: &proto::UpdateDiagnosticSummary,
1784 connection_id: ConnectionId,
1785 ) -> Result<Vec<ConnectionId>> {
1786 self.transact(|mut tx| async {
1787 let project_id = ProjectId::from_proto(update.project_id);
1788 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1789 let summary = update
1790 .summary
1791 .as_ref()
1792 .ok_or_else(|| anyhow!("invalid summary"))?;
1793
1794 // Ensure the update comes from the host.
1795 sqlx::query(
1796 "
1797 SELECT 1
1798 FROM projects
1799 WHERE id = $1 AND host_connection_id = $2
1800 ",
1801 )
1802 .bind(project_id)
1803 .bind(connection_id.0 as i32)
1804 .fetch_one(&mut tx)
1805 .await?;
1806
1807 // Update summary.
1808 sqlx::query(
1809 "
1810 INSERT INTO worktree_diagnostic_summaries (
1811 project_id,
1812 worktree_id,
1813 path,
1814 language_server_id,
1815 error_count,
1816 warning_count,
1817 version
1818 )
1819 VALUES ($1, $2, $3, $4, $5, $6, $7)
1820 ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET
1821 language_server_id = excluded.language_server_id,
1822 error_count = excluded.error_count,
1823 warning_count = excluded.warning_count,
1824 version = excluded.version
1825 ",
1826 )
1827 .bind(project_id)
1828 .bind(worktree_id)
1829 .bind(&summary.path)
1830 .bind(summary.language_server_id as i64)
1831 .bind(summary.error_count as i32)
1832 .bind(summary.warning_count as i32)
1833 .bind(summary.version as i32)
1834 .execute(&mut tx)
1835 .await?;
1836
1837 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1838 tx.commit().await?;
1839 Ok(connection_ids)
1840 })
1841 .await
1842 }
1843
1844 pub async fn start_language_server(
1845 &self,
1846 update: &proto::StartLanguageServer,
1847 connection_id: ConnectionId,
1848 ) -> Result<Vec<ConnectionId>> {
1849 self.transact(|mut tx| async {
1850 let project_id = ProjectId::from_proto(update.project_id);
1851 let server = update
1852 .server
1853 .as_ref()
1854 .ok_or_else(|| anyhow!("invalid language server"))?;
1855
1856 // Ensure the update comes from the host.
1857 sqlx::query(
1858 "
1859 SELECT 1
1860 FROM projects
1861 WHERE id = $1 AND host_connection_id = $2
1862 ",
1863 )
1864 .bind(project_id)
1865 .bind(connection_id.0 as i32)
1866 .fetch_one(&mut tx)
1867 .await?;
1868
1869 // Add the newly-started language server.
1870 sqlx::query(
1871 "
1872 INSERT INTO language_servers (project_id, id, name)
1873 VALUES ($1, $2, $3)
1874 ON CONFLICT (project_id, id) DO UPDATE SET
1875 name = excluded.name
1876 ",
1877 )
1878 .bind(project_id)
1879 .bind(server.id as i64)
1880 .bind(&server.name)
1881 .execute(&mut tx)
1882 .await?;
1883
1884 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1885 tx.commit().await?;
1886 Ok(connection_ids)
1887 })
1888 .await
1889 }
1890
1891 pub async fn join_project(
1892 &self,
1893 project_id: ProjectId,
1894 connection_id: ConnectionId,
1895 ) -> Result<(Project, ReplicaId)> {
1896 self.transact(|mut tx| async move {
1897 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1898 "
1899 SELECT room_id, user_id
1900 FROM room_participants
1901 WHERE answering_connection_id = $1
1902 ",
1903 )
1904 .bind(connection_id.0 as i32)
1905 .fetch_one(&mut tx)
1906 .await?;
1907
1908 // Ensure project id was shared on this room.
1909 sqlx::query(
1910 "
1911 SELECT 1
1912 FROM projects
1913 WHERE id = $1 AND room_id = $2
1914 ",
1915 )
1916 .bind(project_id)
1917 .bind(room_id)
1918 .fetch_one(&mut tx)
1919 .await?;
1920
1921 let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1922 "
1923 SELECT *
1924 FROM project_collaborators
1925 WHERE project_id = $1
1926 ",
1927 )
1928 .bind(project_id)
1929 .fetch_all(&mut tx)
1930 .await?;
1931 let replica_ids = collaborators
1932 .iter()
1933 .map(|c| c.replica_id)
1934 .collect::<HashSet<_>>();
1935 let mut replica_id = ReplicaId(1);
1936 while replica_ids.contains(&replica_id) {
1937 replica_id.0 += 1;
1938 }
1939 let new_collaborator = ProjectCollaborator {
1940 project_id,
1941 connection_id: connection_id.0 as i32,
1942 user_id,
1943 replica_id,
1944 is_host: false,
1945 };
1946
1947 sqlx::query(
1948 "
1949 INSERT INTO project_collaborators (
1950 project_id,
1951 connection_id,
1952 user_id,
1953 replica_id,
1954 is_host
1955 )
1956 VALUES ($1, $2, $3, $4, $5)
1957 ",
1958 )
1959 .bind(new_collaborator.project_id)
1960 .bind(new_collaborator.connection_id)
1961 .bind(new_collaborator.user_id)
1962 .bind(new_collaborator.replica_id)
1963 .bind(new_collaborator.is_host)
1964 .execute(&mut tx)
1965 .await?;
1966 collaborators.push(new_collaborator);
1967
1968 let worktree_rows = sqlx::query_as::<_, WorktreeRow>(
1969 "
1970 SELECT *
1971 FROM worktrees
1972 WHERE project_id = $1
1973 ",
1974 )
1975 .bind(project_id)
1976 .fetch_all(&mut tx)
1977 .await?;
1978 let mut worktrees = worktree_rows
1979 .into_iter()
1980 .map(|worktree_row| {
1981 (
1982 worktree_row.id,
1983 Worktree {
1984 id: worktree_row.id,
1985 abs_path: worktree_row.abs_path,
1986 root_name: worktree_row.root_name,
1987 visible: worktree_row.visible,
1988 entries: Default::default(),
1989 diagnostic_summaries: Default::default(),
1990 scan_id: worktree_row.scan_id as u64,
1991 is_complete: worktree_row.is_complete,
1992 },
1993 )
1994 })
1995 .collect::<BTreeMap<_, _>>();
1996
1997 // Populate worktree entries.
1998 {
1999 let mut entries = sqlx::query_as::<_, WorktreeEntry>(
2000 "
2001 SELECT *
2002 FROM worktree_entries
2003 WHERE project_id = $1
2004 ",
2005 )
2006 .bind(project_id)
2007 .fetch(&mut tx);
2008 while let Some(entry) = entries.next().await {
2009 let entry = entry?;
2010 if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) {
2011 worktree.entries.push(proto::Entry {
2012 id: entry.id as u64,
2013 is_dir: entry.is_dir,
2014 path: entry.path,
2015 inode: entry.inode as u64,
2016 mtime: Some(proto::Timestamp {
2017 seconds: entry.mtime_seconds as u64,
2018 nanos: entry.mtime_nanos as u32,
2019 }),
2020 is_symlink: entry.is_symlink,
2021 is_ignored: entry.is_ignored,
2022 });
2023 }
2024 }
2025 }
2026
2027 // Populate worktree diagnostic summaries.
2028 {
2029 let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(
2030 "
2031 SELECT *
2032 FROM worktree_diagnostic_summaries
2033 WHERE project_id = $1
2034 ",
2035 )
2036 .bind(project_id)
2037 .fetch(&mut tx);
2038 while let Some(summary) = summaries.next().await {
2039 let summary = summary?;
2040 if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) {
2041 worktree
2042 .diagnostic_summaries
2043 .push(proto::DiagnosticSummary {
2044 path: summary.path,
2045 language_server_id: summary.language_server_id as u64,
2046 error_count: summary.error_count as u32,
2047 warning_count: summary.warning_count as u32,
2048 version: summary.version as u32,
2049 });
2050 }
2051 }
2052 }
2053
2054 // Populate language servers.
2055 let language_servers = sqlx::query_as::<_, LanguageServer>(
2056 "
2057 SELECT *
2058 FROM language_servers
2059 WHERE project_id = $1
2060 ",
2061 )
2062 .bind(project_id)
2063 .fetch_all(&mut tx)
2064 .await?;
2065
2066 tx.commit().await?;
2067 Ok((
2068 Project {
2069 collaborators,
2070 worktrees,
2071 language_servers: language_servers
2072 .into_iter()
2073 .map(|language_server| proto::LanguageServer {
2074 id: language_server.id.to_proto(),
2075 name: language_server.name,
2076 })
2077 .collect(),
2078 },
2079 replica_id as ReplicaId,
2080 ))
2081 })
2082 .await
2083 }
2084
2085 pub async fn leave_project(
2086 &self,
2087 project_id: ProjectId,
2088 connection_id: ConnectionId,
2089 ) -> Result<LeftProject> {
2090 self.transact(|mut tx| async move {
2091 let result = sqlx::query(
2092 "
2093 DELETE FROM project_collaborators
2094 WHERE project_id = $1 AND connection_id = $2
2095 ",
2096 )
2097 .bind(project_id)
2098 .bind(connection_id.0 as i32)
2099 .execute(&mut tx)
2100 .await?;
2101
2102 if result.rows_affected() == 0 {
2103 Err(anyhow!("not a collaborator on this project"))?;
2104 }
2105
2106 let connection_ids = sqlx::query_scalar::<_, i32>(
2107 "
2108 SELECT connection_id
2109 FROM project_collaborators
2110 WHERE project_id = $1
2111 ",
2112 )
2113 .bind(project_id)
2114 .fetch_all(&mut tx)
2115 .await?
2116 .into_iter()
2117 .map(|id| ConnectionId(id as u32))
2118 .collect();
2119
2120 let (host_user_id, host_connection_id) = sqlx::query_as::<_, (i32, i32)>(
2121 "
2122 SELECT host_user_id, host_connection_id
2123 FROM projects
2124 WHERE id = $1
2125 ",
2126 )
2127 .bind(project_id)
2128 .fetch_one(&mut tx)
2129 .await?;
2130
2131 tx.commit().await?;
2132
2133 Ok(LeftProject {
2134 id: project_id,
2135 host_user_id: UserId(host_user_id),
2136 host_connection_id: ConnectionId(host_connection_id as u32),
2137 connection_ids,
2138 })
2139 })
2140 .await
2141 }
2142
2143 pub async fn project_collaborators(
2144 &self,
2145 project_id: ProjectId,
2146 connection_id: ConnectionId,
2147 ) -> Result<Vec<ProjectCollaborator>> {
2148 self.transact(|mut tx| async move {
2149 let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2150 "
2151 SELECT *
2152 FROM project_collaborators
2153 WHERE project_id = $1
2154 ",
2155 )
2156 .bind(project_id)
2157 .fetch_all(&mut tx)
2158 .await?;
2159
2160 if collaborators
2161 .iter()
2162 .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
2163 {
2164 Ok(collaborators)
2165 } else {
2166 Err(anyhow!("no such project"))?
2167 }
2168 })
2169 .await
2170 }
2171
2172 pub async fn project_connection_ids(
2173 &self,
2174 project_id: ProjectId,
2175 connection_id: ConnectionId,
2176 ) -> Result<HashSet<ConnectionId>> {
2177 self.transact(|mut tx| async move {
2178 let connection_ids = sqlx::query_scalar::<_, i32>(
2179 "
2180 SELECT connection_id
2181 FROM project_collaborators
2182 WHERE project_id = $1
2183 ",
2184 )
2185 .bind(project_id)
2186 .fetch_all(&mut tx)
2187 .await?;
2188
2189 if connection_ids.contains(&(connection_id.0 as i32)) {
2190 Ok(connection_ids
2191 .into_iter()
2192 .map(|connection_id| ConnectionId(connection_id as u32))
2193 .collect())
2194 } else {
2195 Err(anyhow!("no such project"))?
2196 }
2197 })
2198 .await
2199 }
2200
2201 // contacts
2202
2203 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
2204 self.transact(|mut tx| async move {
2205 let query = "
2206 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
2207 FROM contacts
2208 LEFT JOIN room_participants ON room_participants.user_id = $1
2209 WHERE user_id_a = $1 OR user_id_b = $1;
2210 ";
2211
2212 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
2213 .bind(user_id)
2214 .fetch(&mut tx);
2215
2216 let mut contacts = Vec::new();
2217 while let Some(row) = rows.next().await {
2218 let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
2219 if user_id_a == user_id {
2220 if accepted {
2221 contacts.push(Contact::Accepted {
2222 user_id: user_id_b,
2223 should_notify: should_notify && a_to_b,
2224 busy
2225 });
2226 } else if a_to_b {
2227 contacts.push(Contact::Outgoing { user_id: user_id_b })
2228 } else {
2229 contacts.push(Contact::Incoming {
2230 user_id: user_id_b,
2231 should_notify,
2232 });
2233 }
2234 } else if accepted {
2235 contacts.push(Contact::Accepted {
2236 user_id: user_id_a,
2237 should_notify: should_notify && !a_to_b,
2238 busy
2239 });
2240 } else if a_to_b {
2241 contacts.push(Contact::Incoming {
2242 user_id: user_id_a,
2243 should_notify,
2244 });
2245 } else {
2246 contacts.push(Contact::Outgoing { user_id: user_id_a });
2247 }
2248 }
2249
2250 contacts.sort_unstable_by_key(|contact| contact.user_id());
2251
2252 Ok(contacts)
2253 })
2254 .await
2255 }
2256
2257 pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
2258 self.transact(|mut tx| async move {
2259 Ok(sqlx::query_scalar::<_, i32>(
2260 "
2261 SELECT 1
2262 FROM room_participants
2263 WHERE room_participants.user_id = $1
2264 ",
2265 )
2266 .bind(user_id)
2267 .fetch_optional(&mut tx)
2268 .await?
2269 .is_some())
2270 })
2271 .await
2272 }
2273
2274 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
2275 self.transact(|mut tx| async move {
2276 let (id_a, id_b) = if user_id_1 < user_id_2 {
2277 (user_id_1, user_id_2)
2278 } else {
2279 (user_id_2, user_id_1)
2280 };
2281
2282 let query = "
2283 SELECT 1 FROM contacts
2284 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
2285 LIMIT 1
2286 ";
2287 Ok(sqlx::query_scalar::<_, i32>(query)
2288 .bind(id_a.0)
2289 .bind(id_b.0)
2290 .fetch_optional(&mut tx)
2291 .await?
2292 .is_some())
2293 })
2294 .await
2295 }
2296
2297 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
2298 self.transact(|mut tx| async move {
2299 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
2300 (sender_id, receiver_id, true)
2301 } else {
2302 (receiver_id, sender_id, false)
2303 };
2304 let query = "
2305 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
2306 VALUES ($1, $2, $3, FALSE, TRUE)
2307 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
2308 SET
2309 accepted = TRUE,
2310 should_notify = FALSE
2311 WHERE
2312 NOT contacts.accepted AND
2313 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
2314 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
2315 ";
2316 let result = sqlx::query(query)
2317 .bind(id_a.0)
2318 .bind(id_b.0)
2319 .bind(a_to_b)
2320 .execute(&mut tx)
2321 .await?;
2322
2323 if result.rows_affected() == 1 {
2324 tx.commit().await?;
2325 Ok(())
2326 } else {
2327 Err(anyhow!("contact already requested"))?
2328 }
2329 }).await
2330 }
2331
2332 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2333 self.transact(|mut tx| async move {
2334 let (id_a, id_b) = if responder_id < requester_id {
2335 (responder_id, requester_id)
2336 } else {
2337 (requester_id, responder_id)
2338 };
2339 let query = "
2340 DELETE FROM contacts
2341 WHERE user_id_a = $1 AND user_id_b = $2;
2342 ";
2343 let result = sqlx::query(query)
2344 .bind(id_a.0)
2345 .bind(id_b.0)
2346 .execute(&mut tx)
2347 .await?;
2348
2349 if result.rows_affected() == 1 {
2350 tx.commit().await?;
2351 Ok(())
2352 } else {
2353 Err(anyhow!("no such contact"))?
2354 }
2355 })
2356 .await
2357 }
2358
2359 pub async fn dismiss_contact_notification(
2360 &self,
2361 user_id: UserId,
2362 contact_user_id: UserId,
2363 ) -> Result<()> {
2364 self.transact(|mut tx| async move {
2365 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
2366 (user_id, contact_user_id, true)
2367 } else {
2368 (contact_user_id, user_id, false)
2369 };
2370
2371 let query = "
2372 UPDATE contacts
2373 SET should_notify = FALSE
2374 WHERE
2375 user_id_a = $1 AND user_id_b = $2 AND
2376 (
2377 (a_to_b = $3 AND accepted) OR
2378 (a_to_b != $3 AND NOT accepted)
2379 );
2380 ";
2381
2382 let result = sqlx::query(query)
2383 .bind(id_a.0)
2384 .bind(id_b.0)
2385 .bind(a_to_b)
2386 .execute(&mut tx)
2387 .await?;
2388
2389 if result.rows_affected() == 0 {
2390 Err(anyhow!("no such contact request"))?
2391 } else {
2392 tx.commit().await?;
2393 Ok(())
2394 }
2395 })
2396 .await
2397 }
2398
2399 pub async fn respond_to_contact_request(
2400 &self,
2401 responder_id: UserId,
2402 requester_id: UserId,
2403 accept: bool,
2404 ) -> Result<()> {
2405 self.transact(|mut tx| async move {
2406 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
2407 (responder_id, requester_id, false)
2408 } else {
2409 (requester_id, responder_id, true)
2410 };
2411 let result = if accept {
2412 let query = "
2413 UPDATE contacts
2414 SET accepted = TRUE, should_notify = TRUE
2415 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
2416 ";
2417 sqlx::query(query)
2418 .bind(id_a.0)
2419 .bind(id_b.0)
2420 .bind(a_to_b)
2421 .execute(&mut tx)
2422 .await?
2423 } else {
2424 let query = "
2425 DELETE FROM contacts
2426 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
2427 ";
2428 sqlx::query(query)
2429 .bind(id_a.0)
2430 .bind(id_b.0)
2431 .bind(a_to_b)
2432 .execute(&mut tx)
2433 .await?
2434 };
2435 if result.rows_affected() == 1 {
2436 tx.commit().await?;
2437 Ok(())
2438 } else {
2439 Err(anyhow!("no such contact request"))?
2440 }
2441 })
2442 .await
2443 }
2444
2445 // access tokens
2446
2447 pub async fn create_access_token_hash(
2448 &self,
2449 user_id: UserId,
2450 access_token_hash: &str,
2451 max_access_token_count: usize,
2452 ) -> Result<()> {
2453 self.transact(|tx| async {
2454 let mut tx = tx;
2455 let insert_query = "
2456 INSERT INTO access_tokens (user_id, hash)
2457 VALUES ($1, $2);
2458 ";
2459 let cleanup_query = "
2460 DELETE FROM access_tokens
2461 WHERE id IN (
2462 SELECT id from access_tokens
2463 WHERE user_id = $1
2464 ORDER BY id DESC
2465 LIMIT 10000
2466 OFFSET $3
2467 )
2468 ";
2469
2470 sqlx::query(insert_query)
2471 .bind(user_id.0)
2472 .bind(access_token_hash)
2473 .execute(&mut tx)
2474 .await?;
2475 sqlx::query(cleanup_query)
2476 .bind(user_id.0)
2477 .bind(access_token_hash)
2478 .bind(max_access_token_count as i32)
2479 .execute(&mut tx)
2480 .await?;
2481 Ok(tx.commit().await?)
2482 })
2483 .await
2484 }
2485
2486 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
2487 self.transact(|mut tx| async move {
2488 let query = "
2489 SELECT hash
2490 FROM access_tokens
2491 WHERE user_id = $1
2492 ORDER BY id DESC
2493 ";
2494 Ok(sqlx::query_scalar(query)
2495 .bind(user_id.0)
2496 .fetch_all(&mut tx)
2497 .await?)
2498 })
2499 .await
2500 }
2501
2502 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
2503 where
2504 F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
2505 Fut: Send + Future<Output = Result<T>>,
2506 {
2507 let body = async {
2508 loop {
2509 let tx = self.begin_transaction().await?;
2510 match f(tx).await {
2511 Ok(result) => return Ok(result),
2512 Err(error) => match error {
2513 Error::Database(error)
2514 if error
2515 .as_database_error()
2516 .and_then(|error| error.code())
2517 .as_deref()
2518 == Some("hey") =>
2519 {
2520 // Retry (don't break the loop)
2521 }
2522 error @ _ => return Err(error),
2523 },
2524 }
2525 }
2526 };
2527
2528 #[cfg(test)]
2529 {
2530 if let Some(background) = self.background.as_ref() {
2531 background.simulate_random_delay().await;
2532 }
2533
2534 let result = self.runtime.as_ref().unwrap().block_on(body);
2535
2536 if let Some(background) = self.background.as_ref() {
2537 background.simulate_random_delay().await;
2538 }
2539
2540 result
2541 }
2542
2543 #[cfg(not(test))]
2544 {
2545 body.await
2546 }
2547 }
2548}
2549
2550macro_rules! id_type {
2551 ($name:ident) => {
2552 #[derive(
2553 Clone,
2554 Copy,
2555 Debug,
2556 Default,
2557 PartialEq,
2558 Eq,
2559 PartialOrd,
2560 Ord,
2561 Hash,
2562 sqlx::Type,
2563 Serialize,
2564 Deserialize,
2565 )]
2566 #[sqlx(transparent)]
2567 #[serde(transparent)]
2568 pub struct $name(pub i32);
2569
2570 impl $name {
2571 #[allow(unused)]
2572 pub const MAX: Self = Self(i32::MAX);
2573
2574 #[allow(unused)]
2575 pub fn from_proto(value: u64) -> Self {
2576 Self(value as i32)
2577 }
2578
2579 #[allow(unused)]
2580 pub fn to_proto(self) -> u64 {
2581 self.0 as u64
2582 }
2583 }
2584
2585 impl std::fmt::Display for $name {
2586 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2587 self.0.fmt(f)
2588 }
2589 }
2590 };
2591}
2592
2593id_type!(UserId);
2594#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2595pub struct User {
2596 pub id: UserId,
2597 pub github_login: String,
2598 pub github_user_id: Option<i32>,
2599 pub email_address: Option<String>,
2600 pub admin: bool,
2601 pub invite_code: Option<String>,
2602 pub invite_count: i32,
2603 pub connected_once: bool,
2604}
2605
2606id_type!(RoomId);
2607#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2608pub struct Room {
2609 pub id: RoomId,
2610 pub version: i32,
2611 pub live_kit_room: String,
2612}
2613
2614id_type!(ProjectId);
2615pub struct Project {
2616 pub collaborators: Vec<ProjectCollaborator>,
2617 pub worktrees: BTreeMap<WorktreeId, Worktree>,
2618 pub language_servers: Vec<proto::LanguageServer>,
2619}
2620
2621id_type!(ReplicaId);
2622#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2623pub struct ProjectCollaborator {
2624 pub project_id: ProjectId,
2625 pub connection_id: i32,
2626 pub user_id: UserId,
2627 pub replica_id: ReplicaId,
2628 pub is_host: bool,
2629}
2630
2631id_type!(WorktreeId);
2632#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2633struct WorktreeRow {
2634 pub id: WorktreeId,
2635 pub abs_path: String,
2636 pub root_name: String,
2637 pub visible: bool,
2638 pub scan_id: i64,
2639 pub is_complete: bool,
2640}
2641
2642pub struct Worktree {
2643 pub id: WorktreeId,
2644 pub abs_path: String,
2645 pub root_name: String,
2646 pub visible: bool,
2647 pub entries: Vec<proto::Entry>,
2648 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
2649 pub scan_id: u64,
2650 pub is_complete: bool,
2651}
2652
2653#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2654struct WorktreeEntry {
2655 id: i64,
2656 worktree_id: WorktreeId,
2657 is_dir: bool,
2658 path: String,
2659 inode: i64,
2660 mtime_seconds: i64,
2661 mtime_nanos: i32,
2662 is_symlink: bool,
2663 is_ignored: bool,
2664}
2665
2666#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2667struct WorktreeDiagnosticSummary {
2668 worktree_id: WorktreeId,
2669 path: String,
2670 language_server_id: i64,
2671 error_count: i32,
2672 warning_count: i32,
2673 version: i32,
2674}
2675
2676id_type!(LanguageServerId);
2677#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2678struct LanguageServer {
2679 id: LanguageServerId,
2680 name: String,
2681}
2682
2683pub struct LeftProject {
2684 pub id: ProjectId,
2685 pub host_user_id: UserId,
2686 pub host_connection_id: ConnectionId,
2687 pub connection_ids: Vec<ConnectionId>,
2688}
2689
2690pub struct LeftRoom {
2691 pub room: proto::Room,
2692 pub left_projects: HashMap<ProjectId, LeftProject>,
2693 pub canceled_calls_to_user_ids: Vec<UserId>,
2694}
2695
2696#[derive(Clone, Debug, PartialEq, Eq)]
2697pub enum Contact {
2698 Accepted {
2699 user_id: UserId,
2700 should_notify: bool,
2701 busy: bool,
2702 },
2703 Outgoing {
2704 user_id: UserId,
2705 },
2706 Incoming {
2707 user_id: UserId,
2708 should_notify: bool,
2709 },
2710}
2711
2712impl Contact {
2713 pub fn user_id(&self) -> UserId {
2714 match self {
2715 Contact::Accepted { user_id, .. } => *user_id,
2716 Contact::Outgoing { user_id } => *user_id,
2717 Contact::Incoming { user_id, .. } => *user_id,
2718 }
2719 }
2720}
2721
2722#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2723pub struct IncomingContactRequest {
2724 pub requester_id: UserId,
2725 pub should_notify: bool,
2726}
2727
2728#[derive(Clone, Deserialize)]
2729pub struct Signup {
2730 pub email_address: String,
2731 pub platform_mac: bool,
2732 pub platform_windows: bool,
2733 pub platform_linux: bool,
2734 pub editor_features: Vec<String>,
2735 pub programming_languages: Vec<String>,
2736 pub device_id: Option<String>,
2737}
2738
2739#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2740pub struct WaitlistSummary {
2741 #[sqlx(default)]
2742 pub count: i64,
2743 #[sqlx(default)]
2744 pub linux_count: i64,
2745 #[sqlx(default)]
2746 pub mac_count: i64,
2747 #[sqlx(default)]
2748 pub windows_count: i64,
2749 #[sqlx(default)]
2750 pub unknown_count: i64,
2751}
2752
2753#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2754pub struct Invite {
2755 pub email_address: String,
2756 pub email_confirmation_code: String,
2757}
2758
2759#[derive(Debug, Serialize, Deserialize)]
2760pub struct NewUserParams {
2761 pub github_login: String,
2762 pub github_user_id: i32,
2763 pub invite_count: i32,
2764}
2765
2766#[derive(Debug)]
2767pub struct NewUserResult {
2768 pub user_id: UserId,
2769 pub metrics_id: String,
2770 pub inviting_user_id: Option<UserId>,
2771 pub signup_device_id: Option<String>,
2772}
2773
2774fn random_invite_code() -> String {
2775 nanoid::nanoid!(16)
2776}
2777
2778fn random_email_confirmation_code() -> String {
2779 nanoid::nanoid!(64)
2780}
2781
2782#[cfg(test)]
2783pub use test::*;
2784
2785#[cfg(test)]
2786mod test {
2787 use super::*;
2788 use gpui::executor::Background;
2789 use lazy_static::lazy_static;
2790 use parking_lot::Mutex;
2791 use rand::prelude::*;
2792 use sqlx::migrate::MigrateDatabase;
2793 use std::sync::Arc;
2794
2795 pub struct SqliteTestDb {
2796 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2797 pub conn: sqlx::sqlite::SqliteConnection,
2798 }
2799
2800 pub struct PostgresTestDb {
2801 pub db: Option<Arc<Db<sqlx::Postgres>>>,
2802 pub url: String,
2803 }
2804
2805 impl SqliteTestDb {
2806 pub fn new(background: Arc<Background>) -> Self {
2807 let mut rng = StdRng::from_entropy();
2808 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2809 let runtime = tokio::runtime::Builder::new_current_thread()
2810 .enable_io()
2811 .enable_time()
2812 .build()
2813 .unwrap();
2814
2815 let (mut db, conn) = runtime.block_on(async {
2816 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2817 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2818 db.migrate(migrations_path.as_ref(), false).await.unwrap();
2819 let conn = db.pool.acquire().await.unwrap().detach();
2820 (db, conn)
2821 });
2822
2823 db.background = Some(background);
2824 db.runtime = Some(runtime);
2825
2826 Self {
2827 db: Some(Arc::new(db)),
2828 conn,
2829 }
2830 }
2831
2832 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2833 self.db.as_ref().unwrap()
2834 }
2835 }
2836
2837 impl PostgresTestDb {
2838 pub fn new(background: Arc<Background>) -> Self {
2839 lazy_static! {
2840 static ref LOCK: Mutex<()> = Mutex::new(());
2841 }
2842
2843 let _guard = LOCK.lock();
2844 let mut rng = StdRng::from_entropy();
2845 let url = format!(
2846 "postgres://postgres@localhost/zed-test-{}",
2847 rng.gen::<u128>()
2848 );
2849 let runtime = tokio::runtime::Builder::new_current_thread()
2850 .enable_io()
2851 .enable_time()
2852 .build()
2853 .unwrap();
2854
2855 let mut db = runtime.block_on(async {
2856 sqlx::Postgres::create_database(&url)
2857 .await
2858 .expect("failed to create test db");
2859 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2860 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2861 db.migrate(Path::new(migrations_path), false).await.unwrap();
2862 db
2863 });
2864
2865 db.background = Some(background);
2866 db.runtime = Some(runtime);
2867
2868 Self {
2869 db: Some(Arc::new(db)),
2870 url,
2871 }
2872 }
2873
2874 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2875 self.db.as_ref().unwrap()
2876 }
2877 }
2878
2879 impl Drop for PostgresTestDb {
2880 fn drop(&mut self) {
2881 let db = self.db.take().unwrap();
2882 db.teardown(&self.url);
2883 }
2884 }
2885}