1use crate::{Error, Result};
2use anyhow::anyhow;
3use axum::http::StatusCode;
4use collections::{BTreeMap, HashMap, HashSet};
5use dashmap::DashMap;
6use futures::{future::BoxFuture, FutureExt, StreamExt};
7use rpc::{proto, ConnectionId};
8use serde::{Deserialize, Serialize};
9use sqlx::{
10 migrate::{Migrate as _, Migration, MigrationSource},
11 types::Uuid,
12 FromRow,
13};
14use std::{
15 future::Future,
16 marker::PhantomData,
17 ops::{Deref, DerefMut},
18 path::Path,
19 rc::Rc,
20 sync::Arc,
21 time::Duration,
22};
23use time::{OffsetDateTime, PrimitiveDateTime};
24use tokio::sync::{Mutex, OwnedMutexGuard};
25
26#[cfg(test)]
27pub type DefaultDb = Db<sqlx::Sqlite>;
28
29#[cfg(not(test))]
30pub type DefaultDb = Db<sqlx::Postgres>;
31
32pub struct Db<D: sqlx::Database> {
33 pool: sqlx::Pool<D>,
34 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
35 #[cfg(test)]
36 background: Option<std::sync::Arc<gpui::executor::Background>>,
37 #[cfg(test)]
38 runtime: Option<tokio::runtime::Runtime>,
39}
40
41pub struct RoomGuard<T> {
42 data: T,
43 _guard: OwnedMutexGuard<()>,
44 _not_send: PhantomData<Rc<()>>,
45}
46
47impl<T> Deref for RoomGuard<T> {
48 type Target = T;
49
50 fn deref(&self) -> &T {
51 &self.data
52 }
53}
54
55impl<T> DerefMut for RoomGuard<T> {
56 fn deref_mut(&mut self) -> &mut T {
57 &mut self.data
58 }
59}
60
61pub trait BeginTransaction: Send + Sync {
62 type Database: sqlx::Database;
63
64 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
65}
66
67// In Postgres, serializable transactions are opt-in
68impl BeginTransaction for Db<sqlx::Postgres> {
69 type Database = sqlx::Postgres;
70
71 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
72 async move {
73 let mut tx = self.pool.begin().await?;
74 sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
75 .await?;
76 Ok(tx)
77 }
78 .boxed()
79 }
80}
81
82// In Sqlite, transactions are inherently serializable.
83#[cfg(test)]
84impl BeginTransaction for Db<sqlx::Sqlite> {
85 type Database = sqlx::Sqlite;
86
87 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
88 async move { Ok(self.pool.begin().await?) }.boxed()
89 }
90}
91
92pub trait RowsAffected {
93 fn rows_affected(&self) -> u64;
94}
95
96#[cfg(test)]
97impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
98 fn rows_affected(&self) -> u64 {
99 self.rows_affected()
100 }
101}
102
103impl RowsAffected for sqlx::postgres::PgQueryResult {
104 fn rows_affected(&self) -> u64 {
105 self.rows_affected()
106 }
107}
108
109#[cfg(test)]
110impl Db<sqlx::Sqlite> {
111 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
112 use std::str::FromStr as _;
113 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
114 .unwrap()
115 .create_if_missing(true)
116 .shared_cache(true);
117 let pool = sqlx::sqlite::SqlitePoolOptions::new()
118 .min_connections(2)
119 .max_connections(max_connections)
120 .connect_with(options)
121 .await?;
122 Ok(Self {
123 pool,
124 rooms: Default::default(),
125 background: None,
126 runtime: None,
127 })
128 }
129
130 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
131 self.transact(|tx| async {
132 let mut tx = tx;
133 let query = "
134 SELECT users.*
135 FROM users
136 WHERE users.id IN (SELECT value from json_each($1))
137 ";
138 Ok(sqlx::query_as(query)
139 .bind(&serde_json::json!(ids))
140 .fetch_all(&mut tx)
141 .await?)
142 })
143 .await
144 }
145
146 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
147 self.transact(|mut tx| async move {
148 let query = "
149 SELECT metrics_id
150 FROM users
151 WHERE id = $1
152 ";
153 Ok(sqlx::query_scalar(query)
154 .bind(id)
155 .fetch_one(&mut tx)
156 .await?)
157 })
158 .await
159 }
160
161 pub async fn create_user(
162 &self,
163 email_address: &str,
164 admin: bool,
165 params: NewUserParams,
166 ) -> Result<NewUserResult> {
167 self.transact(|mut tx| async {
168 let query = "
169 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
170 VALUES ($1, $2, $3, $4, $5)
171 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
172 RETURNING id, metrics_id
173 ";
174
175 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
176 .bind(email_address)
177 .bind(¶ms.github_login)
178 .bind(¶ms.github_user_id)
179 .bind(admin)
180 .bind(Uuid::new_v4().to_string())
181 .fetch_one(&mut tx)
182 .await?;
183 tx.commit().await?;
184 Ok(NewUserResult {
185 user_id,
186 metrics_id,
187 signup_device_id: None,
188 inviting_user_id: None,
189 })
190 })
191 .await
192 }
193
194 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
195 unimplemented!()
196 }
197
198 pub async fn create_user_from_invite(
199 &self,
200 _invite: &Invite,
201 _user: NewUserParams,
202 ) -> Result<Option<NewUserResult>> {
203 unimplemented!()
204 }
205
206 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
207 unimplemented!()
208 }
209
210 pub async fn create_invite_from_code(
211 &self,
212 _code: &str,
213 _email_address: &str,
214 _device_id: Option<&str>,
215 ) -> Result<Invite> {
216 unimplemented!()
217 }
218
219 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
220 unimplemented!()
221 }
222}
223
224impl Db<sqlx::Postgres> {
225 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
226 let pool = sqlx::postgres::PgPoolOptions::new()
227 .max_connections(max_connections)
228 .connect(url)
229 .await?;
230 Ok(Self {
231 pool,
232 rooms: DashMap::with_capacity(16384),
233 #[cfg(test)]
234 background: None,
235 #[cfg(test)]
236 runtime: None,
237 })
238 }
239
240 #[cfg(test)]
241 pub fn teardown(&self, url: &str) {
242 self.runtime.as_ref().unwrap().block_on(async {
243 use util::ResultExt;
244 let query = "
245 SELECT pg_terminate_backend(pg_stat_activity.pid)
246 FROM pg_stat_activity
247 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
248 ";
249 sqlx::query(query).execute(&self.pool).await.log_err();
250 self.pool.close().await;
251 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
252 .await
253 .log_err();
254 })
255 }
256
257 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
258 self.transact(|tx| async {
259 let mut tx = tx;
260 let like_string = Self::fuzzy_like_string(name_query);
261 let query = "
262 SELECT users.*
263 FROM users
264 WHERE github_login ILIKE $1
265 ORDER BY github_login <-> $2
266 LIMIT $3
267 ";
268 Ok(sqlx::query_as(query)
269 .bind(like_string)
270 .bind(name_query)
271 .bind(limit as i32)
272 .fetch_all(&mut tx)
273 .await?)
274 })
275 .await
276 }
277
278 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
279 let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
280 self.transact(|tx| async {
281 let mut tx = tx;
282 let query = "
283 SELECT users.*
284 FROM users
285 WHERE users.id = ANY ($1)
286 ";
287 Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
288 })
289 .await
290 }
291
292 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
293 self.transact(|mut tx| async move {
294 let query = "
295 SELECT metrics_id::text
296 FROM users
297 WHERE id = $1
298 ";
299 Ok(sqlx::query_scalar(query)
300 .bind(id)
301 .fetch_one(&mut tx)
302 .await?)
303 })
304 .await
305 }
306
307 pub async fn create_user(
308 &self,
309 email_address: &str,
310 admin: bool,
311 params: NewUserParams,
312 ) -> Result<NewUserResult> {
313 self.transact(|mut tx| async {
314 let query = "
315 INSERT INTO users (email_address, github_login, github_user_id, admin)
316 VALUES ($1, $2, $3, $4)
317 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
318 RETURNING id, metrics_id::text
319 ";
320
321 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
322 .bind(email_address)
323 .bind(¶ms.github_login)
324 .bind(params.github_user_id)
325 .bind(admin)
326 .fetch_one(&mut tx)
327 .await?;
328 tx.commit().await?;
329
330 Ok(NewUserResult {
331 user_id,
332 metrics_id,
333 signup_device_id: None,
334 inviting_user_id: None,
335 })
336 })
337 .await
338 }
339
340 pub async fn create_user_from_invite(
341 &self,
342 invite: &Invite,
343 user: NewUserParams,
344 ) -> Result<Option<NewUserResult>> {
345 self.transact(|mut tx| async {
346 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
347 i32,
348 Option<UserId>,
349 Option<UserId>,
350 Option<String>,
351 ) = sqlx::query_as(
352 "
353 SELECT id, user_id, inviting_user_id, device_id
354 FROM signups
355 WHERE
356 email_address = $1 AND
357 email_confirmation_code = $2
358 ",
359 )
360 .bind(&invite.email_address)
361 .bind(&invite.email_confirmation_code)
362 .fetch_optional(&mut tx)
363 .await?
364 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
365
366 if existing_user_id.is_some() {
367 return Ok(None);
368 }
369
370 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
371 "
372 INSERT INTO users
373 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
374 VALUES
375 ($1, $2, $3, FALSE, $4, $5)
376 ON CONFLICT (github_login) DO UPDATE SET
377 email_address = excluded.email_address,
378 github_user_id = excluded.github_user_id,
379 admin = excluded.admin
380 RETURNING id, metrics_id::text
381 ",
382 )
383 .bind(&invite.email_address)
384 .bind(&user.github_login)
385 .bind(&user.github_user_id)
386 .bind(&user.invite_count)
387 .bind(random_invite_code())
388 .fetch_one(&mut tx)
389 .await?;
390
391 sqlx::query(
392 "
393 UPDATE signups
394 SET user_id = $1
395 WHERE id = $2
396 ",
397 )
398 .bind(&user_id)
399 .bind(&signup_id)
400 .execute(&mut tx)
401 .await?;
402
403 if let Some(inviting_user_id) = inviting_user_id {
404 let id: Option<UserId> = sqlx::query_scalar(
405 "
406 UPDATE users
407 SET invite_count = invite_count - 1
408 WHERE id = $1 AND invite_count > 0
409 RETURNING id
410 ",
411 )
412 .bind(&inviting_user_id)
413 .fetch_optional(&mut tx)
414 .await?;
415
416 if id.is_none() {
417 Err(Error::Http(
418 StatusCode::UNAUTHORIZED,
419 "no invites remaining".to_string(),
420 ))?;
421 }
422
423 sqlx::query(
424 "
425 INSERT INTO contacts
426 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
427 VALUES
428 ($1, $2, TRUE, TRUE, TRUE)
429 ON CONFLICT DO NOTHING
430 ",
431 )
432 .bind(inviting_user_id)
433 .bind(user_id)
434 .execute(&mut tx)
435 .await?;
436 }
437
438 tx.commit().await?;
439 Ok(Some(NewUserResult {
440 user_id,
441 metrics_id,
442 inviting_user_id,
443 signup_device_id,
444 }))
445 })
446 .await
447 }
448
449 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
450 self.transact(|mut tx| async {
451 sqlx::query(
452 "
453 INSERT INTO signups
454 (
455 email_address,
456 email_confirmation_code,
457 email_confirmation_sent,
458 platform_linux,
459 platform_mac,
460 platform_windows,
461 platform_unknown,
462 editor_features,
463 programming_languages,
464 device_id
465 )
466 VALUES
467 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
468 RETURNING id
469 ",
470 )
471 .bind(&signup.email_address)
472 .bind(&random_email_confirmation_code())
473 .bind(&signup.platform_linux)
474 .bind(&signup.platform_mac)
475 .bind(&signup.platform_windows)
476 .bind(&signup.editor_features)
477 .bind(&signup.programming_languages)
478 .bind(&signup.device_id)
479 .execute(&mut tx)
480 .await?;
481 tx.commit().await?;
482 Ok(())
483 })
484 .await
485 }
486
487 pub async fn create_invite_from_code(
488 &self,
489 code: &str,
490 email_address: &str,
491 device_id: Option<&str>,
492 ) -> Result<Invite> {
493 self.transact(|mut tx| async {
494 let existing_user: Option<UserId> = sqlx::query_scalar(
495 "
496 SELECT id
497 FROM users
498 WHERE email_address = $1
499 ",
500 )
501 .bind(email_address)
502 .fetch_optional(&mut tx)
503 .await?;
504 if existing_user.is_some() {
505 Err(anyhow!("email address is already in use"))?;
506 }
507
508 let row: Option<(UserId, i32)> = sqlx::query_as(
509 "
510 SELECT id, invite_count
511 FROM users
512 WHERE invite_code = $1
513 ",
514 )
515 .bind(code)
516 .fetch_optional(&mut tx)
517 .await?;
518
519 let (inviter_id, invite_count) = match row {
520 Some(row) => row,
521 None => Err(Error::Http(
522 StatusCode::NOT_FOUND,
523 "invite code not found".to_string(),
524 ))?,
525 };
526
527 if invite_count == 0 {
528 Err(Error::Http(
529 StatusCode::UNAUTHORIZED,
530 "no invites remaining".to_string(),
531 ))?;
532 }
533
534 let email_confirmation_code: String = sqlx::query_scalar(
535 "
536 INSERT INTO signups
537 (
538 email_address,
539 email_confirmation_code,
540 email_confirmation_sent,
541 inviting_user_id,
542 platform_linux,
543 platform_mac,
544 platform_windows,
545 platform_unknown,
546 device_id
547 )
548 VALUES
549 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
550 ON CONFLICT (email_address)
551 DO UPDATE SET
552 inviting_user_id = excluded.inviting_user_id
553 RETURNING email_confirmation_code
554 ",
555 )
556 .bind(&email_address)
557 .bind(&random_email_confirmation_code())
558 .bind(&inviter_id)
559 .bind(&device_id)
560 .fetch_one(&mut tx)
561 .await?;
562
563 tx.commit().await?;
564
565 Ok(Invite {
566 email_address: email_address.into(),
567 email_confirmation_code,
568 })
569 })
570 .await
571 }
572
573 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
574 self.transact(|mut tx| async {
575 let emails = invites
576 .iter()
577 .map(|s| s.email_address.as_str())
578 .collect::<Vec<_>>();
579 sqlx::query(
580 "
581 UPDATE signups
582 SET email_confirmation_sent = TRUE
583 WHERE email_address = ANY ($1)
584 ",
585 )
586 .bind(&emails)
587 .execute(&mut tx)
588 .await?;
589 tx.commit().await?;
590 Ok(())
591 })
592 .await
593 }
594}
595
596impl<D> Db<D>
597where
598 Self: BeginTransaction<Database = D>,
599 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
600 D::Connection: sqlx::migrate::Migrate,
601 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
602 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
603 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
604 D::QueryResult: RowsAffected,
605 String: sqlx::Type<D>,
606 i32: sqlx::Type<D>,
607 i64: sqlx::Type<D>,
608 bool: sqlx::Type<D>,
609 str: sqlx::Type<D>,
610 Uuid: sqlx::Type<D>,
611 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
612 OffsetDateTime: sqlx::Type<D>,
613 PrimitiveDateTime: sqlx::Type<D>,
614 usize: sqlx::ColumnIndex<D::Row>,
615 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
616 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
617 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
618 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
619 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
620 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
621 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
622 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
623 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
624 for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
625 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
626 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
627 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
628{
629 pub async fn migrate(
630 &self,
631 migrations_path: &Path,
632 ignore_checksum_mismatch: bool,
633 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
634 let migrations = MigrationSource::resolve(migrations_path)
635 .await
636 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
637
638 let mut conn = self.pool.acquire().await?;
639
640 conn.ensure_migrations_table().await?;
641 let applied_migrations: HashMap<_, _> = conn
642 .list_applied_migrations()
643 .await?
644 .into_iter()
645 .map(|m| (m.version, m))
646 .collect();
647
648 let mut new_migrations = Vec::new();
649 for migration in migrations {
650 match applied_migrations.get(&migration.version) {
651 Some(applied_migration) => {
652 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
653 {
654 Err(anyhow!(
655 "checksum mismatch for applied migration {}",
656 migration.description
657 ))?;
658 }
659 }
660 None => {
661 let elapsed = conn.apply(&migration).await?;
662 new_migrations.push((migration, elapsed));
663 }
664 }
665 }
666
667 Ok(new_migrations)
668 }
669
670 pub fn fuzzy_like_string(string: &str) -> String {
671 let mut result = String::with_capacity(string.len() * 2 + 1);
672 for c in string.chars() {
673 if c.is_alphanumeric() {
674 result.push('%');
675 result.push(c);
676 }
677 }
678 result.push('%');
679 result
680 }
681
682 // users
683
684 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
685 self.transact(|tx| async {
686 let mut tx = tx;
687 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
688 Ok(sqlx::query_as(query)
689 .bind(limit as i32)
690 .bind((page * limit) as i32)
691 .fetch_all(&mut tx)
692 .await?)
693 })
694 .await
695 }
696
697 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
698 self.transact(|tx| async {
699 let mut tx = tx;
700 let query = "
701 SELECT users.*
702 FROM users
703 WHERE id = $1
704 LIMIT 1
705 ";
706 Ok(sqlx::query_as(query)
707 .bind(&id)
708 .fetch_optional(&mut tx)
709 .await?)
710 })
711 .await
712 }
713
714 pub async fn get_users_with_no_invites(
715 &self,
716 invited_by_another_user: bool,
717 ) -> Result<Vec<User>> {
718 self.transact(|tx| async {
719 let mut tx = tx;
720 let query = format!(
721 "
722 SELECT users.*
723 FROM users
724 WHERE invite_count = 0
725 AND inviter_id IS{} NULL
726 ",
727 if invited_by_another_user { " NOT" } else { "" }
728 );
729
730 Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
731 })
732 .await
733 }
734
735 pub async fn get_user_by_github_account(
736 &self,
737 github_login: &str,
738 github_user_id: Option<i32>,
739 ) -> Result<Option<User>> {
740 self.transact(|tx| async {
741 let mut tx = tx;
742 if let Some(github_user_id) = github_user_id {
743 let mut user = sqlx::query_as::<_, User>(
744 "
745 UPDATE users
746 SET github_login = $1
747 WHERE github_user_id = $2
748 RETURNING *
749 ",
750 )
751 .bind(github_login)
752 .bind(github_user_id)
753 .fetch_optional(&mut tx)
754 .await?;
755
756 if user.is_none() {
757 user = sqlx::query_as::<_, User>(
758 "
759 UPDATE users
760 SET github_user_id = $1
761 WHERE github_login = $2
762 RETURNING *
763 ",
764 )
765 .bind(github_user_id)
766 .bind(github_login)
767 .fetch_optional(&mut tx)
768 .await?;
769 }
770
771 Ok(user)
772 } else {
773 let user = sqlx::query_as(
774 "
775 SELECT * FROM users
776 WHERE github_login = $1
777 LIMIT 1
778 ",
779 )
780 .bind(github_login)
781 .fetch_optional(&mut tx)
782 .await?;
783 Ok(user)
784 }
785 })
786 .await
787 }
788
789 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
790 self.transact(|mut tx| async {
791 let query = "UPDATE users SET admin = $1 WHERE id = $2";
792 sqlx::query(query)
793 .bind(is_admin)
794 .bind(id.0)
795 .execute(&mut tx)
796 .await?;
797 tx.commit().await?;
798 Ok(())
799 })
800 .await
801 }
802
803 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
804 self.transact(|mut tx| async move {
805 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
806 sqlx::query(query)
807 .bind(connected_once)
808 .bind(id.0)
809 .execute(&mut tx)
810 .await?;
811 tx.commit().await?;
812 Ok(())
813 })
814 .await
815 }
816
817 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
818 self.transact(|mut tx| async move {
819 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
820 sqlx::query(query)
821 .bind(id.0)
822 .execute(&mut tx)
823 .await
824 .map(drop)?;
825 let query = "DELETE FROM users WHERE id = $1;";
826 sqlx::query(query).bind(id.0).execute(&mut tx).await?;
827 tx.commit().await?;
828 Ok(())
829 })
830 .await
831 }
832
833 // signups
834
835 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
836 self.transact(|mut tx| async move {
837 Ok(sqlx::query_as(
838 "
839 SELECT
840 COUNT(*) as count,
841 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
842 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
843 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
844 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
845 FROM (
846 SELECT *
847 FROM signups
848 WHERE
849 NOT email_confirmation_sent
850 ) AS unsent
851 ",
852 )
853 .fetch_one(&mut tx)
854 .await?)
855 })
856 .await
857 }
858
859 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
860 self.transact(|mut tx| async move {
861 Ok(sqlx::query_as(
862 "
863 SELECT
864 email_address, email_confirmation_code
865 FROM signups
866 WHERE
867 NOT email_confirmation_sent AND
868 (platform_mac OR platform_unknown)
869 LIMIT $1
870 ",
871 )
872 .bind(count as i32)
873 .fetch_all(&mut tx)
874 .await?)
875 })
876 .await
877 }
878
879 // invite codes
880
881 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
882 self.transact(|mut tx| async move {
883 if count > 0 {
884 sqlx::query(
885 "
886 UPDATE users
887 SET invite_code = $1
888 WHERE id = $2 AND invite_code IS NULL
889 ",
890 )
891 .bind(random_invite_code())
892 .bind(id)
893 .execute(&mut tx)
894 .await?;
895 }
896
897 sqlx::query(
898 "
899 UPDATE users
900 SET invite_count = $1
901 WHERE id = $2
902 ",
903 )
904 .bind(count as i32)
905 .bind(id)
906 .execute(&mut tx)
907 .await?;
908 tx.commit().await?;
909 Ok(())
910 })
911 .await
912 }
913
914 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
915 self.transact(|mut tx| async move {
916 let result: Option<(String, i32)> = sqlx::query_as(
917 "
918 SELECT invite_code, invite_count
919 FROM users
920 WHERE id = $1 AND invite_code IS NOT NULL
921 ",
922 )
923 .bind(id)
924 .fetch_optional(&mut tx)
925 .await?;
926 if let Some((code, count)) = result {
927 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
928 } else {
929 Ok(None)
930 }
931 })
932 .await
933 }
934
935 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
936 self.transact(|tx| async {
937 let mut tx = tx;
938 sqlx::query_as(
939 "
940 SELECT *
941 FROM users
942 WHERE invite_code = $1
943 ",
944 )
945 .bind(code)
946 .fetch_optional(&mut tx)
947 .await?
948 .ok_or_else(|| {
949 Error::Http(
950 StatusCode::NOT_FOUND,
951 "that invite code does not exist".to_string(),
952 )
953 })
954 })
955 .await
956 }
957
958 async fn commit_room_transaction<'a, T>(
959 &'a self,
960 room_id: RoomId,
961 tx: sqlx::Transaction<'static, D>,
962 data: T,
963 ) -> Result<RoomGuard<T>> {
964 let lock = self.rooms.entry(room_id).or_default().clone();
965 let _guard = lock.lock_owned().await;
966 tx.commit().await?;
967 Ok(RoomGuard {
968 data,
969 _guard,
970 _not_send: PhantomData,
971 })
972 }
973
974 pub async fn create_room(
975 &self,
976 user_id: UserId,
977 connection_id: ConnectionId,
978 live_kit_room: &str,
979 ) -> Result<RoomGuard<proto::Room>> {
980 self.transact(|mut tx| async move {
981 let room_id = sqlx::query_scalar(
982 "
983 INSERT INTO rooms (live_kit_room)
984 VALUES ($1)
985 RETURNING id
986 ",
987 )
988 .bind(&live_kit_room)
989 .fetch_one(&mut tx)
990 .await
991 .map(RoomId)?;
992
993 sqlx::query(
994 "
995 INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
996 VALUES ($1, $2, $3, $4, $5)
997 ",
998 )
999 .bind(room_id)
1000 .bind(user_id)
1001 .bind(connection_id.0 as i32)
1002 .bind(user_id)
1003 .bind(connection_id.0 as i32)
1004 .execute(&mut tx)
1005 .await?;
1006
1007 let room = self.get_room(room_id, &mut tx).await?;
1008 self.commit_room_transaction(room_id, tx, room).await
1009 }).await
1010 }
1011
1012 pub async fn call(
1013 &self,
1014 room_id: RoomId,
1015 calling_user_id: UserId,
1016 calling_connection_id: ConnectionId,
1017 called_user_id: UserId,
1018 initial_project_id: Option<ProjectId>,
1019 ) -> Result<RoomGuard<(proto::Room, proto::IncomingCall)>> {
1020 self.transact(|mut tx| async move {
1021 sqlx::query(
1022 "
1023 INSERT INTO room_participants (
1024 room_id,
1025 user_id,
1026 calling_user_id,
1027 calling_connection_id,
1028 initial_project_id
1029 )
1030 VALUES ($1, $2, $3, $4, $5)
1031 ",
1032 )
1033 .bind(room_id)
1034 .bind(called_user_id)
1035 .bind(calling_user_id)
1036 .bind(calling_connection_id.0 as i32)
1037 .bind(initial_project_id)
1038 .execute(&mut tx)
1039 .await?;
1040
1041 let room = self.get_room(room_id, &mut tx).await?;
1042 let incoming_call = Self::build_incoming_call(&room, called_user_id)
1043 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
1044 self.commit_room_transaction(room_id, tx, (room, incoming_call))
1045 .await
1046 })
1047 .await
1048 }
1049
1050 pub async fn incoming_call_for_user(
1051 &self,
1052 user_id: UserId,
1053 ) -> Result<Option<proto::IncomingCall>> {
1054 self.transact(|mut tx| async move {
1055 let room_id = sqlx::query_scalar::<_, RoomId>(
1056 "
1057 SELECT room_id
1058 FROM room_participants
1059 WHERE user_id = $1 AND answering_connection_id IS NULL
1060 ",
1061 )
1062 .bind(user_id)
1063 .fetch_optional(&mut tx)
1064 .await?;
1065
1066 if let Some(room_id) = room_id {
1067 let room = self.get_room(room_id, &mut tx).await?;
1068 Ok(Self::build_incoming_call(&room, user_id))
1069 } else {
1070 Ok(None)
1071 }
1072 })
1073 .await
1074 }
1075
1076 fn build_incoming_call(
1077 room: &proto::Room,
1078 called_user_id: UserId,
1079 ) -> Option<proto::IncomingCall> {
1080 let pending_participant = room
1081 .pending_participants
1082 .iter()
1083 .find(|participant| participant.user_id == called_user_id.to_proto())?;
1084
1085 Some(proto::IncomingCall {
1086 room_id: room.id,
1087 calling_user_id: pending_participant.calling_user_id,
1088 participant_user_ids: room
1089 .participants
1090 .iter()
1091 .map(|participant| participant.user_id)
1092 .collect(),
1093 initial_project: room.participants.iter().find_map(|participant| {
1094 let initial_project_id = pending_participant.initial_project_id?;
1095 participant
1096 .projects
1097 .iter()
1098 .find(|project| project.id == initial_project_id)
1099 .cloned()
1100 }),
1101 })
1102 }
1103
1104 pub async fn call_failed(
1105 &self,
1106 room_id: RoomId,
1107 called_user_id: UserId,
1108 ) -> Result<RoomGuard<proto::Room>> {
1109 self.transact(|mut tx| async move {
1110 sqlx::query(
1111 "
1112 DELETE FROM room_participants
1113 WHERE room_id = $1 AND user_id = $2
1114 ",
1115 )
1116 .bind(room_id)
1117 .bind(called_user_id)
1118 .execute(&mut tx)
1119 .await?;
1120
1121 let room = self.get_room(room_id, &mut tx).await?;
1122 self.commit_room_transaction(room_id, tx, room).await
1123 })
1124 .await
1125 }
1126
1127 pub async fn decline_call(
1128 &self,
1129 expected_room_id: Option<RoomId>,
1130 user_id: UserId,
1131 ) -> Result<RoomGuard<proto::Room>> {
1132 self.transact(|mut tx| async move {
1133 let room_id = sqlx::query_scalar(
1134 "
1135 DELETE FROM room_participants
1136 WHERE user_id = $1 AND answering_connection_id IS NULL
1137 RETURNING room_id
1138 ",
1139 )
1140 .bind(user_id)
1141 .fetch_one(&mut tx)
1142 .await?;
1143 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1144 return Err(anyhow!("declining call on unexpected room"))?;
1145 }
1146
1147 let room = self.get_room(room_id, &mut tx).await?;
1148 self.commit_room_transaction(room_id, tx, room).await
1149 })
1150 .await
1151 }
1152
1153 pub async fn cancel_call(
1154 &self,
1155 expected_room_id: Option<RoomId>,
1156 calling_connection_id: ConnectionId,
1157 called_user_id: UserId,
1158 ) -> Result<RoomGuard<proto::Room>> {
1159 self.transact(|mut tx| async move {
1160 let room_id = sqlx::query_scalar(
1161 "
1162 DELETE FROM room_participants
1163 WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1164 RETURNING room_id
1165 ",
1166 )
1167 .bind(called_user_id)
1168 .bind(calling_connection_id.0 as i32)
1169 .fetch_one(&mut tx)
1170 .await?;
1171 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1172 return Err(anyhow!("canceling call on unexpected room"))?;
1173 }
1174
1175 let room = self.get_room(room_id, &mut tx).await?;
1176 self.commit_room_transaction(room_id, tx, room).await
1177 }).await
1178 }
1179
1180 pub async fn join_room(
1181 &self,
1182 room_id: RoomId,
1183 user_id: UserId,
1184 connection_id: ConnectionId,
1185 ) -> Result<RoomGuard<proto::Room>> {
1186 self.transact(|mut tx| async move {
1187 sqlx::query(
1188 "
1189 UPDATE room_participants
1190 SET answering_connection_id = $1
1191 WHERE room_id = $2 AND user_id = $3
1192 RETURNING 1
1193 ",
1194 )
1195 .bind(connection_id.0 as i32)
1196 .bind(room_id)
1197 .bind(user_id)
1198 .fetch_one(&mut tx)
1199 .await?;
1200
1201 let room = self.get_room(room_id, &mut tx).await?;
1202 self.commit_room_transaction(room_id, tx, room).await
1203 })
1204 .await
1205 }
1206
1207 pub async fn leave_room(
1208 &self,
1209 connection_id: ConnectionId,
1210 ) -> Result<Option<RoomGuard<LeftRoom>>> {
1211 self.transact(|mut tx| async move {
1212 // Leave room.
1213 let room_id = sqlx::query_scalar::<_, RoomId>(
1214 "
1215 DELETE FROM room_participants
1216 WHERE answering_connection_id = $1
1217 RETURNING room_id
1218 ",
1219 )
1220 .bind(connection_id.0 as i32)
1221 .fetch_optional(&mut tx)
1222 .await?;
1223
1224 if let Some(room_id) = room_id {
1225 // Cancel pending calls initiated by the leaving user.
1226 let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1227 "
1228 DELETE FROM room_participants
1229 WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1230 RETURNING user_id
1231 ",
1232 )
1233 .bind(connection_id.0 as i32)
1234 .fetch_all(&mut tx)
1235 .await?;
1236
1237 let project_ids = sqlx::query_scalar::<_, ProjectId>(
1238 "
1239 SELECT project_id
1240 FROM project_collaborators
1241 WHERE connection_id = $1
1242 ",
1243 )
1244 .bind(connection_id.0 as i32)
1245 .fetch_all(&mut tx)
1246 .await?;
1247
1248 // Leave projects.
1249 let mut left_projects = HashMap::default();
1250 if !project_ids.is_empty() {
1251 let mut params = "?,".repeat(project_ids.len());
1252 params.pop();
1253 let query = format!(
1254 "
1255 SELECT *
1256 FROM project_collaborators
1257 WHERE project_id IN ({params})
1258 "
1259 );
1260 let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query);
1261 for project_id in project_ids {
1262 query = query.bind(project_id);
1263 }
1264
1265 let mut project_collaborators = query.fetch(&mut tx);
1266 while let Some(collaborator) = project_collaborators.next().await {
1267 let collaborator = collaborator?;
1268 let left_project =
1269 left_projects
1270 .entry(collaborator.project_id)
1271 .or_insert(LeftProject {
1272 id: collaborator.project_id,
1273 host_user_id: Default::default(),
1274 connection_ids: Default::default(),
1275 host_connection_id: Default::default(),
1276 });
1277
1278 let collaborator_connection_id =
1279 ConnectionId(collaborator.connection_id as u32);
1280 if collaborator_connection_id != connection_id {
1281 left_project.connection_ids.push(collaborator_connection_id);
1282 }
1283
1284 if collaborator.is_host {
1285 left_project.host_user_id = collaborator.user_id;
1286 left_project.host_connection_id =
1287 ConnectionId(collaborator.connection_id as u32);
1288 }
1289 }
1290 }
1291 sqlx::query(
1292 "
1293 DELETE FROM project_collaborators
1294 WHERE connection_id = $1
1295 ",
1296 )
1297 .bind(connection_id.0 as i32)
1298 .execute(&mut tx)
1299 .await?;
1300
1301 // Unshare projects.
1302 sqlx::query(
1303 "
1304 DELETE FROM projects
1305 WHERE room_id = $1 AND host_connection_id = $2
1306 ",
1307 )
1308 .bind(room_id)
1309 .bind(connection_id.0 as i32)
1310 .execute(&mut tx)
1311 .await?;
1312
1313 let room = self.get_room(room_id, &mut tx).await?;
1314 Ok(Some(
1315 self.commit_room_transaction(
1316 room_id,
1317 tx,
1318 LeftRoom {
1319 room,
1320 left_projects,
1321 canceled_calls_to_user_ids,
1322 },
1323 )
1324 .await?,
1325 ))
1326 } else {
1327 Ok(None)
1328 }
1329 })
1330 .await
1331 }
1332
1333 pub async fn update_room_participant_location(
1334 &self,
1335 room_id: RoomId,
1336 connection_id: ConnectionId,
1337 location: proto::ParticipantLocation,
1338 ) -> Result<RoomGuard<proto::Room>> {
1339 self.transact(|tx| async {
1340 let mut tx = tx;
1341 let location_kind;
1342 let location_project_id;
1343 match location
1344 .variant
1345 .as_ref()
1346 .ok_or_else(|| anyhow!("invalid location"))?
1347 {
1348 proto::participant_location::Variant::SharedProject(project) => {
1349 location_kind = 0;
1350 location_project_id = Some(ProjectId::from_proto(project.id));
1351 }
1352 proto::participant_location::Variant::UnsharedProject(_) => {
1353 location_kind = 1;
1354 location_project_id = None;
1355 }
1356 proto::participant_location::Variant::External(_) => {
1357 location_kind = 2;
1358 location_project_id = None;
1359 }
1360 }
1361
1362 sqlx::query(
1363 "
1364 UPDATE room_participants
1365 SET location_kind = $1, location_project_id = $2
1366 WHERE room_id = $3 AND answering_connection_id = $4
1367 RETURNING 1
1368 ",
1369 )
1370 .bind(location_kind)
1371 .bind(location_project_id)
1372 .bind(room_id)
1373 .bind(connection_id.0 as i32)
1374 .fetch_one(&mut tx)
1375 .await?;
1376
1377 let room = self.get_room(room_id, &mut tx).await?;
1378 self.commit_room_transaction(room_id, tx, room).await
1379 })
1380 .await
1381 }
1382
1383 async fn get_guest_connection_ids(
1384 &self,
1385 project_id: ProjectId,
1386 tx: &mut sqlx::Transaction<'_, D>,
1387 ) -> Result<Vec<ConnectionId>> {
1388 let mut guest_connection_ids = Vec::new();
1389 let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1390 "
1391 SELECT connection_id
1392 FROM project_collaborators
1393 WHERE project_id = $1 AND is_host = FALSE
1394 ",
1395 )
1396 .bind(project_id)
1397 .fetch(tx);
1398 while let Some(connection_id) = db_guest_connection_ids.next().await {
1399 guest_connection_ids.push(ConnectionId(connection_id? as u32));
1400 }
1401 Ok(guest_connection_ids)
1402 }
1403
1404 async fn get_room(
1405 &self,
1406 room_id: RoomId,
1407 tx: &mut sqlx::Transaction<'_, D>,
1408 ) -> Result<proto::Room> {
1409 let room: Room = sqlx::query_as(
1410 "
1411 SELECT *
1412 FROM rooms
1413 WHERE id = $1
1414 ",
1415 )
1416 .bind(room_id)
1417 .fetch_one(&mut *tx)
1418 .await?;
1419
1420 let mut db_participants =
1421 sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1422 "
1423 SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1424 FROM room_participants
1425 WHERE room_id = $1
1426 ",
1427 )
1428 .bind(room_id)
1429 .fetch(&mut *tx);
1430
1431 let mut participants = HashMap::default();
1432 let mut pending_participants = Vec::new();
1433 while let Some(participant) = db_participants.next().await {
1434 let (
1435 user_id,
1436 answering_connection_id,
1437 location_kind,
1438 location_project_id,
1439 calling_user_id,
1440 initial_project_id,
1441 ) = participant?;
1442 if let Some(answering_connection_id) = answering_connection_id {
1443 let location = match (location_kind, location_project_id) {
1444 (Some(0), Some(project_id)) => {
1445 Some(proto::participant_location::Variant::SharedProject(
1446 proto::participant_location::SharedProject {
1447 id: project_id.to_proto(),
1448 },
1449 ))
1450 }
1451 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
1452 Default::default(),
1453 )),
1454 _ => Some(proto::participant_location::Variant::External(
1455 Default::default(),
1456 )),
1457 };
1458 participants.insert(
1459 answering_connection_id,
1460 proto::Participant {
1461 user_id: user_id.to_proto(),
1462 peer_id: answering_connection_id as u32,
1463 projects: Default::default(),
1464 location: Some(proto::ParticipantLocation { variant: location }),
1465 },
1466 );
1467 } else {
1468 pending_participants.push(proto::PendingParticipant {
1469 user_id: user_id.to_proto(),
1470 calling_user_id: calling_user_id.to_proto(),
1471 initial_project_id: initial_project_id.map(|id| id.to_proto()),
1472 });
1473 }
1474 }
1475 drop(db_participants);
1476
1477 let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option<String>)>(
1478 "
1479 SELECT host_connection_id, projects.id, worktrees.root_name
1480 FROM projects
1481 LEFT JOIN worktrees ON projects.id = worktrees.project_id
1482 WHERE room_id = $1
1483 ",
1484 )
1485 .bind(room_id)
1486 .fetch(&mut *tx);
1487
1488 while let Some(row) = rows.next().await {
1489 let (connection_id, project_id, worktree_root_name) = row?;
1490 if let Some(participant) = participants.get_mut(&connection_id) {
1491 let project = if let Some(project) = participant
1492 .projects
1493 .iter_mut()
1494 .find(|project| project.id == project_id.to_proto())
1495 {
1496 project
1497 } else {
1498 participant.projects.push(proto::ParticipantProject {
1499 id: project_id.to_proto(),
1500 worktree_root_names: Default::default(),
1501 });
1502 participant.projects.last_mut().unwrap()
1503 };
1504 project.worktree_root_names.extend(worktree_root_name);
1505 }
1506 }
1507
1508 Ok(proto::Room {
1509 id: room.id.to_proto(),
1510 live_kit_room: room.live_kit_room,
1511 participants: participants.into_values().collect(),
1512 pending_participants,
1513 })
1514 }
1515
1516 // projects
1517
1518 pub async fn project_count_excluding_admins(&self) -> Result<usize> {
1519 self.transact(|mut tx| async move {
1520 Ok(sqlx::query_scalar::<_, i32>(
1521 "
1522 SELECT COUNT(*)
1523 FROM projects, users
1524 WHERE projects.host_user_id = users.id AND users.admin IS FALSE
1525 ",
1526 )
1527 .fetch_one(&mut tx)
1528 .await? as usize)
1529 })
1530 .await
1531 }
1532
1533 pub async fn share_project(
1534 &self,
1535 expected_room_id: RoomId,
1536 connection_id: ConnectionId,
1537 worktrees: &[proto::WorktreeMetadata],
1538 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
1539 self.transact(|mut tx| async move {
1540 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1541 "
1542 SELECT room_id, user_id
1543 FROM room_participants
1544 WHERE answering_connection_id = $1
1545 ",
1546 )
1547 .bind(connection_id.0 as i32)
1548 .fetch_one(&mut tx)
1549 .await?;
1550 if room_id != expected_room_id {
1551 return Err(anyhow!("shared project on unexpected room"))?;
1552 }
1553
1554 let project_id: ProjectId = sqlx::query_scalar(
1555 "
1556 INSERT INTO projects (room_id, host_user_id, host_connection_id)
1557 VALUES ($1, $2, $3)
1558 RETURNING id
1559 ",
1560 )
1561 .bind(room_id)
1562 .bind(user_id)
1563 .bind(connection_id.0 as i32)
1564 .fetch_one(&mut tx)
1565 .await?;
1566
1567 if !worktrees.is_empty() {
1568 let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1569 params.pop();
1570 let query = format!(
1571 "
1572 INSERT INTO worktrees (
1573 project_id,
1574 id,
1575 root_name,
1576 abs_path,
1577 visible,
1578 scan_id,
1579 is_complete
1580 )
1581 VALUES {params}
1582 "
1583 );
1584
1585 let mut query = sqlx::query(&query);
1586 for worktree in worktrees {
1587 query = query
1588 .bind(project_id)
1589 .bind(worktree.id as i32)
1590 .bind(&worktree.root_name)
1591 .bind(&worktree.abs_path)
1592 .bind(worktree.visible)
1593 .bind(0)
1594 .bind(false);
1595 }
1596 query.execute(&mut tx).await?;
1597 }
1598
1599 sqlx::query(
1600 "
1601 INSERT INTO project_collaborators (
1602 project_id,
1603 connection_id,
1604 user_id,
1605 replica_id,
1606 is_host
1607 )
1608 VALUES ($1, $2, $3, $4, $5)
1609 ",
1610 )
1611 .bind(project_id)
1612 .bind(connection_id.0 as i32)
1613 .bind(user_id)
1614 .bind(0)
1615 .bind(true)
1616 .execute(&mut tx)
1617 .await?;
1618
1619 let room = self.get_room(room_id, &mut tx).await?;
1620 self.commit_room_transaction(room_id, tx, (project_id, room))
1621 .await
1622 })
1623 .await
1624 }
1625
1626 pub async fn unshare_project(
1627 &self,
1628 project_id: ProjectId,
1629 connection_id: ConnectionId,
1630 ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1631 self.transact(|mut tx| async move {
1632 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1633 let room_id: RoomId = sqlx::query_scalar(
1634 "
1635 DELETE FROM projects
1636 WHERE id = $1 AND host_connection_id = $2
1637 RETURNING room_id
1638 ",
1639 )
1640 .bind(project_id)
1641 .bind(connection_id.0 as i32)
1642 .fetch_one(&mut tx)
1643 .await?;
1644 let room = self.get_room(room_id, &mut tx).await?;
1645 self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
1646 .await
1647 })
1648 .await
1649 }
1650
1651 pub async fn update_project(
1652 &self,
1653 project_id: ProjectId,
1654 connection_id: ConnectionId,
1655 worktrees: &[proto::WorktreeMetadata],
1656 ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1657 self.transact(|mut tx| async move {
1658 let room_id: RoomId = sqlx::query_scalar(
1659 "
1660 SELECT room_id
1661 FROM projects
1662 WHERE id = $1 AND host_connection_id = $2
1663 ",
1664 )
1665 .bind(project_id)
1666 .bind(connection_id.0 as i32)
1667 .fetch_one(&mut tx)
1668 .await?;
1669
1670 if !worktrees.is_empty() {
1671 let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1672 params.pop();
1673 let query = format!(
1674 "
1675 INSERT INTO worktrees (
1676 project_id,
1677 id,
1678 root_name,
1679 abs_path,
1680 visible,
1681 scan_id,
1682 is_complete
1683 )
1684 VALUES {params}
1685 ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1686 "
1687 );
1688
1689 let mut query = sqlx::query(&query);
1690 for worktree in worktrees {
1691 query = query
1692 .bind(project_id)
1693 .bind(worktree.id as i32)
1694 .bind(&worktree.root_name)
1695 .bind(&worktree.abs_path)
1696 .bind(worktree.visible)
1697 .bind(0)
1698 .bind(false)
1699 }
1700 query.execute(&mut tx).await?;
1701 }
1702
1703 let mut params = "?,".repeat(worktrees.len());
1704 if !worktrees.is_empty() {
1705 params.pop();
1706 }
1707 let query = format!(
1708 "
1709 DELETE FROM worktrees
1710 WHERE project_id = ? AND id NOT IN ({params})
1711 ",
1712 );
1713
1714 let mut query = sqlx::query(&query).bind(project_id);
1715 for worktree in worktrees {
1716 query = query.bind(WorktreeId(worktree.id as i32));
1717 }
1718 query.execute(&mut tx).await?;
1719
1720 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1721 let room = self.get_room(room_id, &mut tx).await?;
1722 self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
1723 .await
1724 })
1725 .await
1726 }
1727
1728 pub async fn update_worktree(
1729 &self,
1730 update: &proto::UpdateWorktree,
1731 connection_id: ConnectionId,
1732 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1733 self.transact(|mut tx| async move {
1734 let project_id = ProjectId::from_proto(update.project_id);
1735 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1736
1737 // Ensure the update comes from the host.
1738 let room_id: RoomId = sqlx::query_scalar(
1739 "
1740 SELECT room_id
1741 FROM projects
1742 WHERE id = $1 AND host_connection_id = $2
1743 ",
1744 )
1745 .bind(project_id)
1746 .bind(connection_id.0 as i32)
1747 .fetch_one(&mut tx)
1748 .await?;
1749
1750 // Update metadata.
1751 sqlx::query(
1752 "
1753 UPDATE worktrees
1754 SET
1755 root_name = $1,
1756 scan_id = $2,
1757 is_complete = $3,
1758 abs_path = $4
1759 WHERE project_id = $5 AND id = $6
1760 RETURNING 1
1761 ",
1762 )
1763 .bind(&update.root_name)
1764 .bind(update.scan_id as i64)
1765 .bind(update.is_last_update)
1766 .bind(&update.abs_path)
1767 .bind(project_id)
1768 .bind(worktree_id)
1769 .fetch_one(&mut tx)
1770 .await?;
1771
1772 if !update.updated_entries.is_empty() {
1773 let mut params =
1774 "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len());
1775 params.pop();
1776
1777 let query = format!(
1778 "
1779 INSERT INTO worktree_entries (
1780 project_id,
1781 worktree_id,
1782 id,
1783 is_dir,
1784 path,
1785 inode,
1786 mtime_seconds,
1787 mtime_nanos,
1788 is_symlink,
1789 is_ignored
1790 )
1791 VALUES {params}
1792 ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET
1793 is_dir = excluded.is_dir,
1794 path = excluded.path,
1795 inode = excluded.inode,
1796 mtime_seconds = excluded.mtime_seconds,
1797 mtime_nanos = excluded.mtime_nanos,
1798 is_symlink = excluded.is_symlink,
1799 is_ignored = excluded.is_ignored
1800 "
1801 );
1802 let mut query = sqlx::query(&query);
1803 for entry in &update.updated_entries {
1804 let mtime = entry.mtime.clone().unwrap_or_default();
1805 query = query
1806 .bind(project_id)
1807 .bind(worktree_id)
1808 .bind(entry.id as i64)
1809 .bind(entry.is_dir)
1810 .bind(&entry.path)
1811 .bind(entry.inode as i64)
1812 .bind(mtime.seconds as i64)
1813 .bind(mtime.nanos as i32)
1814 .bind(entry.is_symlink)
1815 .bind(entry.is_ignored);
1816 }
1817 query.execute(&mut tx).await?;
1818 }
1819
1820 if !update.removed_entries.is_empty() {
1821 let mut params = "?,".repeat(update.removed_entries.len());
1822 params.pop();
1823 let query = format!(
1824 "
1825 DELETE FROM worktree_entries
1826 WHERE project_id = ? AND worktree_id = ? AND id IN ({params})
1827 "
1828 );
1829
1830 let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id);
1831 for entry_id in &update.removed_entries {
1832 query = query.bind(*entry_id as i64);
1833 }
1834 query.execute(&mut tx).await?;
1835 }
1836
1837 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1838 self.commit_room_transaction(room_id, tx, connection_ids)
1839 .await
1840 })
1841 .await
1842 }
1843
1844 pub async fn update_diagnostic_summary(
1845 &self,
1846 update: &proto::UpdateDiagnosticSummary,
1847 connection_id: ConnectionId,
1848 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1849 self.transact(|mut tx| async {
1850 let project_id = ProjectId::from_proto(update.project_id);
1851 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1852 let summary = update
1853 .summary
1854 .as_ref()
1855 .ok_or_else(|| anyhow!("invalid summary"))?;
1856
1857 // Ensure the update comes from the host.
1858 let room_id: RoomId = sqlx::query_scalar(
1859 "
1860 SELECT room_id
1861 FROM projects
1862 WHERE id = $1 AND host_connection_id = $2
1863 ",
1864 )
1865 .bind(project_id)
1866 .bind(connection_id.0 as i32)
1867 .fetch_one(&mut tx)
1868 .await?;
1869
1870 // Update summary.
1871 sqlx::query(
1872 "
1873 INSERT INTO worktree_diagnostic_summaries (
1874 project_id,
1875 worktree_id,
1876 path,
1877 language_server_id,
1878 error_count,
1879 warning_count
1880 )
1881 VALUES ($1, $2, $3, $4, $5, $6)
1882 ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET
1883 language_server_id = excluded.language_server_id,
1884 error_count = excluded.error_count,
1885 warning_count = excluded.warning_count
1886 ",
1887 )
1888 .bind(project_id)
1889 .bind(worktree_id)
1890 .bind(&summary.path)
1891 .bind(summary.language_server_id as i64)
1892 .bind(summary.error_count as i32)
1893 .bind(summary.warning_count as i32)
1894 .execute(&mut tx)
1895 .await?;
1896
1897 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1898 self.commit_room_transaction(room_id, tx, connection_ids)
1899 .await
1900 })
1901 .await
1902 }
1903
1904 pub async fn start_language_server(
1905 &self,
1906 update: &proto::StartLanguageServer,
1907 connection_id: ConnectionId,
1908 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1909 self.transact(|mut tx| async {
1910 let project_id = ProjectId::from_proto(update.project_id);
1911 let server = update
1912 .server
1913 .as_ref()
1914 .ok_or_else(|| anyhow!("invalid language server"))?;
1915
1916 // Ensure the update comes from the host.
1917 let room_id: RoomId = sqlx::query_scalar(
1918 "
1919 SELECT room_id
1920 FROM projects
1921 WHERE id = $1 AND host_connection_id = $2
1922 ",
1923 )
1924 .bind(project_id)
1925 .bind(connection_id.0 as i32)
1926 .fetch_one(&mut tx)
1927 .await?;
1928
1929 // Add the newly-started language server.
1930 sqlx::query(
1931 "
1932 INSERT INTO language_servers (project_id, id, name)
1933 VALUES ($1, $2, $3)
1934 ON CONFLICT (project_id, id) DO UPDATE SET
1935 name = excluded.name
1936 ",
1937 )
1938 .bind(project_id)
1939 .bind(server.id as i64)
1940 .bind(&server.name)
1941 .execute(&mut tx)
1942 .await?;
1943
1944 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1945 self.commit_room_transaction(room_id, tx, connection_ids)
1946 .await
1947 })
1948 .await
1949 }
1950
1951 pub async fn join_project(
1952 &self,
1953 project_id: ProjectId,
1954 connection_id: ConnectionId,
1955 ) -> Result<RoomGuard<(Project, ReplicaId)>> {
1956 self.transact(|mut tx| async move {
1957 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1958 "
1959 SELECT room_id, user_id
1960 FROM room_participants
1961 WHERE answering_connection_id = $1
1962 ",
1963 )
1964 .bind(connection_id.0 as i32)
1965 .fetch_one(&mut tx)
1966 .await?;
1967
1968 // Ensure project id was shared on this room.
1969 sqlx::query(
1970 "
1971 SELECT 1
1972 FROM projects
1973 WHERE id = $1 AND room_id = $2
1974 ",
1975 )
1976 .bind(project_id)
1977 .bind(room_id)
1978 .fetch_one(&mut tx)
1979 .await?;
1980
1981 let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>(
1982 "
1983 SELECT *
1984 FROM project_collaborators
1985 WHERE project_id = $1
1986 ",
1987 )
1988 .bind(project_id)
1989 .fetch_all(&mut tx)
1990 .await?;
1991 let replica_ids = collaborators
1992 .iter()
1993 .map(|c| c.replica_id)
1994 .collect::<HashSet<_>>();
1995 let mut replica_id = ReplicaId(1);
1996 while replica_ids.contains(&replica_id) {
1997 replica_id.0 += 1;
1998 }
1999 let new_collaborator = ProjectCollaborator {
2000 project_id,
2001 connection_id: connection_id.0 as i32,
2002 user_id,
2003 replica_id,
2004 is_host: false,
2005 };
2006
2007 sqlx::query(
2008 "
2009 INSERT INTO project_collaborators (
2010 project_id,
2011 connection_id,
2012 user_id,
2013 replica_id,
2014 is_host
2015 )
2016 VALUES ($1, $2, $3, $4, $5)
2017 ",
2018 )
2019 .bind(new_collaborator.project_id)
2020 .bind(new_collaborator.connection_id)
2021 .bind(new_collaborator.user_id)
2022 .bind(new_collaborator.replica_id)
2023 .bind(new_collaborator.is_host)
2024 .execute(&mut tx)
2025 .await?;
2026 collaborators.push(new_collaborator);
2027
2028 let worktree_rows = sqlx::query_as::<_, WorktreeRow>(
2029 "
2030 SELECT *
2031 FROM worktrees
2032 WHERE project_id = $1
2033 ",
2034 )
2035 .bind(project_id)
2036 .fetch_all(&mut tx)
2037 .await?;
2038 let mut worktrees = worktree_rows
2039 .into_iter()
2040 .map(|worktree_row| {
2041 (
2042 worktree_row.id,
2043 Worktree {
2044 id: worktree_row.id,
2045 abs_path: worktree_row.abs_path,
2046 root_name: worktree_row.root_name,
2047 visible: worktree_row.visible,
2048 entries: Default::default(),
2049 diagnostic_summaries: Default::default(),
2050 scan_id: worktree_row.scan_id as u64,
2051 is_complete: worktree_row.is_complete,
2052 },
2053 )
2054 })
2055 .collect::<BTreeMap<_, _>>();
2056
2057 // Populate worktree entries.
2058 {
2059 let mut entries = sqlx::query_as::<_, WorktreeEntry>(
2060 "
2061 SELECT *
2062 FROM worktree_entries
2063 WHERE project_id = $1
2064 ",
2065 )
2066 .bind(project_id)
2067 .fetch(&mut tx);
2068 while let Some(entry) = entries.next().await {
2069 let entry = entry?;
2070 if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) {
2071 worktree.entries.push(proto::Entry {
2072 id: entry.id as u64,
2073 is_dir: entry.is_dir,
2074 path: entry.path,
2075 inode: entry.inode as u64,
2076 mtime: Some(proto::Timestamp {
2077 seconds: entry.mtime_seconds as u64,
2078 nanos: entry.mtime_nanos as u32,
2079 }),
2080 is_symlink: entry.is_symlink,
2081 is_ignored: entry.is_ignored,
2082 });
2083 }
2084 }
2085 }
2086
2087 // Populate worktree diagnostic summaries.
2088 {
2089 let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(
2090 "
2091 SELECT *
2092 FROM worktree_diagnostic_summaries
2093 WHERE project_id = $1
2094 ",
2095 )
2096 .bind(project_id)
2097 .fetch(&mut tx);
2098 while let Some(summary) = summaries.next().await {
2099 let summary = summary?;
2100 if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) {
2101 worktree
2102 .diagnostic_summaries
2103 .push(proto::DiagnosticSummary {
2104 path: summary.path,
2105 language_server_id: summary.language_server_id as u64,
2106 error_count: summary.error_count as u32,
2107 warning_count: summary.warning_count as u32,
2108 });
2109 }
2110 }
2111 }
2112
2113 // Populate language servers.
2114 let language_servers = sqlx::query_as::<_, LanguageServer>(
2115 "
2116 SELECT *
2117 FROM language_servers
2118 WHERE project_id = $1
2119 ",
2120 )
2121 .bind(project_id)
2122 .fetch_all(&mut tx)
2123 .await?;
2124
2125 self.commit_room_transaction(
2126 room_id,
2127 tx,
2128 (
2129 Project {
2130 collaborators,
2131 worktrees,
2132 language_servers: language_servers
2133 .into_iter()
2134 .map(|language_server| proto::LanguageServer {
2135 id: language_server.id.to_proto(),
2136 name: language_server.name,
2137 })
2138 .collect(),
2139 },
2140 replica_id as ReplicaId,
2141 ),
2142 )
2143 .await
2144 })
2145 .await
2146 }
2147
2148 pub async fn leave_project(
2149 &self,
2150 project_id: ProjectId,
2151 connection_id: ConnectionId,
2152 ) -> Result<RoomGuard<LeftProject>> {
2153 self.transact(|mut tx| async move {
2154 let result = sqlx::query(
2155 "
2156 DELETE FROM project_collaborators
2157 WHERE project_id = $1 AND connection_id = $2
2158 ",
2159 )
2160 .bind(project_id)
2161 .bind(connection_id.0 as i32)
2162 .execute(&mut tx)
2163 .await?;
2164
2165 if result.rows_affected() == 0 {
2166 Err(anyhow!("not a collaborator on this project"))?;
2167 }
2168
2169 let connection_ids = sqlx::query_scalar::<_, i32>(
2170 "
2171 SELECT connection_id
2172 FROM project_collaborators
2173 WHERE project_id = $1
2174 ",
2175 )
2176 .bind(project_id)
2177 .fetch_all(&mut tx)
2178 .await?
2179 .into_iter()
2180 .map(|id| ConnectionId(id as u32))
2181 .collect();
2182
2183 let (room_id, host_user_id, host_connection_id) =
2184 sqlx::query_as::<_, (RoomId, i32, i32)>(
2185 "
2186 SELECT room_id, host_user_id, host_connection_id
2187 FROM projects
2188 WHERE id = $1
2189 ",
2190 )
2191 .bind(project_id)
2192 .fetch_one(&mut tx)
2193 .await?;
2194
2195 self.commit_room_transaction(
2196 room_id,
2197 tx,
2198 LeftProject {
2199 id: project_id,
2200 host_user_id: UserId(host_user_id),
2201 host_connection_id: ConnectionId(host_connection_id as u32),
2202 connection_ids,
2203 },
2204 )
2205 .await
2206 })
2207 .await
2208 }
2209
2210 pub async fn project_collaborators(
2211 &self,
2212 project_id: ProjectId,
2213 connection_id: ConnectionId,
2214 ) -> Result<Vec<ProjectCollaborator>> {
2215 self.transact(|mut tx| async move {
2216 let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2217 "
2218 SELECT *
2219 FROM project_collaborators
2220 WHERE project_id = $1
2221 ",
2222 )
2223 .bind(project_id)
2224 .fetch_all(&mut tx)
2225 .await?;
2226
2227 if collaborators
2228 .iter()
2229 .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
2230 {
2231 Ok(collaborators)
2232 } else {
2233 Err(anyhow!("no such project"))?
2234 }
2235 })
2236 .await
2237 }
2238
2239 pub async fn project_connection_ids(
2240 &self,
2241 project_id: ProjectId,
2242 connection_id: ConnectionId,
2243 ) -> Result<HashSet<ConnectionId>> {
2244 self.transact(|mut tx| async move {
2245 let connection_ids = sqlx::query_scalar::<_, i32>(
2246 "
2247 SELECT connection_id
2248 FROM project_collaborators
2249 WHERE project_id = $1
2250 ",
2251 )
2252 .bind(project_id)
2253 .fetch_all(&mut tx)
2254 .await?;
2255
2256 if connection_ids.contains(&(connection_id.0 as i32)) {
2257 Ok(connection_ids
2258 .into_iter()
2259 .map(|connection_id| ConnectionId(connection_id as u32))
2260 .collect())
2261 } else {
2262 Err(anyhow!("no such project"))?
2263 }
2264 })
2265 .await
2266 }
2267
2268 // contacts
2269
2270 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
2271 self.transact(|mut tx| async move {
2272 let query = "
2273 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
2274 FROM contacts
2275 LEFT JOIN room_participants ON room_participants.user_id = $1
2276 WHERE user_id_a = $1 OR user_id_b = $1;
2277 ";
2278
2279 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
2280 .bind(user_id)
2281 .fetch(&mut tx);
2282
2283 let mut contacts = Vec::new();
2284 while let Some(row) = rows.next().await {
2285 let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
2286 if user_id_a == user_id {
2287 if accepted {
2288 contacts.push(Contact::Accepted {
2289 user_id: user_id_b,
2290 should_notify: should_notify && a_to_b,
2291 busy
2292 });
2293 } else if a_to_b {
2294 contacts.push(Contact::Outgoing { user_id: user_id_b })
2295 } else {
2296 contacts.push(Contact::Incoming {
2297 user_id: user_id_b,
2298 should_notify,
2299 });
2300 }
2301 } else if accepted {
2302 contacts.push(Contact::Accepted {
2303 user_id: user_id_a,
2304 should_notify: should_notify && !a_to_b,
2305 busy
2306 });
2307 } else if a_to_b {
2308 contacts.push(Contact::Incoming {
2309 user_id: user_id_a,
2310 should_notify,
2311 });
2312 } else {
2313 contacts.push(Contact::Outgoing { user_id: user_id_a });
2314 }
2315 }
2316
2317 contacts.sort_unstable_by_key(|contact| contact.user_id());
2318
2319 Ok(contacts)
2320 })
2321 .await
2322 }
2323
2324 pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
2325 self.transact(|mut tx| async move {
2326 Ok(sqlx::query_scalar::<_, i32>(
2327 "
2328 SELECT 1
2329 FROM room_participants
2330 WHERE room_participants.user_id = $1
2331 ",
2332 )
2333 .bind(user_id)
2334 .fetch_optional(&mut tx)
2335 .await?
2336 .is_some())
2337 })
2338 .await
2339 }
2340
2341 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
2342 self.transact(|mut tx| async move {
2343 let (id_a, id_b) = if user_id_1 < user_id_2 {
2344 (user_id_1, user_id_2)
2345 } else {
2346 (user_id_2, user_id_1)
2347 };
2348
2349 let query = "
2350 SELECT 1 FROM contacts
2351 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
2352 LIMIT 1
2353 ";
2354 Ok(sqlx::query_scalar::<_, i32>(query)
2355 .bind(id_a.0)
2356 .bind(id_b.0)
2357 .fetch_optional(&mut tx)
2358 .await?
2359 .is_some())
2360 })
2361 .await
2362 }
2363
2364 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
2365 self.transact(|mut tx| async move {
2366 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
2367 (sender_id, receiver_id, true)
2368 } else {
2369 (receiver_id, sender_id, false)
2370 };
2371 let query = "
2372 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
2373 VALUES ($1, $2, $3, FALSE, TRUE)
2374 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
2375 SET
2376 accepted = TRUE,
2377 should_notify = FALSE
2378 WHERE
2379 NOT contacts.accepted AND
2380 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
2381 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
2382 ";
2383 let result = sqlx::query(query)
2384 .bind(id_a.0)
2385 .bind(id_b.0)
2386 .bind(a_to_b)
2387 .execute(&mut tx)
2388 .await?;
2389
2390 if result.rows_affected() == 1 {
2391 tx.commit().await?;
2392 Ok(())
2393 } else {
2394 Err(anyhow!("contact already requested"))?
2395 }
2396 }).await
2397 }
2398
2399 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2400 self.transact(|mut tx| async move {
2401 let (id_a, id_b) = if responder_id < requester_id {
2402 (responder_id, requester_id)
2403 } else {
2404 (requester_id, responder_id)
2405 };
2406 let query = "
2407 DELETE FROM contacts
2408 WHERE user_id_a = $1 AND user_id_b = $2;
2409 ";
2410 let result = sqlx::query(query)
2411 .bind(id_a.0)
2412 .bind(id_b.0)
2413 .execute(&mut tx)
2414 .await?;
2415
2416 if result.rows_affected() == 1 {
2417 tx.commit().await?;
2418 Ok(())
2419 } else {
2420 Err(anyhow!("no such contact"))?
2421 }
2422 })
2423 .await
2424 }
2425
2426 pub async fn dismiss_contact_notification(
2427 &self,
2428 user_id: UserId,
2429 contact_user_id: UserId,
2430 ) -> Result<()> {
2431 self.transact(|mut tx| async move {
2432 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
2433 (user_id, contact_user_id, true)
2434 } else {
2435 (contact_user_id, user_id, false)
2436 };
2437
2438 let query = "
2439 UPDATE contacts
2440 SET should_notify = FALSE
2441 WHERE
2442 user_id_a = $1 AND user_id_b = $2 AND
2443 (
2444 (a_to_b = $3 AND accepted) OR
2445 (a_to_b != $3 AND NOT accepted)
2446 );
2447 ";
2448
2449 let result = sqlx::query(query)
2450 .bind(id_a.0)
2451 .bind(id_b.0)
2452 .bind(a_to_b)
2453 .execute(&mut tx)
2454 .await?;
2455
2456 if result.rows_affected() == 0 {
2457 Err(anyhow!("no such contact request"))?
2458 } else {
2459 tx.commit().await?;
2460 Ok(())
2461 }
2462 })
2463 .await
2464 }
2465
2466 pub async fn respond_to_contact_request(
2467 &self,
2468 responder_id: UserId,
2469 requester_id: UserId,
2470 accept: bool,
2471 ) -> Result<()> {
2472 self.transact(|mut tx| async move {
2473 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
2474 (responder_id, requester_id, false)
2475 } else {
2476 (requester_id, responder_id, true)
2477 };
2478 let result = if accept {
2479 let query = "
2480 UPDATE contacts
2481 SET accepted = TRUE, should_notify = TRUE
2482 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
2483 ";
2484 sqlx::query(query)
2485 .bind(id_a.0)
2486 .bind(id_b.0)
2487 .bind(a_to_b)
2488 .execute(&mut tx)
2489 .await?
2490 } else {
2491 let query = "
2492 DELETE FROM contacts
2493 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
2494 ";
2495 sqlx::query(query)
2496 .bind(id_a.0)
2497 .bind(id_b.0)
2498 .bind(a_to_b)
2499 .execute(&mut tx)
2500 .await?
2501 };
2502 if result.rows_affected() == 1 {
2503 tx.commit().await?;
2504 Ok(())
2505 } else {
2506 Err(anyhow!("no such contact request"))?
2507 }
2508 })
2509 .await
2510 }
2511
2512 // access tokens
2513
2514 pub async fn create_access_token_hash(
2515 &self,
2516 user_id: UserId,
2517 access_token_hash: &str,
2518 max_access_token_count: usize,
2519 ) -> Result<()> {
2520 self.transact(|tx| async {
2521 let mut tx = tx;
2522 let insert_query = "
2523 INSERT INTO access_tokens (user_id, hash)
2524 VALUES ($1, $2);
2525 ";
2526 let cleanup_query = "
2527 DELETE FROM access_tokens
2528 WHERE id IN (
2529 SELECT id from access_tokens
2530 WHERE user_id = $1
2531 ORDER BY id DESC
2532 LIMIT 10000
2533 OFFSET $3
2534 )
2535 ";
2536
2537 sqlx::query(insert_query)
2538 .bind(user_id.0)
2539 .bind(access_token_hash)
2540 .execute(&mut tx)
2541 .await?;
2542 sqlx::query(cleanup_query)
2543 .bind(user_id.0)
2544 .bind(access_token_hash)
2545 .bind(max_access_token_count as i32)
2546 .execute(&mut tx)
2547 .await?;
2548 Ok(tx.commit().await?)
2549 })
2550 .await
2551 }
2552
2553 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
2554 self.transact(|mut tx| async move {
2555 let query = "
2556 SELECT hash
2557 FROM access_tokens
2558 WHERE user_id = $1
2559 ORDER BY id DESC
2560 ";
2561 Ok(sqlx::query_scalar(query)
2562 .bind(user_id.0)
2563 .fetch_all(&mut tx)
2564 .await?)
2565 })
2566 .await
2567 }
2568
2569 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
2570 where
2571 F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
2572 Fut: Send + Future<Output = Result<T>>,
2573 {
2574 let body = async {
2575 loop {
2576 let tx = self.begin_transaction().await?;
2577 match f(tx).await {
2578 Ok(result) => return Ok(result),
2579 Err(error) => match error {
2580 Error::Database(error)
2581 if error
2582 .as_database_error()
2583 .and_then(|error| error.code())
2584 .as_deref()
2585 == Some("40001") =>
2586 {
2587 // Retry (don't break the loop)
2588 }
2589 error @ _ => return Err(error),
2590 },
2591 }
2592 }
2593 };
2594
2595 #[cfg(test)]
2596 {
2597 if let Some(background) = self.background.as_ref() {
2598 background.simulate_random_delay().await;
2599 }
2600
2601 let result = self.runtime.as_ref().unwrap().block_on(body);
2602
2603 // if let Some(background) = self.background.as_ref() {
2604 // background.simulate_random_delay().await;
2605 // }
2606
2607 result
2608 }
2609
2610 #[cfg(not(test))]
2611 {
2612 body.await
2613 }
2614 }
2615}
2616
2617macro_rules! id_type {
2618 ($name:ident) => {
2619 #[derive(
2620 Clone,
2621 Copy,
2622 Debug,
2623 Default,
2624 PartialEq,
2625 Eq,
2626 PartialOrd,
2627 Ord,
2628 Hash,
2629 sqlx::Type,
2630 Serialize,
2631 Deserialize,
2632 )]
2633 #[sqlx(transparent)]
2634 #[serde(transparent)]
2635 pub struct $name(pub i32);
2636
2637 impl $name {
2638 #[allow(unused)]
2639 pub const MAX: Self = Self(i32::MAX);
2640
2641 #[allow(unused)]
2642 pub fn from_proto(value: u64) -> Self {
2643 Self(value as i32)
2644 }
2645
2646 #[allow(unused)]
2647 pub fn to_proto(self) -> u64 {
2648 self.0 as u64
2649 }
2650 }
2651
2652 impl std::fmt::Display for $name {
2653 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2654 self.0.fmt(f)
2655 }
2656 }
2657 };
2658}
2659
2660id_type!(UserId);
2661#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2662pub struct User {
2663 pub id: UserId,
2664 pub github_login: String,
2665 pub github_user_id: Option<i32>,
2666 pub email_address: Option<String>,
2667 pub admin: bool,
2668 pub invite_code: Option<String>,
2669 pub invite_count: i32,
2670 pub connected_once: bool,
2671}
2672
2673id_type!(RoomId);
2674#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2675pub struct Room {
2676 pub id: RoomId,
2677 pub live_kit_room: String,
2678}
2679
2680id_type!(ProjectId);
2681pub struct Project {
2682 pub collaborators: Vec<ProjectCollaborator>,
2683 pub worktrees: BTreeMap<WorktreeId, Worktree>,
2684 pub language_servers: Vec<proto::LanguageServer>,
2685}
2686
2687id_type!(ReplicaId);
2688#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2689pub struct ProjectCollaborator {
2690 pub project_id: ProjectId,
2691 pub connection_id: i32,
2692 pub user_id: UserId,
2693 pub replica_id: ReplicaId,
2694 pub is_host: bool,
2695}
2696
2697id_type!(WorktreeId);
2698#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2699struct WorktreeRow {
2700 pub id: WorktreeId,
2701 pub abs_path: String,
2702 pub root_name: String,
2703 pub visible: bool,
2704 pub scan_id: i64,
2705 pub is_complete: bool,
2706}
2707
2708pub struct Worktree {
2709 pub id: WorktreeId,
2710 pub abs_path: String,
2711 pub root_name: String,
2712 pub visible: bool,
2713 pub entries: Vec<proto::Entry>,
2714 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
2715 pub scan_id: u64,
2716 pub is_complete: bool,
2717}
2718
2719#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2720struct WorktreeEntry {
2721 id: i64,
2722 worktree_id: WorktreeId,
2723 is_dir: bool,
2724 path: String,
2725 inode: i64,
2726 mtime_seconds: i64,
2727 mtime_nanos: i32,
2728 is_symlink: bool,
2729 is_ignored: bool,
2730}
2731
2732#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2733struct WorktreeDiagnosticSummary {
2734 worktree_id: WorktreeId,
2735 path: String,
2736 language_server_id: i64,
2737 error_count: i32,
2738 warning_count: i32,
2739}
2740
2741id_type!(LanguageServerId);
2742#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2743struct LanguageServer {
2744 id: LanguageServerId,
2745 name: String,
2746}
2747
2748pub struct LeftProject {
2749 pub id: ProjectId,
2750 pub host_user_id: UserId,
2751 pub host_connection_id: ConnectionId,
2752 pub connection_ids: Vec<ConnectionId>,
2753}
2754
2755pub struct LeftRoom {
2756 pub room: proto::Room,
2757 pub left_projects: HashMap<ProjectId, LeftProject>,
2758 pub canceled_calls_to_user_ids: Vec<UserId>,
2759}
2760
2761#[derive(Clone, Debug, PartialEq, Eq)]
2762pub enum Contact {
2763 Accepted {
2764 user_id: UserId,
2765 should_notify: bool,
2766 busy: bool,
2767 },
2768 Outgoing {
2769 user_id: UserId,
2770 },
2771 Incoming {
2772 user_id: UserId,
2773 should_notify: bool,
2774 },
2775}
2776
2777impl Contact {
2778 pub fn user_id(&self) -> UserId {
2779 match self {
2780 Contact::Accepted { user_id, .. } => *user_id,
2781 Contact::Outgoing { user_id } => *user_id,
2782 Contact::Incoming { user_id, .. } => *user_id,
2783 }
2784 }
2785}
2786
2787#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2788pub struct IncomingContactRequest {
2789 pub requester_id: UserId,
2790 pub should_notify: bool,
2791}
2792
2793#[derive(Clone, Deserialize)]
2794pub struct Signup {
2795 pub email_address: String,
2796 pub platform_mac: bool,
2797 pub platform_windows: bool,
2798 pub platform_linux: bool,
2799 pub editor_features: Vec<String>,
2800 pub programming_languages: Vec<String>,
2801 pub device_id: Option<String>,
2802}
2803
2804#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2805pub struct WaitlistSummary {
2806 #[sqlx(default)]
2807 pub count: i64,
2808 #[sqlx(default)]
2809 pub linux_count: i64,
2810 #[sqlx(default)]
2811 pub mac_count: i64,
2812 #[sqlx(default)]
2813 pub windows_count: i64,
2814 #[sqlx(default)]
2815 pub unknown_count: i64,
2816}
2817
2818#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2819pub struct Invite {
2820 pub email_address: String,
2821 pub email_confirmation_code: String,
2822}
2823
2824#[derive(Debug, Serialize, Deserialize)]
2825pub struct NewUserParams {
2826 pub github_login: String,
2827 pub github_user_id: i32,
2828 pub invite_count: i32,
2829}
2830
2831#[derive(Debug)]
2832pub struct NewUserResult {
2833 pub user_id: UserId,
2834 pub metrics_id: String,
2835 pub inviting_user_id: Option<UserId>,
2836 pub signup_device_id: Option<String>,
2837}
2838
2839fn random_invite_code() -> String {
2840 nanoid::nanoid!(16)
2841}
2842
2843fn random_email_confirmation_code() -> String {
2844 nanoid::nanoid!(64)
2845}
2846
2847#[cfg(test)]
2848pub use test::*;
2849
2850#[cfg(test)]
2851mod test {
2852 use super::*;
2853 use gpui::executor::Background;
2854 use lazy_static::lazy_static;
2855 use parking_lot::Mutex;
2856 use rand::prelude::*;
2857 use sqlx::migrate::MigrateDatabase;
2858 use std::sync::Arc;
2859
2860 pub struct SqliteTestDb {
2861 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2862 pub conn: sqlx::sqlite::SqliteConnection,
2863 }
2864
2865 pub struct PostgresTestDb {
2866 pub db: Option<Arc<Db<sqlx::Postgres>>>,
2867 pub url: String,
2868 }
2869
2870 impl SqliteTestDb {
2871 pub fn new(background: Arc<Background>) -> Self {
2872 let mut rng = StdRng::from_entropy();
2873 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2874 let runtime = tokio::runtime::Builder::new_current_thread()
2875 .enable_io()
2876 .enable_time()
2877 .build()
2878 .unwrap();
2879
2880 let (mut db, conn) = runtime.block_on(async {
2881 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2882 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2883 db.migrate(migrations_path.as_ref(), false).await.unwrap();
2884 let conn = db.pool.acquire().await.unwrap().detach();
2885 (db, conn)
2886 });
2887
2888 db.background = Some(background);
2889 db.runtime = Some(runtime);
2890
2891 Self {
2892 db: Some(Arc::new(db)),
2893 conn,
2894 }
2895 }
2896
2897 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2898 self.db.as_ref().unwrap()
2899 }
2900 }
2901
2902 impl PostgresTestDb {
2903 pub fn new(background: Arc<Background>) -> Self {
2904 lazy_static! {
2905 static ref LOCK: Mutex<()> = Mutex::new(());
2906 }
2907
2908 let _guard = LOCK.lock();
2909 let mut rng = StdRng::from_entropy();
2910 let url = format!(
2911 "postgres://postgres@localhost/zed-test-{}",
2912 rng.gen::<u128>()
2913 );
2914 let runtime = tokio::runtime::Builder::new_current_thread()
2915 .enable_io()
2916 .enable_time()
2917 .build()
2918 .unwrap();
2919
2920 let mut db = runtime.block_on(async {
2921 sqlx::Postgres::create_database(&url)
2922 .await
2923 .expect("failed to create test db");
2924 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2925 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2926 db.migrate(Path::new(migrations_path), false).await.unwrap();
2927 db
2928 });
2929
2930 db.background = Some(background);
2931 db.runtime = Some(runtime);
2932
2933 Self {
2934 db: Some(Arc::new(db)),
2935 url,
2936 }
2937 }
2938
2939 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2940 self.db.as_ref().unwrap()
2941 }
2942 }
2943
2944 impl Drop for PostgresTestDb {
2945 fn drop(&mut self) {
2946 let db = self.db.take().unwrap();
2947 db.teardown(&self.url);
2948 }
2949 }
2950}