1mod access_token;
2mod contact;
3mod project;
4mod project_collaborator;
5mod room;
6mod room_participant;
7mod signup;
8#[cfg(test)]
9mod tests;
10mod user;
11mod worktree;
12
13use crate::{Error, Result};
14use anyhow::anyhow;
15use collections::HashMap;
16use dashmap::DashMap;
17use futures::StreamExt;
18use hyper::StatusCode;
19use rpc::{proto, ConnectionId};
20use sea_orm::{
21 entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
22 TransactionTrait,
23};
24use sea_orm::{
25 ActiveValue, ConnectionTrait, DatabaseBackend, FromQueryResult, IntoActiveModel, JoinType,
26 QueryOrder, QuerySelect, Statement,
27};
28use sea_query::{Alias, Expr, OnConflict, Query};
29use serde::{Deserialize, Serialize};
30use sqlx::migrate::{Migrate, Migration, MigrationSource};
31use sqlx::Connection;
32use std::ops::{Deref, DerefMut};
33use std::path::Path;
34use std::time::Duration;
35use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
36use tokio::sync::{Mutex, OwnedMutexGuard};
37
38pub use contact::Contact;
39pub use signup::Invite;
40pub use user::Model as User;
41
42pub struct Database {
43 options: ConnectOptions,
44 pool: DatabaseConnection,
45 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
46 #[cfg(test)]
47 background: Option<std::sync::Arc<gpui::executor::Background>>,
48 #[cfg(test)]
49 runtime: Option<tokio::runtime::Runtime>,
50}
51
52impl Database {
53 pub async fn new(options: ConnectOptions) -> Result<Self> {
54 Ok(Self {
55 options: options.clone(),
56 pool: sea_orm::Database::connect(options).await?,
57 rooms: DashMap::with_capacity(16384),
58 #[cfg(test)]
59 background: None,
60 #[cfg(test)]
61 runtime: None,
62 })
63 }
64
65 pub async fn migrate(
66 &self,
67 migrations_path: &Path,
68 ignore_checksum_mismatch: bool,
69 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
70 let migrations = MigrationSource::resolve(migrations_path)
71 .await
72 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
73
74 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
75
76 connection.ensure_migrations_table().await?;
77 let applied_migrations: HashMap<_, _> = connection
78 .list_applied_migrations()
79 .await?
80 .into_iter()
81 .map(|m| (m.version, m))
82 .collect();
83
84 let mut new_migrations = Vec::new();
85 for migration in migrations {
86 match applied_migrations.get(&migration.version) {
87 Some(applied_migration) => {
88 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
89 {
90 Err(anyhow!(
91 "checksum mismatch for applied migration {}",
92 migration.description
93 ))?;
94 }
95 }
96 None => {
97 let elapsed = connection.apply(&migration).await?;
98 new_migrations.push((migration, elapsed));
99 }
100 }
101 }
102
103 Ok(new_migrations)
104 }
105
106 // users
107
108 pub async fn create_user(
109 &self,
110 email_address: &str,
111 admin: bool,
112 params: NewUserParams,
113 ) -> Result<NewUserResult> {
114 self.transact(|tx| async {
115 let user = user::Entity::insert(user::ActiveModel {
116 email_address: ActiveValue::set(Some(email_address.into())),
117 github_login: ActiveValue::set(params.github_login.clone()),
118 github_user_id: ActiveValue::set(Some(params.github_user_id)),
119 admin: ActiveValue::set(admin),
120 metrics_id: ActiveValue::set(Uuid::new_v4()),
121 ..Default::default()
122 })
123 .on_conflict(
124 OnConflict::column(user::Column::GithubLogin)
125 .update_column(user::Column::GithubLogin)
126 .to_owned(),
127 )
128 .exec_with_returning(&tx)
129 .await?;
130
131 tx.commit().await?;
132
133 Ok(NewUserResult {
134 user_id: user.id,
135 metrics_id: user.metrics_id.to_string(),
136 signup_device_id: None,
137 inviting_user_id: None,
138 })
139 })
140 .await
141 }
142
143 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
144 self.transact(|tx| async {
145 let tx = tx;
146 Ok(user::Entity::find()
147 .filter(user::Column::Id.is_in(ids.iter().copied()))
148 .all(&tx)
149 .await?)
150 })
151 .await
152 }
153
154 pub async fn get_user_by_github_account(
155 &self,
156 github_login: &str,
157 github_user_id: Option<i32>,
158 ) -> Result<Option<User>> {
159 self.transact(|tx| async {
160 let tx = tx;
161 if let Some(github_user_id) = github_user_id {
162 if let Some(user_by_github_user_id) = user::Entity::find()
163 .filter(user::Column::GithubUserId.eq(github_user_id))
164 .one(&tx)
165 .await?
166 {
167 let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
168 user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
169 Ok(Some(user_by_github_user_id.update(&tx).await?))
170 } else if let Some(user_by_github_login) = user::Entity::find()
171 .filter(user::Column::GithubLogin.eq(github_login))
172 .one(&tx)
173 .await?
174 {
175 let mut user_by_github_login = user_by_github_login.into_active_model();
176 user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
177 Ok(Some(user_by_github_login.update(&tx).await?))
178 } else {
179 Ok(None)
180 }
181 } else {
182 Ok(user::Entity::find()
183 .filter(user::Column::GithubLogin.eq(github_login))
184 .one(&tx)
185 .await?)
186 }
187 })
188 .await
189 }
190
191 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
192 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
193 enum QueryAs {
194 MetricsId,
195 }
196
197 self.transact(|tx| async move {
198 let metrics_id: Uuid = user::Entity::find_by_id(id)
199 .select_only()
200 .column(user::Column::MetricsId)
201 .into_values::<_, QueryAs>()
202 .one(&tx)
203 .await?
204 .ok_or_else(|| anyhow!("could not find user"))?;
205 Ok(metrics_id.to_string())
206 })
207 .await
208 }
209
210 // contacts
211
212 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
213 #[derive(Debug, FromQueryResult)]
214 struct ContactWithUserBusyStatuses {
215 user_id_a: UserId,
216 user_id_b: UserId,
217 a_to_b: bool,
218 accepted: bool,
219 should_notify: bool,
220 user_a_busy: bool,
221 user_b_busy: bool,
222 }
223
224 self.transact(|tx| async move {
225 let user_a_participant = Alias::new("user_a_participant");
226 let user_b_participant = Alias::new("user_b_participant");
227 let mut db_contacts = contact::Entity::find()
228 .column_as(
229 Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
230 .is_not_null(),
231 "user_a_busy",
232 )
233 .column_as(
234 Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
235 .is_not_null(),
236 "user_b_busy",
237 )
238 .filter(
239 contact::Column::UserIdA
240 .eq(user_id)
241 .or(contact::Column::UserIdB.eq(user_id)),
242 )
243 .join_as(
244 JoinType::LeftJoin,
245 contact::Relation::UserARoomParticipant.def(),
246 user_a_participant,
247 )
248 .join_as(
249 JoinType::LeftJoin,
250 contact::Relation::UserBRoomParticipant.def(),
251 user_b_participant,
252 )
253 .into_model::<ContactWithUserBusyStatuses>()
254 .stream(&tx)
255 .await?;
256
257 let mut contacts = Vec::new();
258 while let Some(db_contact) = db_contacts.next().await {
259 let db_contact = db_contact?;
260 if db_contact.user_id_a == user_id {
261 if db_contact.accepted {
262 contacts.push(Contact::Accepted {
263 user_id: db_contact.user_id_b,
264 should_notify: db_contact.should_notify && db_contact.a_to_b,
265 busy: db_contact.user_b_busy,
266 });
267 } else if db_contact.a_to_b {
268 contacts.push(Contact::Outgoing {
269 user_id: db_contact.user_id_b,
270 })
271 } else {
272 contacts.push(Contact::Incoming {
273 user_id: db_contact.user_id_b,
274 should_notify: db_contact.should_notify,
275 });
276 }
277 } else if db_contact.accepted {
278 contacts.push(Contact::Accepted {
279 user_id: db_contact.user_id_a,
280 should_notify: db_contact.should_notify && !db_contact.a_to_b,
281 busy: db_contact.user_a_busy,
282 });
283 } else if db_contact.a_to_b {
284 contacts.push(Contact::Incoming {
285 user_id: db_contact.user_id_a,
286 should_notify: db_contact.should_notify,
287 });
288 } else {
289 contacts.push(Contact::Outgoing {
290 user_id: db_contact.user_id_a,
291 });
292 }
293 }
294
295 contacts.sort_unstable_by_key(|contact| contact.user_id());
296
297 Ok(contacts)
298 })
299 .await
300 }
301
302 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
303 self.transact(|tx| async move {
304 let (id_a, id_b) = if user_id_1 < user_id_2 {
305 (user_id_1, user_id_2)
306 } else {
307 (user_id_2, user_id_1)
308 };
309
310 Ok(contact::Entity::find()
311 .filter(
312 contact::Column::UserIdA
313 .eq(id_a)
314 .and(contact::Column::UserIdB.eq(id_b))
315 .and(contact::Column::Accepted.eq(true)),
316 )
317 .one(&tx)
318 .await?
319 .is_some())
320 })
321 .await
322 }
323
324 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
325 self.transact(|mut tx| async move {
326 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
327 (sender_id, receiver_id, true)
328 } else {
329 (receiver_id, sender_id, false)
330 };
331
332 let rows_affected = contact::Entity::insert(contact::ActiveModel {
333 user_id_a: ActiveValue::set(id_a),
334 user_id_b: ActiveValue::set(id_b),
335 a_to_b: ActiveValue::set(a_to_b),
336 accepted: ActiveValue::set(false),
337 should_notify: ActiveValue::set(true),
338 ..Default::default()
339 })
340 .on_conflict(
341 OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB])
342 .values([
343 (contact::Column::Accepted, true.into()),
344 (contact::Column::ShouldNotify, false.into()),
345 ])
346 .action_and_where(
347 contact::Column::Accepted.eq(false).and(
348 contact::Column::AToB
349 .eq(a_to_b)
350 .and(contact::Column::UserIdA.eq(id_b))
351 .or(contact::Column::AToB
352 .ne(a_to_b)
353 .and(contact::Column::UserIdA.eq(id_a))),
354 ),
355 )
356 .to_owned(),
357 )
358 .exec_without_returning(&tx)
359 .await?;
360
361 if rows_affected == 1 {
362 tx.commit().await?;
363 Ok(())
364 } else {
365 Err(anyhow!("contact already requested"))?
366 }
367 })
368 .await
369 }
370
371 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
372 self.transact(|tx| async move {
373 let (id_a, id_b) = if responder_id < requester_id {
374 (responder_id, requester_id)
375 } else {
376 (requester_id, responder_id)
377 };
378
379 let result = contact::Entity::delete_many()
380 .filter(
381 contact::Column::UserIdA
382 .eq(id_a)
383 .and(contact::Column::UserIdB.eq(id_b)),
384 )
385 .exec(&tx)
386 .await?;
387
388 if result.rows_affected == 1 {
389 tx.commit().await?;
390 Ok(())
391 } else {
392 Err(anyhow!("no such contact"))?
393 }
394 })
395 .await
396 }
397
398 pub async fn dismiss_contact_notification(
399 &self,
400 user_id: UserId,
401 contact_user_id: UserId,
402 ) -> Result<()> {
403 self.transact(|tx| async move {
404 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
405 (user_id, contact_user_id, true)
406 } else {
407 (contact_user_id, user_id, false)
408 };
409
410 let result = contact::Entity::update_many()
411 .set(contact::ActiveModel {
412 should_notify: ActiveValue::set(false),
413 ..Default::default()
414 })
415 .filter(
416 contact::Column::UserIdA
417 .eq(id_a)
418 .and(contact::Column::UserIdB.eq(id_b))
419 .and(
420 contact::Column::AToB
421 .eq(a_to_b)
422 .and(contact::Column::Accepted.eq(true))
423 .or(contact::Column::AToB
424 .ne(a_to_b)
425 .and(contact::Column::Accepted.eq(false))),
426 ),
427 )
428 .exec(&tx)
429 .await?;
430 if result.rows_affected == 0 {
431 Err(anyhow!("no such contact request"))?
432 } else {
433 tx.commit().await?;
434 Ok(())
435 }
436 })
437 .await
438 }
439
440 pub async fn respond_to_contact_request(
441 &self,
442 responder_id: UserId,
443 requester_id: UserId,
444 accept: bool,
445 ) -> Result<()> {
446 self.transact(|tx| async move {
447 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
448 (responder_id, requester_id, false)
449 } else {
450 (requester_id, responder_id, true)
451 };
452 let rows_affected = if accept {
453 let result = contact::Entity::update_many()
454 .set(contact::ActiveModel {
455 accepted: ActiveValue::set(true),
456 should_notify: ActiveValue::set(true),
457 ..Default::default()
458 })
459 .filter(
460 contact::Column::UserIdA
461 .eq(id_a)
462 .and(contact::Column::UserIdB.eq(id_b))
463 .and(contact::Column::AToB.eq(a_to_b)),
464 )
465 .exec(&tx)
466 .await?;
467 result.rows_affected
468 } else {
469 let result = contact::Entity::delete_many()
470 .filter(
471 contact::Column::UserIdA
472 .eq(id_a)
473 .and(contact::Column::UserIdB.eq(id_b))
474 .and(contact::Column::AToB.eq(a_to_b))
475 .and(contact::Column::Accepted.eq(false)),
476 )
477 .exec(&tx)
478 .await?;
479
480 result.rows_affected
481 };
482
483 if rows_affected == 1 {
484 tx.commit().await?;
485 Ok(())
486 } else {
487 Err(anyhow!("no such contact request"))?
488 }
489 })
490 .await
491 }
492
493 pub fn fuzzy_like_string(string: &str) -> String {
494 let mut result = String::with_capacity(string.len() * 2 + 1);
495 for c in string.chars() {
496 if c.is_alphanumeric() {
497 result.push('%');
498 result.push(c);
499 }
500 }
501 result.push('%');
502 result
503 }
504
505 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
506 self.transact(|tx| async {
507 let tx = tx;
508 let like_string = Self::fuzzy_like_string(name_query);
509 let query = "
510 SELECT users.*
511 FROM users
512 WHERE github_login ILIKE $1
513 ORDER BY github_login <-> $2
514 LIMIT $3
515 ";
516
517 Ok(user::Entity::find()
518 .from_raw_sql(Statement::from_sql_and_values(
519 self.pool.get_database_backend(),
520 query.into(),
521 vec![like_string.into(), name_query.into(), limit.into()],
522 ))
523 .all(&tx)
524 .await?)
525 })
526 .await
527 }
528
529 // invite codes
530
531 pub async fn create_invite_from_code(
532 &self,
533 code: &str,
534 email_address: &str,
535 device_id: Option<&str>,
536 ) -> Result<Invite> {
537 self.transact(|tx| async move {
538 let existing_user = user::Entity::find()
539 .filter(user::Column::EmailAddress.eq(email_address))
540 .one(&tx)
541 .await?;
542
543 if existing_user.is_some() {
544 Err(anyhow!("email address is already in use"))?;
545 }
546
547 let inviter = match user::Entity::find()
548 .filter(user::Column::InviteCode.eq(code))
549 .one(&tx)
550 .await?
551 {
552 Some(inviter) => inviter,
553 None => {
554 return Err(Error::Http(
555 StatusCode::NOT_FOUND,
556 "invite code not found".to_string(),
557 ))?
558 }
559 };
560
561 if inviter.invite_count == 0 {
562 Err(Error::Http(
563 StatusCode::UNAUTHORIZED,
564 "no invites remaining".to_string(),
565 ))?;
566 }
567
568 let signup = signup::Entity::insert(signup::ActiveModel {
569 email_address: ActiveValue::set(email_address.into()),
570 email_confirmation_code: ActiveValue::set(random_email_confirmation_code()),
571 email_confirmation_sent: ActiveValue::set(false),
572 inviting_user_id: ActiveValue::set(Some(inviter.id)),
573 platform_linux: ActiveValue::set(false),
574 platform_mac: ActiveValue::set(false),
575 platform_windows: ActiveValue::set(false),
576 platform_unknown: ActiveValue::set(true),
577 device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())),
578 ..Default::default()
579 })
580 .on_conflict(
581 OnConflict::column(signup::Column::EmailAddress)
582 .update_column(signup::Column::InvitingUserId)
583 .to_owned(),
584 )
585 .exec_with_returning(&tx)
586 .await?;
587 tx.commit().await?;
588
589 Ok(Invite {
590 email_address: signup.email_address,
591 email_confirmation_code: signup.email_confirmation_code,
592 })
593 })
594 .await
595 }
596
597 pub async fn create_user_from_invite(
598 &self,
599 invite: &Invite,
600 user: NewUserParams,
601 ) -> Result<Option<NewUserResult>> {
602 self.transact(|tx| async {
603 let tx = tx;
604 let signup = signup::Entity::find()
605 .filter(
606 signup::Column::EmailAddress
607 .eq(invite.email_address.as_str())
608 .and(
609 signup::Column::EmailConfirmationCode
610 .eq(invite.email_confirmation_code.as_str()),
611 ),
612 )
613 .one(&tx)
614 .await?
615 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
616
617 if signup.user_id.is_some() {
618 return Ok(None);
619 }
620
621 let user = user::Entity::insert(user::ActiveModel {
622 email_address: ActiveValue::set(Some(invite.email_address.clone())),
623 github_login: ActiveValue::set(user.github_login.clone()),
624 github_user_id: ActiveValue::set(Some(user.github_user_id)),
625 admin: ActiveValue::set(false),
626 invite_count: ActiveValue::set(user.invite_count),
627 invite_code: ActiveValue::set(Some(random_invite_code())),
628 metrics_id: ActiveValue::set(Uuid::new_v4()),
629 ..Default::default()
630 })
631 .on_conflict(
632 OnConflict::column(user::Column::GithubLogin)
633 .update_columns([
634 user::Column::EmailAddress,
635 user::Column::GithubUserId,
636 user::Column::Admin,
637 ])
638 .to_owned(),
639 )
640 .exec_with_returning(&tx)
641 .await?;
642
643 let mut signup = signup.into_active_model();
644 signup.user_id = ActiveValue::set(Some(user.id));
645 let signup = signup.update(&tx).await?;
646
647 if let Some(inviting_user_id) = signup.inviting_user_id {
648 let result = user::Entity::update_many()
649 .filter(
650 user::Column::Id
651 .eq(inviting_user_id)
652 .and(user::Column::InviteCount.gt(0)),
653 )
654 .col_expr(
655 user::Column::InviteCount,
656 Expr::col(user::Column::InviteCount).sub(1),
657 )
658 .exec(&tx)
659 .await?;
660
661 if result.rows_affected == 0 {
662 Err(Error::Http(
663 StatusCode::UNAUTHORIZED,
664 "no invites remaining".to_string(),
665 ))?;
666 }
667
668 contact::Entity::insert(contact::ActiveModel {
669 user_id_a: ActiveValue::set(inviting_user_id),
670 user_id_b: ActiveValue::set(user.id),
671 a_to_b: ActiveValue::set(true),
672 should_notify: ActiveValue::set(true),
673 accepted: ActiveValue::set(true),
674 ..Default::default()
675 })
676 .on_conflict(OnConflict::new().do_nothing().to_owned())
677 .exec_without_returning(&tx)
678 .await?;
679 }
680
681 tx.commit().await?;
682 Ok(Some(NewUserResult {
683 user_id: user.id,
684 metrics_id: user.metrics_id.to_string(),
685 inviting_user_id: signup.inviting_user_id,
686 signup_device_id: signup.device_id,
687 }))
688 })
689 .await
690 }
691
692 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
693 self.transact(|tx| async move {
694 if count > 0 {
695 user::Entity::update_many()
696 .filter(
697 user::Column::Id
698 .eq(id)
699 .and(user::Column::InviteCode.is_null()),
700 )
701 .col_expr(user::Column::InviteCode, random_invite_code().into())
702 .exec(&tx)
703 .await?;
704 }
705
706 user::Entity::update_many()
707 .filter(user::Column::Id.eq(id))
708 .col_expr(user::Column::InviteCount, count.into())
709 .exec(&tx)
710 .await?;
711 tx.commit().await?;
712 Ok(())
713 })
714 .await
715 }
716
717 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
718 self.transact(|tx| async move {
719 match user::Entity::find_by_id(id).one(&tx).await? {
720 Some(user) if user.invite_code.is_some() => {
721 Ok(Some((user.invite_code.unwrap(), user.invite_count as u32)))
722 }
723 _ => Ok(None),
724 }
725 })
726 .await
727 }
728
729 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
730 self.transact(|tx| async move {
731 user::Entity::find()
732 .filter(user::Column::InviteCode.eq(code))
733 .one(&tx)
734 .await?
735 .ok_or_else(|| {
736 Error::Http(
737 StatusCode::NOT_FOUND,
738 "that invite code does not exist".to_string(),
739 )
740 })
741 })
742 .await
743 }
744
745 // projects
746
747 pub async fn share_project(
748 &self,
749 room_id: RoomId,
750 connection_id: ConnectionId,
751 worktrees: &[proto::WorktreeMetadata],
752 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
753 self.transact(|tx| async move {
754 let participant = room_participant::Entity::find()
755 .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
756 .one(&tx)
757 .await?
758 .ok_or_else(|| anyhow!("could not find participant"))?;
759 if participant.room_id != room_id {
760 return Err(anyhow!("shared project on unexpected room"))?;
761 }
762
763 let project = project::ActiveModel {
764 room_id: ActiveValue::set(participant.room_id),
765 host_user_id: ActiveValue::set(participant.user_id),
766 host_connection_id: ActiveValue::set(connection_id.0 as i32),
767 ..Default::default()
768 }
769 .insert(&tx)
770 .await?;
771
772 worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
773 id: ActiveValue::set(worktree.id as i32),
774 project_id: ActiveValue::set(project.id),
775 abs_path: ActiveValue::set(worktree.abs_path.clone()),
776 root_name: ActiveValue::set(worktree.root_name.clone()),
777 visible: ActiveValue::set(worktree.visible),
778 scan_id: ActiveValue::set(0),
779 is_complete: ActiveValue::set(false),
780 }))
781 .exec(&tx)
782 .await?;
783
784 project_collaborator::ActiveModel {
785 project_id: ActiveValue::set(project.id),
786 connection_id: ActiveValue::set(connection_id.0 as i32),
787 user_id: ActiveValue::set(participant.user_id),
788 replica_id: ActiveValue::set(0),
789 is_host: ActiveValue::set(true),
790 ..Default::default()
791 }
792 .insert(&tx)
793 .await?;
794
795 let room = self.get_room(room_id, &tx).await?;
796 self.commit_room_transaction(room_id, tx, (project.id, room))
797 .await
798 })
799 .await
800 }
801
802 async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
803 let db_room = room::Entity::find_by_id(room_id)
804 .one(tx)
805 .await?
806 .ok_or_else(|| anyhow!("could not find room"))?;
807
808 let mut db_participants = db_room
809 .find_related(room_participant::Entity)
810 .stream(tx)
811 .await?;
812 let mut participants = HashMap::default();
813 let mut pending_participants = Vec::new();
814 while let Some(db_participant) = db_participants.next().await {
815 let db_participant = db_participant?;
816 if let Some(answering_connection_id) = db_participant.answering_connection_id {
817 let location = match (
818 db_participant.location_kind,
819 db_participant.location_project_id,
820 ) {
821 (Some(0), Some(project_id)) => {
822 Some(proto::participant_location::Variant::SharedProject(
823 proto::participant_location::SharedProject {
824 id: project_id.to_proto(),
825 },
826 ))
827 }
828 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
829 Default::default(),
830 )),
831 _ => Some(proto::participant_location::Variant::External(
832 Default::default(),
833 )),
834 };
835 participants.insert(
836 answering_connection_id,
837 proto::Participant {
838 user_id: db_participant.user_id.to_proto(),
839 peer_id: answering_connection_id as u32,
840 projects: Default::default(),
841 location: Some(proto::ParticipantLocation { variant: location }),
842 },
843 );
844 } else {
845 pending_participants.push(proto::PendingParticipant {
846 user_id: db_participant.user_id.to_proto(),
847 calling_user_id: db_participant.calling_user_id.to_proto(),
848 initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
849 });
850 }
851 }
852
853 let mut db_projects = db_room
854 .find_related(project::Entity)
855 .find_with_related(worktree::Entity)
856 .stream(tx)
857 .await?;
858
859 while let Some(row) = db_projects.next().await {
860 let (db_project, db_worktree) = row?;
861 if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
862 let project = if let Some(project) = participant
863 .projects
864 .iter_mut()
865 .find(|project| project.id == db_project.id.to_proto())
866 {
867 project
868 } else {
869 participant.projects.push(proto::ParticipantProject {
870 id: db_project.id.to_proto(),
871 worktree_root_names: Default::default(),
872 });
873 participant.projects.last_mut().unwrap()
874 };
875
876 if let Some(db_worktree) = db_worktree {
877 project.worktree_root_names.push(db_worktree.root_name);
878 }
879 }
880 }
881
882 Ok(proto::Room {
883 id: db_room.id.to_proto(),
884 live_kit_room: db_room.live_kit_room,
885 participants: participants.into_values().collect(),
886 pending_participants,
887 })
888 }
889
890 async fn commit_room_transaction<T>(
891 &self,
892 room_id: RoomId,
893 tx: DatabaseTransaction,
894 data: T,
895 ) -> Result<RoomGuard<T>> {
896 let lock = self.rooms.entry(room_id).or_default().clone();
897 let _guard = lock.lock_owned().await;
898 tx.commit().await?;
899 Ok(RoomGuard {
900 data,
901 _guard,
902 _not_send: PhantomData,
903 })
904 }
905
906 pub async fn create_access_token_hash(
907 &self,
908 user_id: UserId,
909 access_token_hash: &str,
910 max_access_token_count: usize,
911 ) -> Result<()> {
912 self.transact(|tx| async {
913 let tx = tx;
914
915 access_token::ActiveModel {
916 user_id: ActiveValue::set(user_id),
917 hash: ActiveValue::set(access_token_hash.into()),
918 ..Default::default()
919 }
920 .insert(&tx)
921 .await?;
922
923 access_token::Entity::delete_many()
924 .filter(
925 access_token::Column::Id.in_subquery(
926 Query::select()
927 .column(access_token::Column::Id)
928 .from(access_token::Entity)
929 .and_where(access_token::Column::UserId.eq(user_id))
930 .order_by(access_token::Column::Id, sea_orm::Order::Desc)
931 .limit(10000)
932 .offset(max_access_token_count as u64)
933 .to_owned(),
934 ),
935 )
936 .exec(&tx)
937 .await?;
938 tx.commit().await?;
939 Ok(())
940 })
941 .await
942 }
943
944 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
945 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
946 enum QueryAs {
947 Hash,
948 }
949
950 self.transact(|tx| async move {
951 Ok(access_token::Entity::find()
952 .select_only()
953 .column(access_token::Column::Hash)
954 .filter(access_token::Column::UserId.eq(user_id))
955 .order_by_desc(access_token::Column::Id)
956 .into_values::<_, QueryAs>()
957 .all(&tx)
958 .await?)
959 })
960 .await
961 }
962
963 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
964 where
965 F: Send + Fn(DatabaseTransaction) -> Fut,
966 Fut: Send + Future<Output = Result<T>>,
967 {
968 let body = async {
969 loop {
970 let tx = self.pool.begin().await?;
971
972 // In Postgres, serializable transactions are opt-in
973 if let DatabaseBackend::Postgres = self.pool.get_database_backend() {
974 tx.execute(Statement::from_string(
975 DatabaseBackend::Postgres,
976 "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
977 ))
978 .await?;
979 }
980
981 match f(tx).await {
982 Ok(result) => return Ok(result),
983 Err(error) => match error {
984 Error::Database2(
985 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
986 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
987 ) if error
988 .as_database_error()
989 .and_then(|error| error.code())
990 .as_deref()
991 == Some("40001") =>
992 {
993 // Retry (don't break the loop)
994 }
995 error @ _ => return Err(error),
996 },
997 }
998 }
999 };
1000
1001 #[cfg(test)]
1002 {
1003 if let Some(background) = self.background.as_ref() {
1004 background.simulate_random_delay().await;
1005 }
1006
1007 self.runtime.as_ref().unwrap().block_on(body)
1008 }
1009
1010 #[cfg(not(test))]
1011 {
1012 body.await
1013 }
1014 }
1015}
1016
1017pub struct RoomGuard<T> {
1018 data: T,
1019 _guard: OwnedMutexGuard<()>,
1020 _not_send: PhantomData<Rc<()>>,
1021}
1022
1023impl<T> Deref for RoomGuard<T> {
1024 type Target = T;
1025
1026 fn deref(&self) -> &T {
1027 &self.data
1028 }
1029}
1030
1031impl<T> DerefMut for RoomGuard<T> {
1032 fn deref_mut(&mut self) -> &mut T {
1033 &mut self.data
1034 }
1035}
1036
1037#[derive(Debug, Serialize, Deserialize)]
1038pub struct NewUserParams {
1039 pub github_login: String,
1040 pub github_user_id: i32,
1041 pub invite_count: i32,
1042}
1043
1044#[derive(Debug)]
1045pub struct NewUserResult {
1046 pub user_id: UserId,
1047 pub metrics_id: String,
1048 pub inviting_user_id: Option<UserId>,
1049 pub signup_device_id: Option<String>,
1050}
1051
1052fn random_invite_code() -> String {
1053 nanoid::nanoid!(16)
1054}
1055
1056fn random_email_confirmation_code() -> String {
1057 nanoid::nanoid!(64)
1058}
1059
1060macro_rules! id_type {
1061 ($name:ident) => {
1062 #[derive(
1063 Clone,
1064 Copy,
1065 Debug,
1066 Default,
1067 PartialEq,
1068 Eq,
1069 PartialOrd,
1070 Ord,
1071 Hash,
1072 sqlx::Type,
1073 Serialize,
1074 Deserialize,
1075 )]
1076 #[sqlx(transparent)]
1077 #[serde(transparent)]
1078 pub struct $name(pub i32);
1079
1080 impl $name {
1081 #[allow(unused)]
1082 pub const MAX: Self = Self(i32::MAX);
1083
1084 #[allow(unused)]
1085 pub fn from_proto(value: u64) -> Self {
1086 Self(value as i32)
1087 }
1088
1089 #[allow(unused)]
1090 pub fn to_proto(self) -> u64 {
1091 self.0 as u64
1092 }
1093 }
1094
1095 impl std::fmt::Display for $name {
1096 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1097 self.0.fmt(f)
1098 }
1099 }
1100
1101 impl From<$name> for sea_query::Value {
1102 fn from(value: $name) -> Self {
1103 sea_query::Value::Int(Some(value.0))
1104 }
1105 }
1106
1107 impl sea_orm::TryGetable for $name {
1108 fn try_get(
1109 res: &sea_orm::QueryResult,
1110 pre: &str,
1111 col: &str,
1112 ) -> Result<Self, sea_orm::TryGetError> {
1113 Ok(Self(i32::try_get(res, pre, col)?))
1114 }
1115 }
1116
1117 impl sea_query::ValueType for $name {
1118 fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
1119 match v {
1120 Value::TinyInt(Some(int)) => {
1121 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1122 }
1123 Value::SmallInt(Some(int)) => {
1124 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1125 }
1126 Value::Int(Some(int)) => {
1127 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1128 }
1129 Value::BigInt(Some(int)) => {
1130 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1131 }
1132 Value::TinyUnsigned(Some(int)) => {
1133 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1134 }
1135 Value::SmallUnsigned(Some(int)) => {
1136 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1137 }
1138 Value::Unsigned(Some(int)) => {
1139 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1140 }
1141 Value::BigUnsigned(Some(int)) => {
1142 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1143 }
1144 _ => Err(sea_query::ValueTypeErr),
1145 }
1146 }
1147
1148 fn type_name() -> String {
1149 stringify!($name).into()
1150 }
1151
1152 fn array_type() -> sea_query::ArrayType {
1153 sea_query::ArrayType::Int
1154 }
1155
1156 fn column_type() -> sea_query::ColumnType {
1157 sea_query::ColumnType::Integer(None)
1158 }
1159 }
1160
1161 impl sea_orm::TryFromU64 for $name {
1162 fn try_from_u64(n: u64) -> Result<Self, DbErr> {
1163 Ok(Self(n.try_into().map_err(|_| {
1164 DbErr::ConvertFromU64(concat!(
1165 "error converting ",
1166 stringify!($name),
1167 " to u64"
1168 ))
1169 })?))
1170 }
1171 }
1172
1173 impl sea_query::Nullable for $name {
1174 fn null() -> Value {
1175 Value::Int(None)
1176 }
1177 }
1178 };
1179}
1180
1181id_type!(AccessTokenId);
1182id_type!(ContactId);
1183id_type!(UserId);
1184id_type!(RoomId);
1185id_type!(RoomParticipantId);
1186id_type!(ProjectId);
1187id_type!(ProjectCollaboratorId);
1188id_type!(SignupId);
1189id_type!(WorktreeId);
1190
1191#[cfg(test)]
1192pub use test::*;
1193
1194#[cfg(test)]
1195mod test {
1196 use super::*;
1197 use gpui::executor::Background;
1198 use lazy_static::lazy_static;
1199 use parking_lot::Mutex;
1200 use rand::prelude::*;
1201 use sea_orm::ConnectionTrait;
1202 use sqlx::migrate::MigrateDatabase;
1203 use std::sync::Arc;
1204
1205 pub struct TestDb {
1206 pub db: Option<Arc<Database>>,
1207 pub connection: Option<sqlx::AnyConnection>,
1208 }
1209
1210 impl TestDb {
1211 pub fn sqlite(background: Arc<Background>) -> Self {
1212 let url = format!("sqlite::memory:");
1213 let runtime = tokio::runtime::Builder::new_current_thread()
1214 .enable_io()
1215 .enable_time()
1216 .build()
1217 .unwrap();
1218
1219 let mut db = runtime.block_on(async {
1220 let mut options = ConnectOptions::new(url);
1221 options.max_connections(5);
1222 let db = Database::new(options).await.unwrap();
1223 let sql = include_str!(concat!(
1224 env!("CARGO_MANIFEST_DIR"),
1225 "/migrations.sqlite/20221109000000_test_schema.sql"
1226 ));
1227 db.pool
1228 .execute(sea_orm::Statement::from_string(
1229 db.pool.get_database_backend(),
1230 sql.into(),
1231 ))
1232 .await
1233 .unwrap();
1234 db
1235 });
1236
1237 db.background = Some(background);
1238 db.runtime = Some(runtime);
1239
1240 Self {
1241 db: Some(Arc::new(db)),
1242 connection: None,
1243 }
1244 }
1245
1246 pub fn postgres(background: Arc<Background>) -> Self {
1247 lazy_static! {
1248 static ref LOCK: Mutex<()> = Mutex::new(());
1249 }
1250
1251 let _guard = LOCK.lock();
1252 let mut rng = StdRng::from_entropy();
1253 let url = format!(
1254 "postgres://postgres@localhost/zed-test-{}",
1255 rng.gen::<u128>()
1256 );
1257 let runtime = tokio::runtime::Builder::new_current_thread()
1258 .enable_io()
1259 .enable_time()
1260 .build()
1261 .unwrap();
1262
1263 let mut db = runtime.block_on(async {
1264 sqlx::Postgres::create_database(&url)
1265 .await
1266 .expect("failed to create test db");
1267 let mut options = ConnectOptions::new(url);
1268 options
1269 .max_connections(5)
1270 .idle_timeout(Duration::from_secs(0));
1271 let db = Database::new(options).await.unwrap();
1272 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1273 db.migrate(Path::new(migrations_path), false).await.unwrap();
1274 db
1275 });
1276
1277 db.background = Some(background);
1278 db.runtime = Some(runtime);
1279
1280 Self {
1281 db: Some(Arc::new(db)),
1282 connection: None,
1283 }
1284 }
1285
1286 pub fn db(&self) -> &Arc<Database> {
1287 self.db.as_ref().unwrap()
1288 }
1289 }
1290
1291 impl Drop for TestDb {
1292 fn drop(&mut self) {
1293 let db = self.db.take().unwrap();
1294 if let DatabaseBackend::Postgres = db.pool.get_database_backend() {
1295 db.runtime.as_ref().unwrap().block_on(async {
1296 use util::ResultExt;
1297 let query = "
1298 SELECT pg_terminate_backend(pg_stat_activity.pid)
1299 FROM pg_stat_activity
1300 WHERE
1301 pg_stat_activity.datname = current_database() AND
1302 pid <> pg_backend_pid();
1303 ";
1304 db.pool
1305 .execute(sea_orm::Statement::from_string(
1306 db.pool.get_database_backend(),
1307 query.into(),
1308 ))
1309 .await
1310 .log_err();
1311 sqlx::Postgres::drop_database(db.options.get_url())
1312 .await
1313 .log_err();
1314 })
1315 }
1316 }
1317 }
1318}