1mod schema;
2#[cfg(test)]
3mod tests;
4
5use crate::{Error, Result};
6use anyhow::anyhow;
7use axum::http::StatusCode;
8use collections::{BTreeMap, HashMap, HashSet};
9use dashmap::DashMap;
10use futures::{future::BoxFuture, FutureExt, StreamExt};
11use rpc::{proto, ConnectionId};
12use sea_query::{Expr, Query};
13use sea_query_binder::SqlxBinder;
14use serde::{Deserialize, Serialize};
15use sqlx::{
16 migrate::{Migrate as _, Migration, MigrationSource},
17 types::Uuid,
18 FromRow,
19};
20use std::{
21 future::Future,
22 marker::PhantomData,
23 ops::{Deref, DerefMut},
24 path::Path,
25 rc::Rc,
26 sync::Arc,
27 time::Duration,
28};
29use time::{OffsetDateTime, PrimitiveDateTime};
30use tokio::sync::{Mutex, OwnedMutexGuard};
31
32#[cfg(test)]
33pub type DefaultDb = Db<sqlx::Sqlite>;
34
35#[cfg(not(test))]
36pub type DefaultDb = Db<sqlx::Postgres>;
37
38pub struct Db<D: sqlx::Database> {
39 pool: sqlx::Pool<D>,
40 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
41 #[cfg(test)]
42 background: Option<std::sync::Arc<gpui::executor::Background>>,
43 #[cfg(test)]
44 runtime: Option<tokio::runtime::Runtime>,
45}
46
47pub struct RoomGuard<T> {
48 data: T,
49 _guard: OwnedMutexGuard<()>,
50 _not_send: PhantomData<Rc<()>>,
51}
52
53impl<T> Deref for RoomGuard<T> {
54 type Target = T;
55
56 fn deref(&self) -> &T {
57 &self.data
58 }
59}
60
61impl<T> DerefMut for RoomGuard<T> {
62 fn deref_mut(&mut self) -> &mut T {
63 &mut self.data
64 }
65}
66
67pub trait BeginTransaction: Send + Sync {
68 type Database: sqlx::Database;
69
70 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, Self::Database>>>;
71}
72
73// In Postgres, serializable transactions are opt-in
74impl BeginTransaction for Db<sqlx::Postgres> {
75 type Database = sqlx::Postgres;
76
77 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Postgres>>> {
78 async move {
79 let mut tx = self.pool.begin().await?;
80 sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;")
81 .await?;
82 Ok(tx)
83 }
84 .boxed()
85 }
86}
87
88// In Sqlite, transactions are inherently serializable.
89#[cfg(test)]
90impl BeginTransaction for Db<sqlx::Sqlite> {
91 type Database = sqlx::Sqlite;
92
93 fn begin_transaction(&self) -> BoxFuture<Result<sqlx::Transaction<'static, sqlx::Sqlite>>> {
94 async move { Ok(self.pool.begin().await?) }.boxed()
95 }
96}
97
98pub trait BuildQuery {
99 fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues);
100}
101
102impl BuildQuery for Db<sqlx::Postgres> {
103 fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues) {
104 query.build_sqlx(sea_query::PostgresQueryBuilder)
105 }
106}
107
108#[cfg(test)]
109impl BuildQuery for Db<sqlx::Sqlite> {
110 fn build_query<T: SqlxBinder>(&self, query: &T) -> (String, sea_query_binder::SqlxValues) {
111 query.build_sqlx(sea_query::SqliteQueryBuilder)
112 }
113}
114
115pub trait RowsAffected {
116 fn rows_affected(&self) -> u64;
117}
118
119#[cfg(test)]
120impl RowsAffected for sqlx::sqlite::SqliteQueryResult {
121 fn rows_affected(&self) -> u64 {
122 self.rows_affected()
123 }
124}
125
126impl RowsAffected for sqlx::postgres::PgQueryResult {
127 fn rows_affected(&self) -> u64 {
128 self.rows_affected()
129 }
130}
131
132#[cfg(test)]
133impl Db<sqlx::Sqlite> {
134 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
135 use std::str::FromStr as _;
136 let options = sqlx::sqlite::SqliteConnectOptions::from_str(url)
137 .unwrap()
138 .create_if_missing(true)
139 .shared_cache(true);
140 let pool = sqlx::sqlite::SqlitePoolOptions::new()
141 .min_connections(2)
142 .max_connections(max_connections)
143 .connect_with(options)
144 .await?;
145 Ok(Self {
146 pool,
147 rooms: Default::default(),
148 background: None,
149 runtime: None,
150 })
151 }
152
153 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
154 self.transact(|tx| async {
155 let mut tx = tx;
156 let query = "
157 SELECT users.*
158 FROM users
159 WHERE users.id IN (SELECT value from json_each($1))
160 ";
161 Ok(sqlx::query_as(query)
162 .bind(&serde_json::json!(ids))
163 .fetch_all(&mut tx)
164 .await?)
165 })
166 .await
167 }
168
169 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
170 self.transact(|mut tx| async move {
171 let query = "
172 SELECT metrics_id
173 FROM users
174 WHERE id = $1
175 ";
176 Ok(sqlx::query_scalar(query)
177 .bind(id)
178 .fetch_one(&mut tx)
179 .await?)
180 })
181 .await
182 }
183
184 pub async fn create_user(
185 &self,
186 email_address: &str,
187 admin: bool,
188 params: NewUserParams,
189 ) -> Result<NewUserResult> {
190 self.transact(|mut tx| async {
191 let query = "
192 INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id)
193 VALUES ($1, $2, $3, $4, $5)
194 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
195 RETURNING id, metrics_id
196 ";
197
198 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
199 .bind(email_address)
200 .bind(¶ms.github_login)
201 .bind(¶ms.github_user_id)
202 .bind(admin)
203 .bind(Uuid::new_v4().to_string())
204 .fetch_one(&mut tx)
205 .await?;
206 tx.commit().await?;
207 Ok(NewUserResult {
208 user_id,
209 metrics_id,
210 signup_device_id: None,
211 inviting_user_id: None,
212 })
213 })
214 .await
215 }
216
217 pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result<Vec<User>> {
218 unimplemented!()
219 }
220
221 pub async fn create_user_from_invite(
222 &self,
223 _invite: &Invite,
224 _user: NewUserParams,
225 ) -> Result<Option<NewUserResult>> {
226 unimplemented!()
227 }
228
229 pub async fn create_signup(&self, _signup: Signup) -> Result<()> {
230 unimplemented!()
231 }
232
233 pub async fn create_invite_from_code(
234 &self,
235 _code: &str,
236 _email_address: &str,
237 _device_id: Option<&str>,
238 ) -> Result<Invite> {
239 unimplemented!()
240 }
241
242 pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> {
243 unimplemented!()
244 }
245}
246
247impl Db<sqlx::Postgres> {
248 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
249 let pool = sqlx::postgres::PgPoolOptions::new()
250 .max_connections(max_connections)
251 .connect(url)
252 .await?;
253 Ok(Self {
254 pool,
255 rooms: DashMap::with_capacity(16384),
256 #[cfg(test)]
257 background: None,
258 #[cfg(test)]
259 runtime: None,
260 })
261 }
262
263 #[cfg(test)]
264 pub fn teardown(&self, url: &str) {
265 self.runtime.as_ref().unwrap().block_on(async {
266 use util::ResultExt;
267 let query = "
268 SELECT pg_terminate_backend(pg_stat_activity.pid)
269 FROM pg_stat_activity
270 WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid();
271 ";
272 sqlx::query(query).execute(&self.pool).await.log_err();
273 self.pool.close().await;
274 <sqlx::Sqlite as sqlx::migrate::MigrateDatabase>::drop_database(url)
275 .await
276 .log_err();
277 })
278 }
279
280 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
281 self.transact(|tx| async {
282 let mut tx = tx;
283 let like_string = Self::fuzzy_like_string(name_query);
284 let query = "
285 SELECT users.*
286 FROM users
287 WHERE github_login ILIKE $1
288 ORDER BY github_login <-> $2
289 LIMIT $3
290 ";
291 Ok(sqlx::query_as(query)
292 .bind(like_string)
293 .bind(name_query)
294 .bind(limit as i32)
295 .fetch_all(&mut tx)
296 .await?)
297 })
298 .await
299 }
300
301 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
302 let ids = ids.iter().map(|id| id.0).collect::<Vec<_>>();
303 self.transact(|tx| async {
304 let mut tx = tx;
305 let query = "
306 SELECT users.*
307 FROM users
308 WHERE users.id = ANY ($1)
309 ";
310 Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?)
311 })
312 .await
313 }
314
315 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
316 self.transact(|mut tx| async move {
317 let query = "
318 SELECT metrics_id::text
319 FROM users
320 WHERE id = $1
321 ";
322 Ok(sqlx::query_scalar(query)
323 .bind(id)
324 .fetch_one(&mut tx)
325 .await?)
326 })
327 .await
328 }
329
330 pub async fn create_user(
331 &self,
332 email_address: &str,
333 admin: bool,
334 params: NewUserParams,
335 ) -> Result<NewUserResult> {
336 self.transact(|mut tx| async {
337 let query = "
338 INSERT INTO users (email_address, github_login, github_user_id, admin)
339 VALUES ($1, $2, $3, $4)
340 ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
341 RETURNING id, metrics_id::text
342 ";
343
344 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query)
345 .bind(email_address)
346 .bind(¶ms.github_login)
347 .bind(params.github_user_id)
348 .bind(admin)
349 .fetch_one(&mut tx)
350 .await?;
351 tx.commit().await?;
352
353 Ok(NewUserResult {
354 user_id,
355 metrics_id,
356 signup_device_id: None,
357 inviting_user_id: None,
358 })
359 })
360 .await
361 }
362
363 pub async fn create_user_from_invite(
364 &self,
365 invite: &Invite,
366 user: NewUserParams,
367 ) -> Result<Option<NewUserResult>> {
368 self.transact(|mut tx| async {
369 let (signup_id, existing_user_id, inviting_user_id, signup_device_id): (
370 i32,
371 Option<UserId>,
372 Option<UserId>,
373 Option<String>,
374 ) = sqlx::query_as(
375 "
376 SELECT id, user_id, inviting_user_id, device_id
377 FROM signups
378 WHERE
379 email_address = $1 AND
380 email_confirmation_code = $2
381 ",
382 )
383 .bind(&invite.email_address)
384 .bind(&invite.email_confirmation_code)
385 .fetch_optional(&mut tx)
386 .await?
387 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
388
389 if existing_user_id.is_some() {
390 return Ok(None);
391 }
392
393 let (user_id, metrics_id): (UserId, String) = sqlx::query_as(
394 "
395 INSERT INTO users
396 (email_address, github_login, github_user_id, admin, invite_count, invite_code)
397 VALUES
398 ($1, $2, $3, FALSE, $4, $5)
399 ON CONFLICT (github_login) DO UPDATE SET
400 email_address = excluded.email_address,
401 github_user_id = excluded.github_user_id,
402 admin = excluded.admin
403 RETURNING id, metrics_id::text
404 ",
405 )
406 .bind(&invite.email_address)
407 .bind(&user.github_login)
408 .bind(&user.github_user_id)
409 .bind(&user.invite_count)
410 .bind(random_invite_code())
411 .fetch_one(&mut tx)
412 .await?;
413
414 sqlx::query(
415 "
416 UPDATE signups
417 SET user_id = $1
418 WHERE id = $2
419 ",
420 )
421 .bind(&user_id)
422 .bind(&signup_id)
423 .execute(&mut tx)
424 .await?;
425
426 if let Some(inviting_user_id) = inviting_user_id {
427 let id: Option<UserId> = sqlx::query_scalar(
428 "
429 UPDATE users
430 SET invite_count = invite_count - 1
431 WHERE id = $1 AND invite_count > 0
432 RETURNING id
433 ",
434 )
435 .bind(&inviting_user_id)
436 .fetch_optional(&mut tx)
437 .await?;
438
439 if id.is_none() {
440 Err(Error::Http(
441 StatusCode::UNAUTHORIZED,
442 "no invites remaining".to_string(),
443 ))?;
444 }
445
446 sqlx::query(
447 "
448 INSERT INTO contacts
449 (user_id_a, user_id_b, a_to_b, should_notify, accepted)
450 VALUES
451 ($1, $2, TRUE, TRUE, TRUE)
452 ON CONFLICT DO NOTHING
453 ",
454 )
455 .bind(inviting_user_id)
456 .bind(user_id)
457 .execute(&mut tx)
458 .await?;
459 }
460
461 tx.commit().await?;
462 Ok(Some(NewUserResult {
463 user_id,
464 metrics_id,
465 inviting_user_id,
466 signup_device_id,
467 }))
468 })
469 .await
470 }
471
472 pub async fn create_signup(&self, signup: Signup) -> Result<()> {
473 self.transact(|mut tx| async {
474 sqlx::query(
475 "
476 INSERT INTO signups
477 (
478 email_address,
479 email_confirmation_code,
480 email_confirmation_sent,
481 platform_linux,
482 platform_mac,
483 platform_windows,
484 platform_unknown,
485 editor_features,
486 programming_languages,
487 device_id
488 )
489 VALUES
490 ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8)
491 RETURNING id
492 ",
493 )
494 .bind(&signup.email_address)
495 .bind(&random_email_confirmation_code())
496 .bind(&signup.platform_linux)
497 .bind(&signup.platform_mac)
498 .bind(&signup.platform_windows)
499 .bind(&signup.editor_features)
500 .bind(&signup.programming_languages)
501 .bind(&signup.device_id)
502 .execute(&mut tx)
503 .await?;
504 tx.commit().await?;
505 Ok(())
506 })
507 .await
508 }
509
510 pub async fn create_invite_from_code(
511 &self,
512 code: &str,
513 email_address: &str,
514 device_id: Option<&str>,
515 ) -> Result<Invite> {
516 self.transact(|mut tx| async {
517 let existing_user: Option<UserId> = sqlx::query_scalar(
518 "
519 SELECT id
520 FROM users
521 WHERE email_address = $1
522 ",
523 )
524 .bind(email_address)
525 .fetch_optional(&mut tx)
526 .await?;
527 if existing_user.is_some() {
528 Err(anyhow!("email address is already in use"))?;
529 }
530
531 let row: Option<(UserId, i32)> = sqlx::query_as(
532 "
533 SELECT id, invite_count
534 FROM users
535 WHERE invite_code = $1
536 ",
537 )
538 .bind(code)
539 .fetch_optional(&mut tx)
540 .await?;
541
542 let (inviter_id, invite_count) = match row {
543 Some(row) => row,
544 None => Err(Error::Http(
545 StatusCode::NOT_FOUND,
546 "invite code not found".to_string(),
547 ))?,
548 };
549
550 if invite_count == 0 {
551 Err(Error::Http(
552 StatusCode::UNAUTHORIZED,
553 "no invites remaining".to_string(),
554 ))?;
555 }
556
557 let email_confirmation_code: String = sqlx::query_scalar(
558 "
559 INSERT INTO signups
560 (
561 email_address,
562 email_confirmation_code,
563 email_confirmation_sent,
564 inviting_user_id,
565 platform_linux,
566 platform_mac,
567 platform_windows,
568 platform_unknown,
569 device_id
570 )
571 VALUES
572 ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4)
573 ON CONFLICT (email_address)
574 DO UPDATE SET
575 inviting_user_id = excluded.inviting_user_id
576 RETURNING email_confirmation_code
577 ",
578 )
579 .bind(&email_address)
580 .bind(&random_email_confirmation_code())
581 .bind(&inviter_id)
582 .bind(&device_id)
583 .fetch_one(&mut tx)
584 .await?;
585
586 tx.commit().await?;
587
588 Ok(Invite {
589 email_address: email_address.into(),
590 email_confirmation_code,
591 })
592 })
593 .await
594 }
595
596 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
597 self.transact(|mut tx| async {
598 let emails = invites
599 .iter()
600 .map(|s| s.email_address.as_str())
601 .collect::<Vec<_>>();
602 sqlx::query(
603 "
604 UPDATE signups
605 SET email_confirmation_sent = TRUE
606 WHERE email_address = ANY ($1)
607 ",
608 )
609 .bind(&emails)
610 .execute(&mut tx)
611 .await?;
612 tx.commit().await?;
613 Ok(())
614 })
615 .await
616 }
617}
618
619impl<D> Db<D>
620where
621 Self: BeginTransaction<Database = D> + BuildQuery,
622 D: sqlx::Database + sqlx::migrate::MigrateDatabase,
623 D::Connection: sqlx::migrate::Migrate,
624 for<'a> <D as sqlx::database::HasArguments<'a>>::Arguments: sqlx::IntoArguments<'a, D>,
625 for<'a> sea_query_binder::SqlxValues: sqlx::IntoArguments<'a, D>,
626 for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>,
627 for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>,
628 D::QueryResult: RowsAffected,
629 String: sqlx::Type<D>,
630 i32: sqlx::Type<D>,
631 i64: sqlx::Type<D>,
632 bool: sqlx::Type<D>,
633 str: sqlx::Type<D>,
634 Uuid: sqlx::Type<D>,
635 sqlx::types::Json<serde_json::Value>: sqlx::Type<D>,
636 OffsetDateTime: sqlx::Type<D>,
637 PrimitiveDateTime: sqlx::Type<D>,
638 usize: sqlx::ColumnIndex<D::Row>,
639 for<'a> &'a str: sqlx::ColumnIndex<D::Row>,
640 for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
641 for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
642 for<'a> Option<String>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
643 for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
644 for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
645 for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
646 for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
647 for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
648 for<'a> Option<ProjectId>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
649 for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
650 for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>,
651 for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>,
652{
653 pub async fn migrate(
654 &self,
655 migrations_path: &Path,
656 ignore_checksum_mismatch: bool,
657 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
658 let migrations = MigrationSource::resolve(migrations_path)
659 .await
660 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
661
662 let mut conn = self.pool.acquire().await?;
663
664 conn.ensure_migrations_table().await?;
665 let applied_migrations: HashMap<_, _> = conn
666 .list_applied_migrations()
667 .await?
668 .into_iter()
669 .map(|m| (m.version, m))
670 .collect();
671
672 let mut new_migrations = Vec::new();
673 for migration in migrations {
674 match applied_migrations.get(&migration.version) {
675 Some(applied_migration) => {
676 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
677 {
678 Err(anyhow!(
679 "checksum mismatch for applied migration {}",
680 migration.description
681 ))?;
682 }
683 }
684 None => {
685 let elapsed = conn.apply(&migration).await?;
686 new_migrations.push((migration, elapsed));
687 }
688 }
689 }
690
691 Ok(new_migrations)
692 }
693
694 pub fn fuzzy_like_string(string: &str) -> String {
695 let mut result = String::with_capacity(string.len() * 2 + 1);
696 for c in string.chars() {
697 if c.is_alphanumeric() {
698 result.push('%');
699 result.push(c);
700 }
701 }
702 result.push('%');
703 result
704 }
705
706 // users
707
708 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
709 self.transact(|tx| async {
710 let mut tx = tx;
711 let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2";
712 Ok(sqlx::query_as(query)
713 .bind(limit as i32)
714 .bind((page * limit) as i32)
715 .fetch_all(&mut tx)
716 .await?)
717 })
718 .await
719 }
720
721 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
722 self.transact(|tx| async {
723 let mut tx = tx;
724 let query = "
725 SELECT users.*
726 FROM users
727 WHERE id = $1
728 LIMIT 1
729 ";
730 Ok(sqlx::query_as(query)
731 .bind(&id)
732 .fetch_optional(&mut tx)
733 .await?)
734 })
735 .await
736 }
737
738 pub async fn get_users_with_no_invites(
739 &self,
740 invited_by_another_user: bool,
741 ) -> Result<Vec<User>> {
742 self.transact(|tx| async {
743 let mut tx = tx;
744 let query = format!(
745 "
746 SELECT users.*
747 FROM users
748 WHERE invite_count = 0
749 AND inviter_id IS{} NULL
750 ",
751 if invited_by_another_user { " NOT" } else { "" }
752 );
753
754 Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?)
755 })
756 .await
757 }
758
759 pub async fn get_user_by_github_account(
760 &self,
761 github_login: &str,
762 github_user_id: Option<i32>,
763 ) -> Result<Option<User>> {
764 self.transact(|tx| async {
765 let mut tx = tx;
766 if let Some(github_user_id) = github_user_id {
767 let mut user = sqlx::query_as::<_, User>(
768 "
769 UPDATE users
770 SET github_login = $1
771 WHERE github_user_id = $2
772 RETURNING *
773 ",
774 )
775 .bind(github_login)
776 .bind(github_user_id)
777 .fetch_optional(&mut tx)
778 .await?;
779
780 if user.is_none() {
781 user = sqlx::query_as::<_, User>(
782 "
783 UPDATE users
784 SET github_user_id = $1
785 WHERE github_login = $2
786 RETURNING *
787 ",
788 )
789 .bind(github_user_id)
790 .bind(github_login)
791 .fetch_optional(&mut tx)
792 .await?;
793 }
794
795 Ok(user)
796 } else {
797 let user = sqlx::query_as(
798 "
799 SELECT * FROM users
800 WHERE github_login = $1
801 LIMIT 1
802 ",
803 )
804 .bind(github_login)
805 .fetch_optional(&mut tx)
806 .await?;
807 Ok(user)
808 }
809 })
810 .await
811 }
812
813 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
814 self.transact(|mut tx| async {
815 let query = "UPDATE users SET admin = $1 WHERE id = $2";
816 sqlx::query(query)
817 .bind(is_admin)
818 .bind(id.0)
819 .execute(&mut tx)
820 .await?;
821 tx.commit().await?;
822 Ok(())
823 })
824 .await
825 }
826
827 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
828 self.transact(|mut tx| async move {
829 let query = "UPDATE users SET connected_once = $1 WHERE id = $2";
830 sqlx::query(query)
831 .bind(connected_once)
832 .bind(id.0)
833 .execute(&mut tx)
834 .await?;
835 tx.commit().await?;
836 Ok(())
837 })
838 .await
839 }
840
841 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
842 self.transact(|mut tx| async move {
843 let query = "DELETE FROM access_tokens WHERE user_id = $1;";
844 sqlx::query(query)
845 .bind(id.0)
846 .execute(&mut tx)
847 .await
848 .map(drop)?;
849 let query = "DELETE FROM users WHERE id = $1;";
850 sqlx::query(query).bind(id.0).execute(&mut tx).await?;
851 tx.commit().await?;
852 Ok(())
853 })
854 .await
855 }
856
857 // signups
858
859 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
860 self.transact(|mut tx| async move {
861 Ok(sqlx::query_as(
862 "
863 SELECT
864 COUNT(*) as count,
865 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
866 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
867 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
868 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
869 FROM (
870 SELECT *
871 FROM signups
872 WHERE
873 NOT email_confirmation_sent
874 ) AS unsent
875 ",
876 )
877 .fetch_one(&mut tx)
878 .await?)
879 })
880 .await
881 }
882
883 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
884 self.transact(|mut tx| async move {
885 Ok(sqlx::query_as(
886 "
887 SELECT
888 email_address, email_confirmation_code
889 FROM signups
890 WHERE
891 NOT email_confirmation_sent AND
892 (platform_mac OR platform_unknown)
893 LIMIT $1
894 ",
895 )
896 .bind(count as i32)
897 .fetch_all(&mut tx)
898 .await?)
899 })
900 .await
901 }
902
903 // invite codes
904
905 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
906 self.transact(|mut tx| async move {
907 if count > 0 {
908 sqlx::query(
909 "
910 UPDATE users
911 SET invite_code = $1
912 WHERE id = $2 AND invite_code IS NULL
913 ",
914 )
915 .bind(random_invite_code())
916 .bind(id)
917 .execute(&mut tx)
918 .await?;
919 }
920
921 sqlx::query(
922 "
923 UPDATE users
924 SET invite_count = $1
925 WHERE id = $2
926 ",
927 )
928 .bind(count as i32)
929 .bind(id)
930 .execute(&mut tx)
931 .await?;
932 tx.commit().await?;
933 Ok(())
934 })
935 .await
936 }
937
938 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
939 self.transact(|mut tx| async move {
940 let result: Option<(String, i32)> = sqlx::query_as(
941 "
942 SELECT invite_code, invite_count
943 FROM users
944 WHERE id = $1 AND invite_code IS NOT NULL
945 ",
946 )
947 .bind(id)
948 .fetch_optional(&mut tx)
949 .await?;
950 if let Some((code, count)) = result {
951 Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?)))
952 } else {
953 Ok(None)
954 }
955 })
956 .await
957 }
958
959 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
960 self.transact(|tx| async {
961 let mut tx = tx;
962 sqlx::query_as(
963 "
964 SELECT *
965 FROM users
966 WHERE invite_code = $1
967 ",
968 )
969 .bind(code)
970 .fetch_optional(&mut tx)
971 .await?
972 .ok_or_else(|| {
973 Error::Http(
974 StatusCode::NOT_FOUND,
975 "that invite code does not exist".to_string(),
976 )
977 })
978 })
979 .await
980 }
981
982 async fn commit_room_transaction<'a, T>(
983 &'a self,
984 room_id: RoomId,
985 tx: sqlx::Transaction<'static, D>,
986 data: T,
987 ) -> Result<RoomGuard<T>> {
988 let lock = self.rooms.entry(room_id).or_default().clone();
989 let _guard = lock.lock_owned().await;
990 tx.commit().await?;
991 Ok(RoomGuard {
992 data,
993 _guard,
994 _not_send: PhantomData,
995 })
996 }
997
998 pub async fn create_room(
999 &self,
1000 user_id: UserId,
1001 connection_id: ConnectionId,
1002 live_kit_room: &str,
1003 ) -> Result<RoomGuard<proto::Room>> {
1004 self.transact(|mut tx| async move {
1005 let room_id = sqlx::query_scalar(
1006 "
1007 INSERT INTO rooms (live_kit_room)
1008 VALUES ($1)
1009 RETURNING id
1010 ",
1011 )
1012 .bind(&live_kit_room)
1013 .fetch_one(&mut tx)
1014 .await
1015 .map(RoomId)?;
1016
1017 sqlx::query(
1018 "
1019 INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
1020 VALUES ($1, $2, $3, $4, $5)
1021 ",
1022 )
1023 .bind(room_id)
1024 .bind(user_id)
1025 .bind(connection_id.0 as i32)
1026 .bind(user_id)
1027 .bind(connection_id.0 as i32)
1028 .execute(&mut tx)
1029 .await?;
1030
1031 let room = self.get_room(room_id, &mut tx).await?;
1032 self.commit_room_transaction(room_id, tx, room).await
1033 }).await
1034 }
1035
1036 pub async fn call(
1037 &self,
1038 room_id: RoomId,
1039 calling_user_id: UserId,
1040 calling_connection_id: ConnectionId,
1041 called_user_id: UserId,
1042 initial_project_id: Option<ProjectId>,
1043 ) -> Result<RoomGuard<(proto::Room, proto::IncomingCall)>> {
1044 self.transact(|mut tx| async move {
1045 sqlx::query(
1046 "
1047 INSERT INTO room_participants (
1048 room_id,
1049 user_id,
1050 calling_user_id,
1051 calling_connection_id,
1052 initial_project_id
1053 )
1054 VALUES ($1, $2, $3, $4, $5)
1055 ",
1056 )
1057 .bind(room_id)
1058 .bind(called_user_id)
1059 .bind(calling_user_id)
1060 .bind(calling_connection_id.0 as i32)
1061 .bind(initial_project_id)
1062 .execute(&mut tx)
1063 .await?;
1064
1065 let room = self.get_room(room_id, &mut tx).await?;
1066 let incoming_call = Self::build_incoming_call(&room, called_user_id)
1067 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
1068 self.commit_room_transaction(room_id, tx, (room, incoming_call))
1069 .await
1070 })
1071 .await
1072 }
1073
1074 pub async fn incoming_call_for_user(
1075 &self,
1076 user_id: UserId,
1077 ) -> Result<Option<proto::IncomingCall>> {
1078 self.transact(|mut tx| async move {
1079 let room_id = sqlx::query_scalar::<_, RoomId>(
1080 "
1081 SELECT room_id
1082 FROM room_participants
1083 WHERE user_id = $1 AND answering_connection_id IS NULL
1084 ",
1085 )
1086 .bind(user_id)
1087 .fetch_optional(&mut tx)
1088 .await?;
1089
1090 if let Some(room_id) = room_id {
1091 let room = self.get_room(room_id, &mut tx).await?;
1092 Ok(Self::build_incoming_call(&room, user_id))
1093 } else {
1094 Ok(None)
1095 }
1096 })
1097 .await
1098 }
1099
1100 fn build_incoming_call(
1101 room: &proto::Room,
1102 called_user_id: UserId,
1103 ) -> Option<proto::IncomingCall> {
1104 let pending_participant = room
1105 .pending_participants
1106 .iter()
1107 .find(|participant| participant.user_id == called_user_id.to_proto())?;
1108
1109 Some(proto::IncomingCall {
1110 room_id: room.id,
1111 calling_user_id: pending_participant.calling_user_id,
1112 participant_user_ids: room
1113 .participants
1114 .iter()
1115 .map(|participant| participant.user_id)
1116 .collect(),
1117 initial_project: room.participants.iter().find_map(|participant| {
1118 let initial_project_id = pending_participant.initial_project_id?;
1119 participant
1120 .projects
1121 .iter()
1122 .find(|project| project.id == initial_project_id)
1123 .cloned()
1124 }),
1125 })
1126 }
1127
1128 pub async fn call_failed(
1129 &self,
1130 room_id: RoomId,
1131 called_user_id: UserId,
1132 ) -> Result<RoomGuard<proto::Room>> {
1133 self.transact(|mut tx| async move {
1134 sqlx::query(
1135 "
1136 DELETE FROM room_participants
1137 WHERE room_id = $1 AND user_id = $2
1138 ",
1139 )
1140 .bind(room_id)
1141 .bind(called_user_id)
1142 .execute(&mut tx)
1143 .await?;
1144
1145 let room = self.get_room(room_id, &mut tx).await?;
1146 self.commit_room_transaction(room_id, tx, room).await
1147 })
1148 .await
1149 }
1150
1151 pub async fn decline_call(
1152 &self,
1153 expected_room_id: Option<RoomId>,
1154 user_id: UserId,
1155 ) -> Result<RoomGuard<proto::Room>> {
1156 self.transact(|mut tx| async move {
1157 let room_id = sqlx::query_scalar(
1158 "
1159 DELETE FROM room_participants
1160 WHERE user_id = $1 AND answering_connection_id IS NULL
1161 RETURNING room_id
1162 ",
1163 )
1164 .bind(user_id)
1165 .fetch_one(&mut tx)
1166 .await?;
1167 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1168 return Err(anyhow!("declining call on unexpected room"))?;
1169 }
1170
1171 let room = self.get_room(room_id, &mut tx).await?;
1172 self.commit_room_transaction(room_id, tx, room).await
1173 })
1174 .await
1175 }
1176
1177 pub async fn cancel_call(
1178 &self,
1179 expected_room_id: Option<RoomId>,
1180 calling_connection_id: ConnectionId,
1181 called_user_id: UserId,
1182 ) -> Result<RoomGuard<proto::Room>> {
1183 self.transact(|mut tx| async move {
1184 let room_id = sqlx::query_scalar(
1185 "
1186 DELETE FROM room_participants
1187 WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL
1188 RETURNING room_id
1189 ",
1190 )
1191 .bind(called_user_id)
1192 .bind(calling_connection_id.0 as i32)
1193 .fetch_one(&mut tx)
1194 .await?;
1195 if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1196 return Err(anyhow!("canceling call on unexpected room"))?;
1197 }
1198
1199 let room = self.get_room(room_id, &mut tx).await?;
1200 self.commit_room_transaction(room_id, tx, room).await
1201 }).await
1202 }
1203
1204 pub async fn join_room(
1205 &self,
1206 room_id: RoomId,
1207 user_id: UserId,
1208 connection_id: ConnectionId,
1209 ) -> Result<RoomGuard<proto::Room>> {
1210 self.transact(|mut tx| async move {
1211 sqlx::query(
1212 "
1213 UPDATE room_participants
1214 SET answering_connection_id = $1
1215 WHERE room_id = $2 AND user_id = $3
1216 RETURNING 1
1217 ",
1218 )
1219 .bind(connection_id.0 as i32)
1220 .bind(room_id)
1221 .bind(user_id)
1222 .fetch_one(&mut tx)
1223 .await?;
1224
1225 let room = self.get_room(room_id, &mut tx).await?;
1226 self.commit_room_transaction(room_id, tx, room).await
1227 })
1228 .await
1229 }
1230
1231 pub async fn leave_room(
1232 &self,
1233 connection_id: ConnectionId,
1234 ) -> Result<Option<RoomGuard<LeftRoom>>> {
1235 self.transact(|mut tx| async move {
1236 // Leave room.
1237 let room_id = sqlx::query_scalar::<_, RoomId>(
1238 "
1239 DELETE FROM room_participants
1240 WHERE answering_connection_id = $1
1241 RETURNING room_id
1242 ",
1243 )
1244 .bind(connection_id.0 as i32)
1245 .fetch_optional(&mut tx)
1246 .await?;
1247
1248 if let Some(room_id) = room_id {
1249 // Cancel pending calls initiated by the leaving user.
1250 let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
1251 "
1252 DELETE FROM room_participants
1253 WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
1254 RETURNING user_id
1255 ",
1256 )
1257 .bind(connection_id.0 as i32)
1258 .fetch_all(&mut tx)
1259 .await?;
1260
1261 let project_ids = sqlx::query_scalar::<_, ProjectId>(
1262 "
1263 SELECT project_id
1264 FROM project_collaborators
1265 WHERE connection_id = $1
1266 ",
1267 )
1268 .bind(connection_id.0 as i32)
1269 .fetch_all(&mut tx)
1270 .await?;
1271
1272 // Leave projects.
1273 let mut left_projects = HashMap::default();
1274 if !project_ids.is_empty() {
1275 let mut params = "?,".repeat(project_ids.len());
1276 params.pop();
1277 let query = format!(
1278 "
1279 SELECT *
1280 FROM project_collaborators
1281 WHERE project_id IN ({params})
1282 "
1283 );
1284 let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query);
1285 for project_id in project_ids {
1286 query = query.bind(project_id);
1287 }
1288
1289 let mut project_collaborators = query.fetch(&mut tx);
1290 while let Some(collaborator) = project_collaborators.next().await {
1291 let collaborator = collaborator?;
1292 let left_project =
1293 left_projects
1294 .entry(collaborator.project_id)
1295 .or_insert(LeftProject {
1296 id: collaborator.project_id,
1297 host_user_id: Default::default(),
1298 connection_ids: Default::default(),
1299 host_connection_id: Default::default(),
1300 });
1301
1302 let collaborator_connection_id =
1303 ConnectionId(collaborator.connection_id as u32);
1304 if collaborator_connection_id != connection_id {
1305 left_project.connection_ids.push(collaborator_connection_id);
1306 }
1307
1308 if collaborator.is_host {
1309 left_project.host_user_id = collaborator.user_id;
1310 left_project.host_connection_id =
1311 ConnectionId(collaborator.connection_id as u32);
1312 }
1313 }
1314 }
1315 sqlx::query(
1316 "
1317 DELETE FROM project_collaborators
1318 WHERE connection_id = $1
1319 ",
1320 )
1321 .bind(connection_id.0 as i32)
1322 .execute(&mut tx)
1323 .await?;
1324
1325 // Unshare projects.
1326 sqlx::query(
1327 "
1328 DELETE FROM projects
1329 WHERE room_id = $1 AND host_connection_id = $2
1330 ",
1331 )
1332 .bind(room_id)
1333 .bind(connection_id.0 as i32)
1334 .execute(&mut tx)
1335 .await?;
1336
1337 let room = self.get_room(room_id, &mut tx).await?;
1338 Ok(Some(
1339 self.commit_room_transaction(
1340 room_id,
1341 tx,
1342 LeftRoom {
1343 room,
1344 left_projects,
1345 canceled_calls_to_user_ids,
1346 },
1347 )
1348 .await?,
1349 ))
1350 } else {
1351 Ok(None)
1352 }
1353 })
1354 .await
1355 }
1356
1357 pub async fn update_room_participant_location(
1358 &self,
1359 room_id: RoomId,
1360 connection_id: ConnectionId,
1361 location: proto::ParticipantLocation,
1362 ) -> Result<RoomGuard<proto::Room>> {
1363 self.transact(|tx| async {
1364 let mut tx = tx;
1365 let location_kind;
1366 let location_project_id;
1367 match location
1368 .variant
1369 .as_ref()
1370 .ok_or_else(|| anyhow!("invalid location"))?
1371 {
1372 proto::participant_location::Variant::SharedProject(project) => {
1373 location_kind = 0;
1374 location_project_id = Some(ProjectId::from_proto(project.id));
1375 }
1376 proto::participant_location::Variant::UnsharedProject(_) => {
1377 location_kind = 1;
1378 location_project_id = None;
1379 }
1380 proto::participant_location::Variant::External(_) => {
1381 location_kind = 2;
1382 location_project_id = None;
1383 }
1384 }
1385
1386 sqlx::query(
1387 "
1388 UPDATE room_participants
1389 SET location_kind = $1, location_project_id = $2
1390 WHERE room_id = $3 AND answering_connection_id = $4
1391 RETURNING 1
1392 ",
1393 )
1394 .bind(location_kind)
1395 .bind(location_project_id)
1396 .bind(room_id)
1397 .bind(connection_id.0 as i32)
1398 .fetch_one(&mut tx)
1399 .await?;
1400
1401 let room = self.get_room(room_id, &mut tx).await?;
1402 self.commit_room_transaction(room_id, tx, room).await
1403 })
1404 .await
1405 }
1406
1407 async fn get_guest_connection_ids(
1408 &self,
1409 project_id: ProjectId,
1410 tx: &mut sqlx::Transaction<'_, D>,
1411 ) -> Result<Vec<ConnectionId>> {
1412 let mut guest_connection_ids = Vec::new();
1413 let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>(
1414 "
1415 SELECT connection_id
1416 FROM project_collaborators
1417 WHERE project_id = $1 AND is_host = FALSE
1418 ",
1419 )
1420 .bind(project_id)
1421 .fetch(tx);
1422 while let Some(connection_id) = db_guest_connection_ids.next().await {
1423 guest_connection_ids.push(ConnectionId(connection_id? as u32));
1424 }
1425 Ok(guest_connection_ids)
1426 }
1427
1428 async fn get_room(
1429 &self,
1430 room_id: RoomId,
1431 tx: &mut sqlx::Transaction<'_, D>,
1432 ) -> Result<proto::Room> {
1433 let room: Room = sqlx::query_as(
1434 "
1435 SELECT *
1436 FROM rooms
1437 WHERE id = $1
1438 ",
1439 )
1440 .bind(room_id)
1441 .fetch_one(&mut *tx)
1442 .await?;
1443
1444 let mut db_participants =
1445 sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
1446 "
1447 SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
1448 FROM room_participants
1449 WHERE room_id = $1
1450 ",
1451 )
1452 .bind(room_id)
1453 .fetch(&mut *tx);
1454
1455 let mut participants = HashMap::default();
1456 let mut pending_participants = Vec::new();
1457 while let Some(participant) = db_participants.next().await {
1458 let (
1459 user_id,
1460 answering_connection_id,
1461 location_kind,
1462 location_project_id,
1463 calling_user_id,
1464 initial_project_id,
1465 ) = participant?;
1466 if let Some(answering_connection_id) = answering_connection_id {
1467 let location = match (location_kind, location_project_id) {
1468 (Some(0), Some(project_id)) => {
1469 Some(proto::participant_location::Variant::SharedProject(
1470 proto::participant_location::SharedProject {
1471 id: project_id.to_proto(),
1472 },
1473 ))
1474 }
1475 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
1476 Default::default(),
1477 )),
1478 _ => Some(proto::participant_location::Variant::External(
1479 Default::default(),
1480 )),
1481 };
1482 participants.insert(
1483 answering_connection_id,
1484 proto::Participant {
1485 user_id: user_id.to_proto(),
1486 peer_id: answering_connection_id as u32,
1487 projects: Default::default(),
1488 location: Some(proto::ParticipantLocation { variant: location }),
1489 },
1490 );
1491 } else {
1492 pending_participants.push(proto::PendingParticipant {
1493 user_id: user_id.to_proto(),
1494 calling_user_id: calling_user_id.to_proto(),
1495 initial_project_id: initial_project_id.map(|id| id.to_proto()),
1496 });
1497 }
1498 }
1499 drop(db_participants);
1500
1501 let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option<String>)>(
1502 "
1503 SELECT host_connection_id, projects.id, worktrees.root_name
1504 FROM projects
1505 LEFT JOIN worktrees ON projects.id = worktrees.project_id
1506 WHERE room_id = $1
1507 ",
1508 )
1509 .bind(room_id)
1510 .fetch(&mut *tx);
1511
1512 while let Some(row) = rows.next().await {
1513 let (connection_id, project_id, worktree_root_name) = row?;
1514 if let Some(participant) = participants.get_mut(&connection_id) {
1515 let project = if let Some(project) = participant
1516 .projects
1517 .iter_mut()
1518 .find(|project| project.id == project_id.to_proto())
1519 {
1520 project
1521 } else {
1522 participant.projects.push(proto::ParticipantProject {
1523 id: project_id.to_proto(),
1524 worktree_root_names: Default::default(),
1525 });
1526 participant.projects.last_mut().unwrap()
1527 };
1528 project.worktree_root_names.extend(worktree_root_name);
1529 }
1530 }
1531
1532 Ok(proto::Room {
1533 id: room.id.to_proto(),
1534 live_kit_room: room.live_kit_room,
1535 participants: participants.into_values().collect(),
1536 pending_participants,
1537 })
1538 }
1539
1540 // projects
1541
1542 pub async fn project_count_excluding_admins(&self) -> Result<usize> {
1543 self.transact(|mut tx| async move {
1544 Ok(sqlx::query_scalar::<_, i32>(
1545 "
1546 SELECT COUNT(*)
1547 FROM projects, users
1548 WHERE projects.host_user_id = users.id AND users.admin IS FALSE
1549 ",
1550 )
1551 .fetch_one(&mut tx)
1552 .await? as usize)
1553 })
1554 .await
1555 }
1556
1557 pub async fn share_project(
1558 &self,
1559 expected_room_id: RoomId,
1560 connection_id: ConnectionId,
1561 worktrees: &[proto::WorktreeMetadata],
1562 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
1563 self.transact(|mut tx| async move {
1564 let (sql, values) = self.build_query(
1565 Query::select()
1566 .columns([
1567 schema::room_participant::Definition::RoomId,
1568 schema::room_participant::Definition::UserId,
1569 ])
1570 .from(schema::room_participant::Definition::Table)
1571 .and_where(
1572 Expr::col(schema::room_participant::Definition::AnsweringConnectionId)
1573 .eq(connection_id.0),
1574 ),
1575 );
1576 let (room_id, user_id) = sqlx::query_as_with::<_, (RoomId, UserId), _>(&sql, values)
1577 .fetch_one(&mut tx)
1578 .await?;
1579 if room_id != expected_room_id {
1580 return Err(anyhow!("shared project on unexpected room"))?;
1581 }
1582
1583 let (sql, values) = self.build_query(
1584 Query::insert()
1585 .into_table(schema::project::Definition::Table)
1586 .columns([
1587 schema::project::Definition::RoomId,
1588 schema::project::Definition::HostUserId,
1589 schema::project::Definition::HostConnectionId,
1590 ])
1591 .values_panic([room_id.into(), user_id.into(), connection_id.0.into()])
1592 .returning_col(schema::project::Definition::Id),
1593 );
1594 let project_id: ProjectId = sqlx::query_scalar_with(&sql, values)
1595 .fetch_one(&mut tx)
1596 .await?;
1597
1598 if !worktrees.is_empty() {
1599 let mut query = Query::insert()
1600 .into_table(schema::worktree::Definition::Table)
1601 .columns([
1602 schema::worktree::Definition::ProjectId,
1603 schema::worktree::Definition::Id,
1604 schema::worktree::Definition::RootName,
1605 schema::worktree::Definition::AbsPath,
1606 schema::worktree::Definition::Visible,
1607 schema::worktree::Definition::ScanId,
1608 schema::worktree::Definition::IsComplete,
1609 ])
1610 .to_owned();
1611 for worktree in worktrees {
1612 query.values_panic([
1613 project_id.into(),
1614 worktree.id.into(),
1615 worktree.root_name.clone().into(),
1616 worktree.abs_path.clone().into(),
1617 worktree.visible.into(),
1618 0.into(),
1619 false.into(),
1620 ]);
1621 }
1622 let (sql, values) = self.build_query(&query);
1623 sqlx::query_with(&sql, values).execute(&mut tx).await?;
1624 }
1625
1626 sqlx::query(
1627 "
1628 INSERT INTO project_collaborators (
1629 project_id,
1630 connection_id,
1631 user_id,
1632 replica_id,
1633 is_host
1634 )
1635 VALUES ($1, $2, $3, $4, $5)
1636 ",
1637 )
1638 .bind(project_id)
1639 .bind(connection_id.0 as i32)
1640 .bind(user_id)
1641 .bind(0)
1642 .bind(true)
1643 .execute(&mut tx)
1644 .await?;
1645
1646 let room = self.get_room(room_id, &mut tx).await?;
1647 self.commit_room_transaction(room_id, tx, (project_id, room))
1648 .await
1649 })
1650 .await
1651 }
1652
1653 pub async fn unshare_project(
1654 &self,
1655 project_id: ProjectId,
1656 connection_id: ConnectionId,
1657 ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1658 self.transact(|mut tx| async move {
1659 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1660 let room_id: RoomId = sqlx::query_scalar(
1661 "
1662 DELETE FROM projects
1663 WHERE id = $1 AND host_connection_id = $2
1664 RETURNING room_id
1665 ",
1666 )
1667 .bind(project_id)
1668 .bind(connection_id.0 as i32)
1669 .fetch_one(&mut tx)
1670 .await?;
1671 let room = self.get_room(room_id, &mut tx).await?;
1672 self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
1673 .await
1674 })
1675 .await
1676 }
1677
1678 pub async fn update_project(
1679 &self,
1680 project_id: ProjectId,
1681 connection_id: ConnectionId,
1682 worktrees: &[proto::WorktreeMetadata],
1683 ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1684 self.transact(|mut tx| async move {
1685 let room_id: RoomId = sqlx::query_scalar(
1686 "
1687 SELECT room_id
1688 FROM projects
1689 WHERE id = $1 AND host_connection_id = $2
1690 ",
1691 )
1692 .bind(project_id)
1693 .bind(connection_id.0 as i32)
1694 .fetch_one(&mut tx)
1695 .await?;
1696
1697 if !worktrees.is_empty() {
1698 let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len());
1699 params.pop();
1700 let query = format!(
1701 "
1702 INSERT INTO worktrees (
1703 project_id,
1704 id,
1705 root_name,
1706 abs_path,
1707 visible,
1708 scan_id,
1709 is_complete
1710 )
1711 VALUES {params}
1712 ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name
1713 "
1714 );
1715
1716 let mut query = sqlx::query(&query);
1717 for worktree in worktrees {
1718 query = query
1719 .bind(project_id)
1720 .bind(worktree.id as i32)
1721 .bind(&worktree.root_name)
1722 .bind(&worktree.abs_path)
1723 .bind(worktree.visible)
1724 .bind(0)
1725 .bind(false)
1726 }
1727 query.execute(&mut tx).await?;
1728 }
1729
1730 let mut params = "?,".repeat(worktrees.len());
1731 if !worktrees.is_empty() {
1732 params.pop();
1733 }
1734 let query = format!(
1735 "
1736 DELETE FROM worktrees
1737 WHERE project_id = ? AND id NOT IN ({params})
1738 ",
1739 );
1740
1741 let mut query = sqlx::query(&query).bind(project_id);
1742 for worktree in worktrees {
1743 query = query.bind(WorktreeId(worktree.id as i32));
1744 }
1745 query.execute(&mut tx).await?;
1746
1747 let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1748 let room = self.get_room(room_id, &mut tx).await?;
1749 self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
1750 .await
1751 })
1752 .await
1753 }
1754
1755 pub async fn update_worktree(
1756 &self,
1757 update: &proto::UpdateWorktree,
1758 connection_id: ConnectionId,
1759 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1760 self.transact(|mut tx| async move {
1761 let project_id = ProjectId::from_proto(update.project_id);
1762 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1763
1764 // Ensure the update comes from the host.
1765 let room_id: RoomId = sqlx::query_scalar(
1766 "
1767 SELECT room_id
1768 FROM projects
1769 WHERE id = $1 AND host_connection_id = $2
1770 ",
1771 )
1772 .bind(project_id)
1773 .bind(connection_id.0 as i32)
1774 .fetch_one(&mut tx)
1775 .await?;
1776
1777 // Update metadata.
1778 sqlx::query(
1779 "
1780 UPDATE worktrees
1781 SET
1782 root_name = $1,
1783 scan_id = $2,
1784 is_complete = $3,
1785 abs_path = $4
1786 WHERE project_id = $5 AND id = $6
1787 RETURNING 1
1788 ",
1789 )
1790 .bind(&update.root_name)
1791 .bind(update.scan_id as i64)
1792 .bind(update.is_last_update)
1793 .bind(&update.abs_path)
1794 .bind(project_id)
1795 .bind(worktree_id)
1796 .fetch_one(&mut tx)
1797 .await?;
1798
1799 if !update.updated_entries.is_empty() {
1800 let mut params =
1801 "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len());
1802 params.pop();
1803
1804 let query = format!(
1805 "
1806 INSERT INTO worktree_entries (
1807 project_id,
1808 worktree_id,
1809 id,
1810 is_dir,
1811 path,
1812 inode,
1813 mtime_seconds,
1814 mtime_nanos,
1815 is_symlink,
1816 is_ignored
1817 )
1818 VALUES {params}
1819 ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET
1820 is_dir = excluded.is_dir,
1821 path = excluded.path,
1822 inode = excluded.inode,
1823 mtime_seconds = excluded.mtime_seconds,
1824 mtime_nanos = excluded.mtime_nanos,
1825 is_symlink = excluded.is_symlink,
1826 is_ignored = excluded.is_ignored
1827 "
1828 );
1829 let mut query = sqlx::query(&query);
1830 for entry in &update.updated_entries {
1831 let mtime = entry.mtime.clone().unwrap_or_default();
1832 query = query
1833 .bind(project_id)
1834 .bind(worktree_id)
1835 .bind(entry.id as i64)
1836 .bind(entry.is_dir)
1837 .bind(&entry.path)
1838 .bind(entry.inode as i64)
1839 .bind(mtime.seconds as i64)
1840 .bind(mtime.nanos as i32)
1841 .bind(entry.is_symlink)
1842 .bind(entry.is_ignored);
1843 }
1844 query.execute(&mut tx).await?;
1845 }
1846
1847 if !update.removed_entries.is_empty() {
1848 let mut params = "?,".repeat(update.removed_entries.len());
1849 params.pop();
1850 let query = format!(
1851 "
1852 DELETE FROM worktree_entries
1853 WHERE project_id = ? AND worktree_id = ? AND id IN ({params})
1854 "
1855 );
1856
1857 let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id);
1858 for entry_id in &update.removed_entries {
1859 query = query.bind(*entry_id as i64);
1860 }
1861 query.execute(&mut tx).await?;
1862 }
1863
1864 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1865 self.commit_room_transaction(room_id, tx, connection_ids)
1866 .await
1867 })
1868 .await
1869 }
1870
1871 pub async fn update_diagnostic_summary(
1872 &self,
1873 update: &proto::UpdateDiagnosticSummary,
1874 connection_id: ConnectionId,
1875 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1876 self.transact(|mut tx| async {
1877 let project_id = ProjectId::from_proto(update.project_id);
1878 let worktree_id = WorktreeId::from_proto(update.worktree_id);
1879 let summary = update
1880 .summary
1881 .as_ref()
1882 .ok_or_else(|| anyhow!("invalid summary"))?;
1883
1884 // Ensure the update comes from the host.
1885 let room_id: RoomId = sqlx::query_scalar(
1886 "
1887 SELECT room_id
1888 FROM projects
1889 WHERE id = $1 AND host_connection_id = $2
1890 ",
1891 )
1892 .bind(project_id)
1893 .bind(connection_id.0 as i32)
1894 .fetch_one(&mut tx)
1895 .await?;
1896
1897 // Update summary.
1898 sqlx::query(
1899 "
1900 INSERT INTO worktree_diagnostic_summaries (
1901 project_id,
1902 worktree_id,
1903 path,
1904 language_server_id,
1905 error_count,
1906 warning_count
1907 )
1908 VALUES ($1, $2, $3, $4, $5, $6)
1909 ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET
1910 language_server_id = excluded.language_server_id,
1911 error_count = excluded.error_count,
1912 warning_count = excluded.warning_count
1913 ",
1914 )
1915 .bind(project_id)
1916 .bind(worktree_id)
1917 .bind(&summary.path)
1918 .bind(summary.language_server_id as i64)
1919 .bind(summary.error_count as i32)
1920 .bind(summary.warning_count as i32)
1921 .execute(&mut tx)
1922 .await?;
1923
1924 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1925 self.commit_room_transaction(room_id, tx, connection_ids)
1926 .await
1927 })
1928 .await
1929 }
1930
1931 pub async fn start_language_server(
1932 &self,
1933 update: &proto::StartLanguageServer,
1934 connection_id: ConnectionId,
1935 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1936 self.transact(|mut tx| async {
1937 let project_id = ProjectId::from_proto(update.project_id);
1938 let server = update
1939 .server
1940 .as_ref()
1941 .ok_or_else(|| anyhow!("invalid language server"))?;
1942
1943 // Ensure the update comes from the host.
1944 let room_id: RoomId = sqlx::query_scalar(
1945 "
1946 SELECT room_id
1947 FROM projects
1948 WHERE id = $1 AND host_connection_id = $2
1949 ",
1950 )
1951 .bind(project_id)
1952 .bind(connection_id.0 as i32)
1953 .fetch_one(&mut tx)
1954 .await?;
1955
1956 // Add the newly-started language server.
1957 sqlx::query(
1958 "
1959 INSERT INTO language_servers (project_id, id, name)
1960 VALUES ($1, $2, $3)
1961 ON CONFLICT (project_id, id) DO UPDATE SET
1962 name = excluded.name
1963 ",
1964 )
1965 .bind(project_id)
1966 .bind(server.id as i64)
1967 .bind(&server.name)
1968 .execute(&mut tx)
1969 .await?;
1970
1971 let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?;
1972 self.commit_room_transaction(room_id, tx, connection_ids)
1973 .await
1974 })
1975 .await
1976 }
1977
1978 pub async fn join_project(
1979 &self,
1980 project_id: ProjectId,
1981 connection_id: ConnectionId,
1982 ) -> Result<RoomGuard<(Project, ReplicaId)>> {
1983 self.transact(|mut tx| async move {
1984 let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>(
1985 "
1986 SELECT room_id, user_id
1987 FROM room_participants
1988 WHERE answering_connection_id = $1
1989 ",
1990 )
1991 .bind(connection_id.0 as i32)
1992 .fetch_one(&mut tx)
1993 .await?;
1994
1995 // Ensure project id was shared on this room.
1996 sqlx::query(
1997 "
1998 SELECT 1
1999 FROM projects
2000 WHERE id = $1 AND room_id = $2
2001 ",
2002 )
2003 .bind(project_id)
2004 .bind(room_id)
2005 .fetch_one(&mut tx)
2006 .await?;
2007
2008 let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2009 "
2010 SELECT *
2011 FROM project_collaborators
2012 WHERE project_id = $1
2013 ",
2014 )
2015 .bind(project_id)
2016 .fetch_all(&mut tx)
2017 .await?;
2018 let replica_ids = collaborators
2019 .iter()
2020 .map(|c| c.replica_id)
2021 .collect::<HashSet<_>>();
2022 let mut replica_id = ReplicaId(1);
2023 while replica_ids.contains(&replica_id) {
2024 replica_id.0 += 1;
2025 }
2026 let new_collaborator = ProjectCollaborator {
2027 project_id,
2028 connection_id: connection_id.0 as i32,
2029 user_id,
2030 replica_id,
2031 is_host: false,
2032 };
2033
2034 sqlx::query(
2035 "
2036 INSERT INTO project_collaborators (
2037 project_id,
2038 connection_id,
2039 user_id,
2040 replica_id,
2041 is_host
2042 )
2043 VALUES ($1, $2, $3, $4, $5)
2044 ",
2045 )
2046 .bind(new_collaborator.project_id)
2047 .bind(new_collaborator.connection_id)
2048 .bind(new_collaborator.user_id)
2049 .bind(new_collaborator.replica_id)
2050 .bind(new_collaborator.is_host)
2051 .execute(&mut tx)
2052 .await?;
2053 collaborators.push(new_collaborator);
2054
2055 let worktree_rows = sqlx::query_as::<_, WorktreeRow>(
2056 "
2057 SELECT *
2058 FROM worktrees
2059 WHERE project_id = $1
2060 ",
2061 )
2062 .bind(project_id)
2063 .fetch_all(&mut tx)
2064 .await?;
2065 let mut worktrees = worktree_rows
2066 .into_iter()
2067 .map(|worktree_row| {
2068 (
2069 worktree_row.id,
2070 Worktree {
2071 id: worktree_row.id,
2072 abs_path: worktree_row.abs_path,
2073 root_name: worktree_row.root_name,
2074 visible: worktree_row.visible,
2075 entries: Default::default(),
2076 diagnostic_summaries: Default::default(),
2077 scan_id: worktree_row.scan_id as u64,
2078 is_complete: worktree_row.is_complete,
2079 },
2080 )
2081 })
2082 .collect::<BTreeMap<_, _>>();
2083
2084 // Populate worktree entries.
2085 {
2086 let mut entries = sqlx::query_as::<_, WorktreeEntry>(
2087 "
2088 SELECT *
2089 FROM worktree_entries
2090 WHERE project_id = $1
2091 ",
2092 )
2093 .bind(project_id)
2094 .fetch(&mut tx);
2095 while let Some(entry) = entries.next().await {
2096 let entry = entry?;
2097 if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) {
2098 worktree.entries.push(proto::Entry {
2099 id: entry.id as u64,
2100 is_dir: entry.is_dir,
2101 path: entry.path,
2102 inode: entry.inode as u64,
2103 mtime: Some(proto::Timestamp {
2104 seconds: entry.mtime_seconds as u64,
2105 nanos: entry.mtime_nanos as u32,
2106 }),
2107 is_symlink: entry.is_symlink,
2108 is_ignored: entry.is_ignored,
2109 });
2110 }
2111 }
2112 }
2113
2114 // Populate worktree diagnostic summaries.
2115 {
2116 let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(
2117 "
2118 SELECT *
2119 FROM worktree_diagnostic_summaries
2120 WHERE project_id = $1
2121 ",
2122 )
2123 .bind(project_id)
2124 .fetch(&mut tx);
2125 while let Some(summary) = summaries.next().await {
2126 let summary = summary?;
2127 if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) {
2128 worktree
2129 .diagnostic_summaries
2130 .push(proto::DiagnosticSummary {
2131 path: summary.path,
2132 language_server_id: summary.language_server_id as u64,
2133 error_count: summary.error_count as u32,
2134 warning_count: summary.warning_count as u32,
2135 });
2136 }
2137 }
2138 }
2139
2140 // Populate language servers.
2141 let language_servers = sqlx::query_as::<_, LanguageServer>(
2142 "
2143 SELECT *
2144 FROM language_servers
2145 WHERE project_id = $1
2146 ",
2147 )
2148 .bind(project_id)
2149 .fetch_all(&mut tx)
2150 .await?;
2151
2152 self.commit_room_transaction(
2153 room_id,
2154 tx,
2155 (
2156 Project {
2157 collaborators,
2158 worktrees,
2159 language_servers: language_servers
2160 .into_iter()
2161 .map(|language_server| proto::LanguageServer {
2162 id: language_server.id.to_proto(),
2163 name: language_server.name,
2164 })
2165 .collect(),
2166 },
2167 replica_id as ReplicaId,
2168 ),
2169 )
2170 .await
2171 })
2172 .await
2173 }
2174
2175 pub async fn leave_project(
2176 &self,
2177 project_id: ProjectId,
2178 connection_id: ConnectionId,
2179 ) -> Result<RoomGuard<LeftProject>> {
2180 self.transact(|mut tx| async move {
2181 let result = sqlx::query(
2182 "
2183 DELETE FROM project_collaborators
2184 WHERE project_id = $1 AND connection_id = $2
2185 ",
2186 )
2187 .bind(project_id)
2188 .bind(connection_id.0 as i32)
2189 .execute(&mut tx)
2190 .await?;
2191
2192 if result.rows_affected() == 0 {
2193 Err(anyhow!("not a collaborator on this project"))?;
2194 }
2195
2196 let connection_ids = sqlx::query_scalar::<_, i32>(
2197 "
2198 SELECT connection_id
2199 FROM project_collaborators
2200 WHERE project_id = $1
2201 ",
2202 )
2203 .bind(project_id)
2204 .fetch_all(&mut tx)
2205 .await?
2206 .into_iter()
2207 .map(|id| ConnectionId(id as u32))
2208 .collect();
2209
2210 let (room_id, host_user_id, host_connection_id) =
2211 sqlx::query_as::<_, (RoomId, i32, i32)>(
2212 "
2213 SELECT room_id, host_user_id, host_connection_id
2214 FROM projects
2215 WHERE id = $1
2216 ",
2217 )
2218 .bind(project_id)
2219 .fetch_one(&mut tx)
2220 .await?;
2221
2222 self.commit_room_transaction(
2223 room_id,
2224 tx,
2225 LeftProject {
2226 id: project_id,
2227 host_user_id: UserId(host_user_id),
2228 host_connection_id: ConnectionId(host_connection_id as u32),
2229 connection_ids,
2230 },
2231 )
2232 .await
2233 })
2234 .await
2235 }
2236
2237 pub async fn project_collaborators(
2238 &self,
2239 project_id: ProjectId,
2240 connection_id: ConnectionId,
2241 ) -> Result<Vec<ProjectCollaborator>> {
2242 self.transact(|mut tx| async move {
2243 let collaborators = sqlx::query_as::<_, ProjectCollaborator>(
2244 "
2245 SELECT *
2246 FROM project_collaborators
2247 WHERE project_id = $1
2248 ",
2249 )
2250 .bind(project_id)
2251 .fetch_all(&mut tx)
2252 .await?;
2253
2254 if collaborators
2255 .iter()
2256 .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
2257 {
2258 Ok(collaborators)
2259 } else {
2260 Err(anyhow!("no such project"))?
2261 }
2262 })
2263 .await
2264 }
2265
2266 pub async fn project_connection_ids(
2267 &self,
2268 project_id: ProjectId,
2269 connection_id: ConnectionId,
2270 ) -> Result<HashSet<ConnectionId>> {
2271 self.transact(|mut tx| async move {
2272 let connection_ids = sqlx::query_scalar::<_, i32>(
2273 "
2274 SELECT connection_id
2275 FROM project_collaborators
2276 WHERE project_id = $1
2277 ",
2278 )
2279 .bind(project_id)
2280 .fetch_all(&mut tx)
2281 .await?;
2282
2283 if connection_ids.contains(&(connection_id.0 as i32)) {
2284 Ok(connection_ids
2285 .into_iter()
2286 .map(|connection_id| ConnectionId(connection_id as u32))
2287 .collect())
2288 } else {
2289 Err(anyhow!("no such project"))?
2290 }
2291 })
2292 .await
2293 }
2294
2295 // contacts
2296
2297 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
2298 self.transact(|mut tx| async move {
2299 let query = "
2300 SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy
2301 FROM contacts
2302 LEFT JOIN room_participants ON room_participants.user_id = $1
2303 WHERE user_id_a = $1 OR user_id_b = $1;
2304 ";
2305
2306 let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query)
2307 .bind(user_id)
2308 .fetch(&mut tx);
2309
2310 let mut contacts = Vec::new();
2311 while let Some(row) = rows.next().await {
2312 let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?;
2313 if user_id_a == user_id {
2314 if accepted {
2315 contacts.push(Contact::Accepted {
2316 user_id: user_id_b,
2317 should_notify: should_notify && a_to_b,
2318 busy
2319 });
2320 } else if a_to_b {
2321 contacts.push(Contact::Outgoing { user_id: user_id_b })
2322 } else {
2323 contacts.push(Contact::Incoming {
2324 user_id: user_id_b,
2325 should_notify,
2326 });
2327 }
2328 } else if accepted {
2329 contacts.push(Contact::Accepted {
2330 user_id: user_id_a,
2331 should_notify: should_notify && !a_to_b,
2332 busy
2333 });
2334 } else if a_to_b {
2335 contacts.push(Contact::Incoming {
2336 user_id: user_id_a,
2337 should_notify,
2338 });
2339 } else {
2340 contacts.push(Contact::Outgoing { user_id: user_id_a });
2341 }
2342 }
2343
2344 contacts.sort_unstable_by_key(|contact| contact.user_id());
2345
2346 Ok(contacts)
2347 })
2348 .await
2349 }
2350
2351 pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
2352 self.transact(|mut tx| async move {
2353 Ok(sqlx::query_scalar::<_, i32>(
2354 "
2355 SELECT 1
2356 FROM room_participants
2357 WHERE room_participants.user_id = $1
2358 ",
2359 )
2360 .bind(user_id)
2361 .fetch_optional(&mut tx)
2362 .await?
2363 .is_some())
2364 })
2365 .await
2366 }
2367
2368 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
2369 self.transact(|mut tx| async move {
2370 let (id_a, id_b) = if user_id_1 < user_id_2 {
2371 (user_id_1, user_id_2)
2372 } else {
2373 (user_id_2, user_id_1)
2374 };
2375
2376 let query = "
2377 SELECT 1 FROM contacts
2378 WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE
2379 LIMIT 1
2380 ";
2381 Ok(sqlx::query_scalar::<_, i32>(query)
2382 .bind(id_a.0)
2383 .bind(id_b.0)
2384 .fetch_optional(&mut tx)
2385 .await?
2386 .is_some())
2387 })
2388 .await
2389 }
2390
2391 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
2392 self.transact(|mut tx| async move {
2393 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
2394 (sender_id, receiver_id, true)
2395 } else {
2396 (receiver_id, sender_id, false)
2397 };
2398 let query = "
2399 INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
2400 VALUES ($1, $2, $3, FALSE, TRUE)
2401 ON CONFLICT (user_id_a, user_id_b) DO UPDATE
2402 SET
2403 accepted = TRUE,
2404 should_notify = FALSE
2405 WHERE
2406 NOT contacts.accepted AND
2407 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
2408 (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
2409 ";
2410 let result = sqlx::query(query)
2411 .bind(id_a.0)
2412 .bind(id_b.0)
2413 .bind(a_to_b)
2414 .execute(&mut tx)
2415 .await?;
2416
2417 if result.rows_affected() == 1 {
2418 tx.commit().await?;
2419 Ok(())
2420 } else {
2421 Err(anyhow!("contact already requested"))?
2422 }
2423 }).await
2424 }
2425
2426 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
2427 self.transact(|mut tx| async move {
2428 let (id_a, id_b) = if responder_id < requester_id {
2429 (responder_id, requester_id)
2430 } else {
2431 (requester_id, responder_id)
2432 };
2433 let query = "
2434 DELETE FROM contacts
2435 WHERE user_id_a = $1 AND user_id_b = $2;
2436 ";
2437 let result = sqlx::query(query)
2438 .bind(id_a.0)
2439 .bind(id_b.0)
2440 .execute(&mut tx)
2441 .await?;
2442
2443 if result.rows_affected() == 1 {
2444 tx.commit().await?;
2445 Ok(())
2446 } else {
2447 Err(anyhow!("no such contact"))?
2448 }
2449 })
2450 .await
2451 }
2452
2453 pub async fn dismiss_contact_notification(
2454 &self,
2455 user_id: UserId,
2456 contact_user_id: UserId,
2457 ) -> Result<()> {
2458 self.transact(|mut tx| async move {
2459 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
2460 (user_id, contact_user_id, true)
2461 } else {
2462 (contact_user_id, user_id, false)
2463 };
2464
2465 let query = "
2466 UPDATE contacts
2467 SET should_notify = FALSE
2468 WHERE
2469 user_id_a = $1 AND user_id_b = $2 AND
2470 (
2471 (a_to_b = $3 AND accepted) OR
2472 (a_to_b != $3 AND NOT accepted)
2473 );
2474 ";
2475
2476 let result = sqlx::query(query)
2477 .bind(id_a.0)
2478 .bind(id_b.0)
2479 .bind(a_to_b)
2480 .execute(&mut tx)
2481 .await?;
2482
2483 if result.rows_affected() == 0 {
2484 Err(anyhow!("no such contact request"))?
2485 } else {
2486 tx.commit().await?;
2487 Ok(())
2488 }
2489 })
2490 .await
2491 }
2492
2493 pub async fn respond_to_contact_request(
2494 &self,
2495 responder_id: UserId,
2496 requester_id: UserId,
2497 accept: bool,
2498 ) -> Result<()> {
2499 self.transact(|mut tx| async move {
2500 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
2501 (responder_id, requester_id, false)
2502 } else {
2503 (requester_id, responder_id, true)
2504 };
2505 let result = if accept {
2506 let query = "
2507 UPDATE contacts
2508 SET accepted = TRUE, should_notify = TRUE
2509 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
2510 ";
2511 sqlx::query(query)
2512 .bind(id_a.0)
2513 .bind(id_b.0)
2514 .bind(a_to_b)
2515 .execute(&mut tx)
2516 .await?
2517 } else {
2518 let query = "
2519 DELETE FROM contacts
2520 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
2521 ";
2522 sqlx::query(query)
2523 .bind(id_a.0)
2524 .bind(id_b.0)
2525 .bind(a_to_b)
2526 .execute(&mut tx)
2527 .await?
2528 };
2529 if result.rows_affected() == 1 {
2530 tx.commit().await?;
2531 Ok(())
2532 } else {
2533 Err(anyhow!("no such contact request"))?
2534 }
2535 })
2536 .await
2537 }
2538
2539 // access tokens
2540
2541 pub async fn create_access_token_hash(
2542 &self,
2543 user_id: UserId,
2544 access_token_hash: &str,
2545 max_access_token_count: usize,
2546 ) -> Result<()> {
2547 self.transact(|tx| async {
2548 let mut tx = tx;
2549 let insert_query = "
2550 INSERT INTO access_tokens (user_id, hash)
2551 VALUES ($1, $2);
2552 ";
2553 let cleanup_query = "
2554 DELETE FROM access_tokens
2555 WHERE id IN (
2556 SELECT id from access_tokens
2557 WHERE user_id = $1
2558 ORDER BY id DESC
2559 LIMIT 10000
2560 OFFSET $3
2561 )
2562 ";
2563
2564 sqlx::query(insert_query)
2565 .bind(user_id.0)
2566 .bind(access_token_hash)
2567 .execute(&mut tx)
2568 .await?;
2569 sqlx::query(cleanup_query)
2570 .bind(user_id.0)
2571 .bind(access_token_hash)
2572 .bind(max_access_token_count as i32)
2573 .execute(&mut tx)
2574 .await?;
2575 Ok(tx.commit().await?)
2576 })
2577 .await
2578 }
2579
2580 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
2581 self.transact(|mut tx| async move {
2582 let query = "
2583 SELECT hash
2584 FROM access_tokens
2585 WHERE user_id = $1
2586 ORDER BY id DESC
2587 ";
2588 Ok(sqlx::query_scalar(query)
2589 .bind(user_id.0)
2590 .fetch_all(&mut tx)
2591 .await?)
2592 })
2593 .await
2594 }
2595
2596 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
2597 where
2598 F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut,
2599 Fut: Send + Future<Output = Result<T>>,
2600 {
2601 let body = async {
2602 loop {
2603 let tx = self.begin_transaction().await?;
2604 match f(tx).await {
2605 Ok(result) => return Ok(result),
2606 Err(error) => match error {
2607 Error::Database(error)
2608 if error
2609 .as_database_error()
2610 .and_then(|error| error.code())
2611 .as_deref()
2612 == Some("40001") =>
2613 {
2614 // Retry (don't break the loop)
2615 }
2616 error @ _ => return Err(error),
2617 },
2618 }
2619 }
2620 };
2621
2622 #[cfg(test)]
2623 {
2624 if let Some(background) = self.background.as_ref() {
2625 background.simulate_random_delay().await;
2626 }
2627
2628 self.runtime.as_ref().unwrap().block_on(body)
2629 }
2630
2631 #[cfg(not(test))]
2632 {
2633 body.await
2634 }
2635 }
2636}
2637
2638macro_rules! id_type {
2639 ($name:ident) => {
2640 #[derive(
2641 Clone,
2642 Copy,
2643 Debug,
2644 Default,
2645 PartialEq,
2646 Eq,
2647 PartialOrd,
2648 Ord,
2649 Hash,
2650 sqlx::Type,
2651 Serialize,
2652 Deserialize,
2653 )]
2654 #[sqlx(transparent)]
2655 #[serde(transparent)]
2656 pub struct $name(pub i32);
2657
2658 impl $name {
2659 #[allow(unused)]
2660 pub const MAX: Self = Self(i32::MAX);
2661
2662 #[allow(unused)]
2663 pub fn from_proto(value: u64) -> Self {
2664 Self(value as i32)
2665 }
2666
2667 #[allow(unused)]
2668 pub fn to_proto(self) -> u64 {
2669 self.0 as u64
2670 }
2671 }
2672
2673 impl std::fmt::Display for $name {
2674 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2675 self.0.fmt(f)
2676 }
2677 }
2678
2679 impl From<$name> for sea_query::Value {
2680 fn from(value: $name) -> Self {
2681 sea_query::Value::Int(Some(value.0))
2682 }
2683 }
2684 };
2685}
2686
2687id_type!(UserId);
2688#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2689pub struct User {
2690 pub id: UserId,
2691 pub github_login: String,
2692 pub github_user_id: Option<i32>,
2693 pub email_address: Option<String>,
2694 pub admin: bool,
2695 pub invite_code: Option<String>,
2696 pub invite_count: i32,
2697 pub connected_once: bool,
2698}
2699
2700id_type!(RoomId);
2701#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)]
2702pub struct Room {
2703 pub id: RoomId,
2704 pub live_kit_room: String,
2705}
2706
2707id_type!(ProjectId);
2708pub struct Project {
2709 pub collaborators: Vec<ProjectCollaborator>,
2710 pub worktrees: BTreeMap<WorktreeId, Worktree>,
2711 pub language_servers: Vec<proto::LanguageServer>,
2712}
2713
2714id_type!(ReplicaId);
2715#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2716pub struct ProjectCollaborator {
2717 pub project_id: ProjectId,
2718 pub connection_id: i32,
2719 pub user_id: UserId,
2720 pub replica_id: ReplicaId,
2721 pub is_host: bool,
2722}
2723
2724id_type!(WorktreeId);
2725#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2726struct WorktreeRow {
2727 pub id: WorktreeId,
2728 pub project_id: ProjectId,
2729 pub abs_path: String,
2730 pub root_name: String,
2731 pub visible: bool,
2732 pub scan_id: i64,
2733 pub is_complete: bool,
2734}
2735
2736pub struct Worktree {
2737 pub id: WorktreeId,
2738 pub abs_path: String,
2739 pub root_name: String,
2740 pub visible: bool,
2741 pub entries: Vec<proto::Entry>,
2742 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
2743 pub scan_id: u64,
2744 pub is_complete: bool,
2745}
2746
2747#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2748struct WorktreeEntry {
2749 id: i64,
2750 worktree_id: WorktreeId,
2751 is_dir: bool,
2752 path: String,
2753 inode: i64,
2754 mtime_seconds: i64,
2755 mtime_nanos: i32,
2756 is_symlink: bool,
2757 is_ignored: bool,
2758}
2759
2760#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2761struct WorktreeDiagnosticSummary {
2762 worktree_id: WorktreeId,
2763 path: String,
2764 language_server_id: i64,
2765 error_count: i32,
2766 warning_count: i32,
2767}
2768
2769id_type!(LanguageServerId);
2770#[derive(Clone, Debug, Default, FromRow, PartialEq)]
2771struct LanguageServer {
2772 id: LanguageServerId,
2773 name: String,
2774}
2775
2776pub struct LeftProject {
2777 pub id: ProjectId,
2778 pub host_user_id: UserId,
2779 pub host_connection_id: ConnectionId,
2780 pub connection_ids: Vec<ConnectionId>,
2781}
2782
2783pub struct LeftRoom {
2784 pub room: proto::Room,
2785 pub left_projects: HashMap<ProjectId, LeftProject>,
2786 pub canceled_calls_to_user_ids: Vec<UserId>,
2787}
2788
2789#[derive(Clone, Debug, PartialEq, Eq)]
2790pub enum Contact {
2791 Accepted {
2792 user_id: UserId,
2793 should_notify: bool,
2794 busy: bool,
2795 },
2796 Outgoing {
2797 user_id: UserId,
2798 },
2799 Incoming {
2800 user_id: UserId,
2801 should_notify: bool,
2802 },
2803}
2804
2805impl Contact {
2806 pub fn user_id(&self) -> UserId {
2807 match self {
2808 Contact::Accepted { user_id, .. } => *user_id,
2809 Contact::Outgoing { user_id } => *user_id,
2810 Contact::Incoming { user_id, .. } => *user_id,
2811 }
2812 }
2813}
2814
2815#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
2816pub struct IncomingContactRequest {
2817 pub requester_id: UserId,
2818 pub should_notify: bool,
2819}
2820
2821#[derive(Clone, Deserialize)]
2822pub struct Signup {
2823 pub email_address: String,
2824 pub platform_mac: bool,
2825 pub platform_windows: bool,
2826 pub platform_linux: bool,
2827 pub editor_features: Vec<String>,
2828 pub programming_languages: Vec<String>,
2829 pub device_id: Option<String>,
2830}
2831
2832#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)]
2833pub struct WaitlistSummary {
2834 #[sqlx(default)]
2835 pub count: i64,
2836 #[sqlx(default)]
2837 pub linux_count: i64,
2838 #[sqlx(default)]
2839 pub mac_count: i64,
2840 #[sqlx(default)]
2841 pub windows_count: i64,
2842 #[sqlx(default)]
2843 pub unknown_count: i64,
2844}
2845
2846#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)]
2847pub struct Invite {
2848 pub email_address: String,
2849 pub email_confirmation_code: String,
2850}
2851
2852#[derive(Debug, Serialize, Deserialize)]
2853pub struct NewUserParams {
2854 pub github_login: String,
2855 pub github_user_id: i32,
2856 pub invite_count: i32,
2857}
2858
2859#[derive(Debug)]
2860pub struct NewUserResult {
2861 pub user_id: UserId,
2862 pub metrics_id: String,
2863 pub inviting_user_id: Option<UserId>,
2864 pub signup_device_id: Option<String>,
2865}
2866
2867fn random_invite_code() -> String {
2868 nanoid::nanoid!(16)
2869}
2870
2871fn random_email_confirmation_code() -> String {
2872 nanoid::nanoid!(64)
2873}
2874
2875#[cfg(test)]
2876pub use test::*;
2877
2878#[cfg(test)]
2879mod test {
2880 use super::*;
2881 use gpui::executor::Background;
2882 use lazy_static::lazy_static;
2883 use parking_lot::Mutex;
2884 use rand::prelude::*;
2885 use sqlx::migrate::MigrateDatabase;
2886 use std::sync::Arc;
2887
2888 pub struct SqliteTestDb {
2889 pub db: Option<Arc<Db<sqlx::Sqlite>>>,
2890 pub conn: sqlx::sqlite::SqliteConnection,
2891 }
2892
2893 pub struct PostgresTestDb {
2894 pub db: Option<Arc<Db<sqlx::Postgres>>>,
2895 pub url: String,
2896 }
2897
2898 impl SqliteTestDb {
2899 pub fn new(background: Arc<Background>) -> Self {
2900 let mut rng = StdRng::from_entropy();
2901 let url = format!("file:zed-test-{}?mode=memory", rng.gen::<u128>());
2902 let runtime = tokio::runtime::Builder::new_current_thread()
2903 .enable_io()
2904 .enable_time()
2905 .build()
2906 .unwrap();
2907
2908 let (mut db, conn) = runtime.block_on(async {
2909 let db = Db::<sqlx::Sqlite>::new(&url, 5).await.unwrap();
2910 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
2911 db.migrate(migrations_path.as_ref(), false).await.unwrap();
2912 let conn = db.pool.acquire().await.unwrap().detach();
2913 (db, conn)
2914 });
2915
2916 db.background = Some(background);
2917 db.runtime = Some(runtime);
2918
2919 Self {
2920 db: Some(Arc::new(db)),
2921 conn,
2922 }
2923 }
2924
2925 pub fn db(&self) -> &Arc<Db<sqlx::Sqlite>> {
2926 self.db.as_ref().unwrap()
2927 }
2928 }
2929
2930 impl PostgresTestDb {
2931 pub fn new(background: Arc<Background>) -> Self {
2932 lazy_static! {
2933 static ref LOCK: Mutex<()> = Mutex::new(());
2934 }
2935
2936 let _guard = LOCK.lock();
2937 let mut rng = StdRng::from_entropy();
2938 let url = format!(
2939 "postgres://postgres@localhost/zed-test-{}",
2940 rng.gen::<u128>()
2941 );
2942 let runtime = tokio::runtime::Builder::new_current_thread()
2943 .enable_io()
2944 .enable_time()
2945 .build()
2946 .unwrap();
2947
2948 let mut db = runtime.block_on(async {
2949 sqlx::Postgres::create_database(&url)
2950 .await
2951 .expect("failed to create test db");
2952 let db = Db::<sqlx::Postgres>::new(&url, 5).await.unwrap();
2953 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2954 db.migrate(Path::new(migrations_path), false).await.unwrap();
2955 db
2956 });
2957
2958 db.background = Some(background);
2959 db.runtime = Some(runtime);
2960
2961 Self {
2962 db: Some(Arc::new(db)),
2963 url,
2964 }
2965 }
2966
2967 pub fn db(&self) -> &Arc<Db<sqlx::Postgres>> {
2968 self.db.as_ref().unwrap()
2969 }
2970 }
2971
2972 impl Drop for PostgresTestDb {
2973 fn drop(&mut self) {
2974 let db = self.db.take().unwrap();
2975 db.teardown(&self.url);
2976 }
2977 }
2978}