1mod access_token;
2mod contact;
3mod project;
4mod project_collaborator;
5mod room;
6mod room_participant;
7mod signup;
8#[cfg(test)]
9mod tests;
10mod user;
11mod worktree;
12
13use crate::{Error, Result};
14use anyhow::anyhow;
15use collections::HashMap;
16pub use contact::Contact;
17use dashmap::DashMap;
18use futures::StreamExt;
19use hyper::StatusCode;
20use rpc::{proto, ConnectionId};
21pub use sea_orm::ConnectOptions;
22use sea_orm::{
23 entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseBackend, DatabaseConnection,
24 DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, JoinType, QueryOrder,
25 QuerySelect, Statement, TransactionTrait,
26};
27use sea_query::{Alias, Expr, OnConflict, Query};
28use serde::{Deserialize, Serialize};
29pub use signup::{Invite, NewSignup, WaitlistSummary};
30use sqlx::migrate::{Migrate, Migration, MigrationSource};
31use sqlx::Connection;
32use std::ops::{Deref, DerefMut};
33use std::path::Path;
34use std::time::Duration;
35use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
36use tokio::sync::{Mutex, OwnedMutexGuard};
37pub use user::Model as User;
38
39pub struct Database {
40 options: ConnectOptions,
41 pool: DatabaseConnection,
42 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
43 #[cfg(test)]
44 background: Option<std::sync::Arc<gpui::executor::Background>>,
45 #[cfg(test)]
46 runtime: Option<tokio::runtime::Runtime>,
47}
48
49impl Database {
50 pub async fn new(options: ConnectOptions) -> Result<Self> {
51 Ok(Self {
52 options: options.clone(),
53 pool: sea_orm::Database::connect(options).await?,
54 rooms: DashMap::with_capacity(16384),
55 #[cfg(test)]
56 background: None,
57 #[cfg(test)]
58 runtime: None,
59 })
60 }
61
62 pub async fn migrate(
63 &self,
64 migrations_path: &Path,
65 ignore_checksum_mismatch: bool,
66 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
67 let migrations = MigrationSource::resolve(migrations_path)
68 .await
69 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
70
71 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
72
73 connection.ensure_migrations_table().await?;
74 let applied_migrations: HashMap<_, _> = connection
75 .list_applied_migrations()
76 .await?
77 .into_iter()
78 .map(|m| (m.version, m))
79 .collect();
80
81 let mut new_migrations = Vec::new();
82 for migration in migrations {
83 match applied_migrations.get(&migration.version) {
84 Some(applied_migration) => {
85 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
86 {
87 Err(anyhow!(
88 "checksum mismatch for applied migration {}",
89 migration.description
90 ))?;
91 }
92 }
93 None => {
94 let elapsed = connection.apply(&migration).await?;
95 new_migrations.push((migration, elapsed));
96 }
97 }
98 }
99
100 Ok(new_migrations)
101 }
102
103 // users
104
105 pub async fn create_user(
106 &self,
107 email_address: &str,
108 admin: bool,
109 params: NewUserParams,
110 ) -> Result<NewUserResult> {
111 self.transact(|tx| async {
112 let user = user::Entity::insert(user::ActiveModel {
113 email_address: ActiveValue::set(Some(email_address.into())),
114 github_login: ActiveValue::set(params.github_login.clone()),
115 github_user_id: ActiveValue::set(Some(params.github_user_id)),
116 admin: ActiveValue::set(admin),
117 metrics_id: ActiveValue::set(Uuid::new_v4()),
118 ..Default::default()
119 })
120 .on_conflict(
121 OnConflict::column(user::Column::GithubLogin)
122 .update_column(user::Column::GithubLogin)
123 .to_owned(),
124 )
125 .exec_with_returning(&tx)
126 .await?;
127
128 tx.commit().await?;
129
130 Ok(NewUserResult {
131 user_id: user.id,
132 metrics_id: user.metrics_id.to_string(),
133 signup_device_id: None,
134 inviting_user_id: None,
135 })
136 })
137 .await
138 }
139
140 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<user::Model>> {
141 self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) })
142 .await
143 }
144
145 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
146 self.transact(|tx| async {
147 let tx = tx;
148 Ok(user::Entity::find()
149 .filter(user::Column::Id.is_in(ids.iter().copied()))
150 .all(&tx)
151 .await?)
152 })
153 .await
154 }
155
156 pub async fn get_user_by_github_account(
157 &self,
158 github_login: &str,
159 github_user_id: Option<i32>,
160 ) -> Result<Option<User>> {
161 self.transact(|tx| async {
162 let tx = tx;
163 if let Some(github_user_id) = github_user_id {
164 if let Some(user_by_github_user_id) = user::Entity::find()
165 .filter(user::Column::GithubUserId.eq(github_user_id))
166 .one(&tx)
167 .await?
168 {
169 let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
170 user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
171 Ok(Some(user_by_github_user_id.update(&tx).await?))
172 } else if let Some(user_by_github_login) = user::Entity::find()
173 .filter(user::Column::GithubLogin.eq(github_login))
174 .one(&tx)
175 .await?
176 {
177 let mut user_by_github_login = user_by_github_login.into_active_model();
178 user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
179 Ok(Some(user_by_github_login.update(&tx).await?))
180 } else {
181 Ok(None)
182 }
183 } else {
184 Ok(user::Entity::find()
185 .filter(user::Column::GithubLogin.eq(github_login))
186 .one(&tx)
187 .await?)
188 }
189 })
190 .await
191 }
192
193 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
194 self.transact(|tx| async move {
195 Ok(user::Entity::find()
196 .order_by_asc(user::Column::GithubLogin)
197 .limit(limit as u64)
198 .offset(page as u64 * limit as u64)
199 .all(&tx)
200 .await?)
201 })
202 .await
203 }
204
205 pub async fn get_users_with_no_invites(
206 &self,
207 invited_by_another_user: bool,
208 ) -> Result<Vec<User>> {
209 self.transact(|tx| async move {
210 Ok(user::Entity::find()
211 .filter(
212 user::Column::InviteCount
213 .eq(0)
214 .and(if invited_by_another_user {
215 user::Column::InviterId.is_not_null()
216 } else {
217 user::Column::InviterId.is_null()
218 }),
219 )
220 .all(&tx)
221 .await?)
222 })
223 .await
224 }
225
226 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
227 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
228 enum QueryAs {
229 MetricsId,
230 }
231
232 self.transact(|tx| async move {
233 let metrics_id: Uuid = user::Entity::find_by_id(id)
234 .select_only()
235 .column(user::Column::MetricsId)
236 .into_values::<_, QueryAs>()
237 .one(&tx)
238 .await?
239 .ok_or_else(|| anyhow!("could not find user"))?;
240 Ok(metrics_id.to_string())
241 })
242 .await
243 }
244
245 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
246 self.transact(|tx| async move {
247 user::Entity::update_many()
248 .filter(user::Column::Id.eq(id))
249 .col_expr(user::Column::Admin, is_admin.into())
250 .exec(&tx)
251 .await?;
252 tx.commit().await?;
253 Ok(())
254 })
255 .await
256 }
257
258 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
259 self.transact(|tx| async move {
260 access_token::Entity::delete_many()
261 .filter(access_token::Column::UserId.eq(id))
262 .exec(&tx)
263 .await?;
264 user::Entity::delete_by_id(id).exec(&tx).await?;
265 tx.commit().await?;
266 Ok(())
267 })
268 .await
269 }
270
271 // contacts
272
273 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
274 #[derive(Debug, FromQueryResult)]
275 struct ContactWithUserBusyStatuses {
276 user_id_a: UserId,
277 user_id_b: UserId,
278 a_to_b: bool,
279 accepted: bool,
280 should_notify: bool,
281 user_a_busy: bool,
282 user_b_busy: bool,
283 }
284
285 self.transact(|tx| async move {
286 let user_a_participant = Alias::new("user_a_participant");
287 let user_b_participant = Alias::new("user_b_participant");
288 let mut db_contacts = contact::Entity::find()
289 .column_as(
290 Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
291 .is_not_null(),
292 "user_a_busy",
293 )
294 .column_as(
295 Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
296 .is_not_null(),
297 "user_b_busy",
298 )
299 .filter(
300 contact::Column::UserIdA
301 .eq(user_id)
302 .or(contact::Column::UserIdB.eq(user_id)),
303 )
304 .join_as(
305 JoinType::LeftJoin,
306 contact::Relation::UserARoomParticipant.def(),
307 user_a_participant,
308 )
309 .join_as(
310 JoinType::LeftJoin,
311 contact::Relation::UserBRoomParticipant.def(),
312 user_b_participant,
313 )
314 .into_model::<ContactWithUserBusyStatuses>()
315 .stream(&tx)
316 .await?;
317
318 let mut contacts = Vec::new();
319 while let Some(db_contact) = db_contacts.next().await {
320 let db_contact = db_contact?;
321 if db_contact.user_id_a == user_id {
322 if db_contact.accepted {
323 contacts.push(Contact::Accepted {
324 user_id: db_contact.user_id_b,
325 should_notify: db_contact.should_notify && db_contact.a_to_b,
326 busy: db_contact.user_b_busy,
327 });
328 } else if db_contact.a_to_b {
329 contacts.push(Contact::Outgoing {
330 user_id: db_contact.user_id_b,
331 })
332 } else {
333 contacts.push(Contact::Incoming {
334 user_id: db_contact.user_id_b,
335 should_notify: db_contact.should_notify,
336 });
337 }
338 } else if db_contact.accepted {
339 contacts.push(Contact::Accepted {
340 user_id: db_contact.user_id_a,
341 should_notify: db_contact.should_notify && !db_contact.a_to_b,
342 busy: db_contact.user_a_busy,
343 });
344 } else if db_contact.a_to_b {
345 contacts.push(Contact::Incoming {
346 user_id: db_contact.user_id_a,
347 should_notify: db_contact.should_notify,
348 });
349 } else {
350 contacts.push(Contact::Outgoing {
351 user_id: db_contact.user_id_a,
352 });
353 }
354 }
355
356 contacts.sort_unstable_by_key(|contact| contact.user_id());
357
358 Ok(contacts)
359 })
360 .await
361 }
362
363 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
364 self.transact(|tx| async move {
365 let (id_a, id_b) = if user_id_1 < user_id_2 {
366 (user_id_1, user_id_2)
367 } else {
368 (user_id_2, user_id_1)
369 };
370
371 Ok(contact::Entity::find()
372 .filter(
373 contact::Column::UserIdA
374 .eq(id_a)
375 .and(contact::Column::UserIdB.eq(id_b))
376 .and(contact::Column::Accepted.eq(true)),
377 )
378 .one(&tx)
379 .await?
380 .is_some())
381 })
382 .await
383 }
384
385 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
386 self.transact(|tx| async move {
387 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
388 (sender_id, receiver_id, true)
389 } else {
390 (receiver_id, sender_id, false)
391 };
392
393 let rows_affected = contact::Entity::insert(contact::ActiveModel {
394 user_id_a: ActiveValue::set(id_a),
395 user_id_b: ActiveValue::set(id_b),
396 a_to_b: ActiveValue::set(a_to_b),
397 accepted: ActiveValue::set(false),
398 should_notify: ActiveValue::set(true),
399 ..Default::default()
400 })
401 .on_conflict(
402 OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB])
403 .values([
404 (contact::Column::Accepted, true.into()),
405 (contact::Column::ShouldNotify, false.into()),
406 ])
407 .action_and_where(
408 contact::Column::Accepted.eq(false).and(
409 contact::Column::AToB
410 .eq(a_to_b)
411 .and(contact::Column::UserIdA.eq(id_b))
412 .or(contact::Column::AToB
413 .ne(a_to_b)
414 .and(contact::Column::UserIdA.eq(id_a))),
415 ),
416 )
417 .to_owned(),
418 )
419 .exec_without_returning(&tx)
420 .await?;
421
422 if rows_affected == 1 {
423 tx.commit().await?;
424 Ok(())
425 } else {
426 Err(anyhow!("contact already requested"))?
427 }
428 })
429 .await
430 }
431
432 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
433 self.transact(|tx| async move {
434 let (id_a, id_b) = if responder_id < requester_id {
435 (responder_id, requester_id)
436 } else {
437 (requester_id, responder_id)
438 };
439
440 let result = contact::Entity::delete_many()
441 .filter(
442 contact::Column::UserIdA
443 .eq(id_a)
444 .and(contact::Column::UserIdB.eq(id_b)),
445 )
446 .exec(&tx)
447 .await?;
448
449 if result.rows_affected == 1 {
450 tx.commit().await?;
451 Ok(())
452 } else {
453 Err(anyhow!("no such contact"))?
454 }
455 })
456 .await
457 }
458
459 pub async fn dismiss_contact_notification(
460 &self,
461 user_id: UserId,
462 contact_user_id: UserId,
463 ) -> Result<()> {
464 self.transact(|tx| async move {
465 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
466 (user_id, contact_user_id, true)
467 } else {
468 (contact_user_id, user_id, false)
469 };
470
471 let result = contact::Entity::update_many()
472 .set(contact::ActiveModel {
473 should_notify: ActiveValue::set(false),
474 ..Default::default()
475 })
476 .filter(
477 contact::Column::UserIdA
478 .eq(id_a)
479 .and(contact::Column::UserIdB.eq(id_b))
480 .and(
481 contact::Column::AToB
482 .eq(a_to_b)
483 .and(contact::Column::Accepted.eq(true))
484 .or(contact::Column::AToB
485 .ne(a_to_b)
486 .and(contact::Column::Accepted.eq(false))),
487 ),
488 )
489 .exec(&tx)
490 .await?;
491 if result.rows_affected == 0 {
492 Err(anyhow!("no such contact request"))?
493 } else {
494 tx.commit().await?;
495 Ok(())
496 }
497 })
498 .await
499 }
500
501 pub async fn respond_to_contact_request(
502 &self,
503 responder_id: UserId,
504 requester_id: UserId,
505 accept: bool,
506 ) -> Result<()> {
507 self.transact(|tx| async move {
508 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
509 (responder_id, requester_id, false)
510 } else {
511 (requester_id, responder_id, true)
512 };
513 let rows_affected = if accept {
514 let result = contact::Entity::update_many()
515 .set(contact::ActiveModel {
516 accepted: ActiveValue::set(true),
517 should_notify: ActiveValue::set(true),
518 ..Default::default()
519 })
520 .filter(
521 contact::Column::UserIdA
522 .eq(id_a)
523 .and(contact::Column::UserIdB.eq(id_b))
524 .and(contact::Column::AToB.eq(a_to_b)),
525 )
526 .exec(&tx)
527 .await?;
528 result.rows_affected
529 } else {
530 let result = contact::Entity::delete_many()
531 .filter(
532 contact::Column::UserIdA
533 .eq(id_a)
534 .and(contact::Column::UserIdB.eq(id_b))
535 .and(contact::Column::AToB.eq(a_to_b))
536 .and(contact::Column::Accepted.eq(false)),
537 )
538 .exec(&tx)
539 .await?;
540
541 result.rows_affected
542 };
543
544 if rows_affected == 1 {
545 tx.commit().await?;
546 Ok(())
547 } else {
548 Err(anyhow!("no such contact request"))?
549 }
550 })
551 .await
552 }
553
554 pub fn fuzzy_like_string(string: &str) -> String {
555 let mut result = String::with_capacity(string.len() * 2 + 1);
556 for c in string.chars() {
557 if c.is_alphanumeric() {
558 result.push('%');
559 result.push(c);
560 }
561 }
562 result.push('%');
563 result
564 }
565
566 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
567 self.transact(|tx| async {
568 let tx = tx;
569 let like_string = Self::fuzzy_like_string(name_query);
570 let query = "
571 SELECT users.*
572 FROM users
573 WHERE github_login ILIKE $1
574 ORDER BY github_login <-> $2
575 LIMIT $3
576 ";
577
578 Ok(user::Entity::find()
579 .from_raw_sql(Statement::from_sql_and_values(
580 self.pool.get_database_backend(),
581 query.into(),
582 vec![like_string.into(), name_query.into(), limit.into()],
583 ))
584 .all(&tx)
585 .await?)
586 })
587 .await
588 }
589
590 // signups
591
592 pub async fn create_signup(&self, signup: NewSignup) -> Result<()> {
593 self.transact(|tx| async {
594 signup::ActiveModel {
595 email_address: ActiveValue::set(signup.email_address.clone()),
596 email_confirmation_code: ActiveValue::set(random_email_confirmation_code()),
597 email_confirmation_sent: ActiveValue::set(false),
598 platform_mac: ActiveValue::set(signup.platform_mac),
599 platform_windows: ActiveValue::set(signup.platform_windows),
600 platform_linux: ActiveValue::set(signup.platform_linux),
601 platform_unknown: ActiveValue::set(false),
602 editor_features: ActiveValue::set(Some(signup.editor_features.clone())),
603 programming_languages: ActiveValue::set(Some(signup.programming_languages.clone())),
604 device_id: ActiveValue::set(signup.device_id.clone()),
605 ..Default::default()
606 }
607 .insert(&tx)
608 .await?;
609 tx.commit().await?;
610 Ok(())
611 })
612 .await
613 }
614
615 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
616 self.transact(|tx| async move {
617 let query = "
618 SELECT
619 COUNT(*) as count,
620 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
621 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
622 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
623 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
624 FROM (
625 SELECT *
626 FROM signups
627 WHERE
628 NOT email_confirmation_sent
629 ) AS unsent
630 ";
631 Ok(
632 WaitlistSummary::find_by_statement(Statement::from_sql_and_values(
633 self.pool.get_database_backend(),
634 query.into(),
635 vec![],
636 ))
637 .one(&tx)
638 .await?
639 .ok_or_else(|| anyhow!("invalid result"))?,
640 )
641 })
642 .await
643 }
644
645 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
646 let emails = invites
647 .iter()
648 .map(|s| s.email_address.as_str())
649 .collect::<Vec<_>>();
650 self.transact(|tx| async {
651 signup::Entity::update_many()
652 .filter(signup::Column::EmailAddress.is_in(emails.iter().copied()))
653 .col_expr(signup::Column::EmailConfirmationSent, true.into())
654 .exec(&tx)
655 .await?;
656 tx.commit().await?;
657 Ok(())
658 })
659 .await
660 }
661
662 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
663 self.transact(|tx| async move {
664 Ok(signup::Entity::find()
665 .select_only()
666 .column(signup::Column::EmailAddress)
667 .column(signup::Column::EmailConfirmationCode)
668 .filter(
669 signup::Column::EmailConfirmationSent.eq(false).and(
670 signup::Column::PlatformMac
671 .eq(true)
672 .or(signup::Column::PlatformUnknown.eq(true)),
673 ),
674 )
675 .limit(count as u64)
676 .into_model()
677 .all(&tx)
678 .await?)
679 })
680 .await
681 }
682
683 // invite codes
684
685 pub async fn create_invite_from_code(
686 &self,
687 code: &str,
688 email_address: &str,
689 device_id: Option<&str>,
690 ) -> Result<Invite> {
691 self.transact(|tx| async move {
692 let existing_user = user::Entity::find()
693 .filter(user::Column::EmailAddress.eq(email_address))
694 .one(&tx)
695 .await?;
696
697 if existing_user.is_some() {
698 Err(anyhow!("email address is already in use"))?;
699 }
700
701 let inviter = match user::Entity::find()
702 .filter(user::Column::InviteCode.eq(code))
703 .one(&tx)
704 .await?
705 {
706 Some(inviter) => inviter,
707 None => {
708 return Err(Error::Http(
709 StatusCode::NOT_FOUND,
710 "invite code not found".to_string(),
711 ))?
712 }
713 };
714
715 if inviter.invite_count == 0 {
716 Err(Error::Http(
717 StatusCode::UNAUTHORIZED,
718 "no invites remaining".to_string(),
719 ))?;
720 }
721
722 let signup = signup::Entity::insert(signup::ActiveModel {
723 email_address: ActiveValue::set(email_address.into()),
724 email_confirmation_code: ActiveValue::set(random_email_confirmation_code()),
725 email_confirmation_sent: ActiveValue::set(false),
726 inviting_user_id: ActiveValue::set(Some(inviter.id)),
727 platform_linux: ActiveValue::set(false),
728 platform_mac: ActiveValue::set(false),
729 platform_windows: ActiveValue::set(false),
730 platform_unknown: ActiveValue::set(true),
731 device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())),
732 ..Default::default()
733 })
734 .on_conflict(
735 OnConflict::column(signup::Column::EmailAddress)
736 .update_column(signup::Column::InvitingUserId)
737 .to_owned(),
738 )
739 .exec_with_returning(&tx)
740 .await?;
741 tx.commit().await?;
742
743 Ok(Invite {
744 email_address: signup.email_address,
745 email_confirmation_code: signup.email_confirmation_code,
746 })
747 })
748 .await
749 }
750
751 pub async fn create_user_from_invite(
752 &self,
753 invite: &Invite,
754 user: NewUserParams,
755 ) -> Result<Option<NewUserResult>> {
756 self.transact(|tx| async {
757 let tx = tx;
758 let signup = signup::Entity::find()
759 .filter(
760 signup::Column::EmailAddress
761 .eq(invite.email_address.as_str())
762 .and(
763 signup::Column::EmailConfirmationCode
764 .eq(invite.email_confirmation_code.as_str()),
765 ),
766 )
767 .one(&tx)
768 .await?
769 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
770
771 if signup.user_id.is_some() {
772 return Ok(None);
773 }
774
775 let user = user::Entity::insert(user::ActiveModel {
776 email_address: ActiveValue::set(Some(invite.email_address.clone())),
777 github_login: ActiveValue::set(user.github_login.clone()),
778 github_user_id: ActiveValue::set(Some(user.github_user_id)),
779 admin: ActiveValue::set(false),
780 invite_count: ActiveValue::set(user.invite_count),
781 invite_code: ActiveValue::set(Some(random_invite_code())),
782 metrics_id: ActiveValue::set(Uuid::new_v4()),
783 ..Default::default()
784 })
785 .on_conflict(
786 OnConflict::column(user::Column::GithubLogin)
787 .update_columns([
788 user::Column::EmailAddress,
789 user::Column::GithubUserId,
790 user::Column::Admin,
791 ])
792 .to_owned(),
793 )
794 .exec_with_returning(&tx)
795 .await?;
796
797 let mut signup = signup.into_active_model();
798 signup.user_id = ActiveValue::set(Some(user.id));
799 let signup = signup.update(&tx).await?;
800
801 if let Some(inviting_user_id) = signup.inviting_user_id {
802 let result = user::Entity::update_many()
803 .filter(
804 user::Column::Id
805 .eq(inviting_user_id)
806 .and(user::Column::InviteCount.gt(0)),
807 )
808 .col_expr(
809 user::Column::InviteCount,
810 Expr::col(user::Column::InviteCount).sub(1),
811 )
812 .exec(&tx)
813 .await?;
814
815 if result.rows_affected == 0 {
816 Err(Error::Http(
817 StatusCode::UNAUTHORIZED,
818 "no invites remaining".to_string(),
819 ))?;
820 }
821
822 contact::Entity::insert(contact::ActiveModel {
823 user_id_a: ActiveValue::set(inviting_user_id),
824 user_id_b: ActiveValue::set(user.id),
825 a_to_b: ActiveValue::set(true),
826 should_notify: ActiveValue::set(true),
827 accepted: ActiveValue::set(true),
828 ..Default::default()
829 })
830 .on_conflict(OnConflict::new().do_nothing().to_owned())
831 .exec_without_returning(&tx)
832 .await?;
833 }
834
835 tx.commit().await?;
836 Ok(Some(NewUserResult {
837 user_id: user.id,
838 metrics_id: user.metrics_id.to_string(),
839 inviting_user_id: signup.inviting_user_id,
840 signup_device_id: signup.device_id,
841 }))
842 })
843 .await
844 }
845
846 pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
847 self.transact(|tx| async move {
848 if count > 0 {
849 user::Entity::update_many()
850 .filter(
851 user::Column::Id
852 .eq(id)
853 .and(user::Column::InviteCode.is_null()),
854 )
855 .col_expr(user::Column::InviteCode, random_invite_code().into())
856 .exec(&tx)
857 .await?;
858 }
859
860 user::Entity::update_many()
861 .filter(user::Column::Id.eq(id))
862 .col_expr(user::Column::InviteCount, count.into())
863 .exec(&tx)
864 .await?;
865 tx.commit().await?;
866 Ok(())
867 })
868 .await
869 }
870
871 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
872 self.transact(|tx| async move {
873 match user::Entity::find_by_id(id).one(&tx).await? {
874 Some(user) if user.invite_code.is_some() => {
875 Ok(Some((user.invite_code.unwrap(), user.invite_count as u32)))
876 }
877 _ => Ok(None),
878 }
879 })
880 .await
881 }
882
883 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
884 self.transact(|tx| async move {
885 user::Entity::find()
886 .filter(user::Column::InviteCode.eq(code))
887 .one(&tx)
888 .await?
889 .ok_or_else(|| {
890 Error::Http(
891 StatusCode::NOT_FOUND,
892 "that invite code does not exist".to_string(),
893 )
894 })
895 })
896 .await
897 }
898
899 // projects
900
901 pub async fn share_project(
902 &self,
903 room_id: RoomId,
904 connection_id: ConnectionId,
905 worktrees: &[proto::WorktreeMetadata],
906 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
907 self.transact(|tx| async move {
908 let participant = room_participant::Entity::find()
909 .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
910 .one(&tx)
911 .await?
912 .ok_or_else(|| anyhow!("could not find participant"))?;
913 if participant.room_id != room_id {
914 return Err(anyhow!("shared project on unexpected room"))?;
915 }
916
917 let project = project::ActiveModel {
918 room_id: ActiveValue::set(participant.room_id),
919 host_user_id: ActiveValue::set(participant.user_id),
920 host_connection_id: ActiveValue::set(connection_id.0 as i32),
921 ..Default::default()
922 }
923 .insert(&tx)
924 .await?;
925
926 worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
927 id: ActiveValue::set(worktree.id as i32),
928 project_id: ActiveValue::set(project.id),
929 abs_path: ActiveValue::set(worktree.abs_path.clone()),
930 root_name: ActiveValue::set(worktree.root_name.clone()),
931 visible: ActiveValue::set(worktree.visible),
932 scan_id: ActiveValue::set(0),
933 is_complete: ActiveValue::set(false),
934 }))
935 .exec(&tx)
936 .await?;
937
938 project_collaborator::ActiveModel {
939 project_id: ActiveValue::set(project.id),
940 connection_id: ActiveValue::set(connection_id.0 as i32),
941 user_id: ActiveValue::set(participant.user_id),
942 replica_id: ActiveValue::set(0),
943 is_host: ActiveValue::set(true),
944 ..Default::default()
945 }
946 .insert(&tx)
947 .await?;
948
949 let room = self.get_room(room_id, &tx).await?;
950 self.commit_room_transaction(room_id, tx, (project.id, room))
951 .await
952 })
953 .await
954 }
955
956 async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
957 let db_room = room::Entity::find_by_id(room_id)
958 .one(tx)
959 .await?
960 .ok_or_else(|| anyhow!("could not find room"))?;
961
962 let mut db_participants = db_room
963 .find_related(room_participant::Entity)
964 .stream(tx)
965 .await?;
966 let mut participants = HashMap::default();
967 let mut pending_participants = Vec::new();
968 while let Some(db_participant) = db_participants.next().await {
969 let db_participant = db_participant?;
970 if let Some(answering_connection_id) = db_participant.answering_connection_id {
971 let location = match (
972 db_participant.location_kind,
973 db_participant.location_project_id,
974 ) {
975 (Some(0), Some(project_id)) => {
976 Some(proto::participant_location::Variant::SharedProject(
977 proto::participant_location::SharedProject {
978 id: project_id.to_proto(),
979 },
980 ))
981 }
982 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
983 Default::default(),
984 )),
985 _ => Some(proto::participant_location::Variant::External(
986 Default::default(),
987 )),
988 };
989 participants.insert(
990 answering_connection_id,
991 proto::Participant {
992 user_id: db_participant.user_id.to_proto(),
993 peer_id: answering_connection_id as u32,
994 projects: Default::default(),
995 location: Some(proto::ParticipantLocation { variant: location }),
996 },
997 );
998 } else {
999 pending_participants.push(proto::PendingParticipant {
1000 user_id: db_participant.user_id.to_proto(),
1001 calling_user_id: db_participant.calling_user_id.to_proto(),
1002 initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
1003 });
1004 }
1005 }
1006
1007 let mut db_projects = db_room
1008 .find_related(project::Entity)
1009 .find_with_related(worktree::Entity)
1010 .stream(tx)
1011 .await?;
1012
1013 while let Some(row) = db_projects.next().await {
1014 let (db_project, db_worktree) = row?;
1015 if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
1016 let project = if let Some(project) = participant
1017 .projects
1018 .iter_mut()
1019 .find(|project| project.id == db_project.id.to_proto())
1020 {
1021 project
1022 } else {
1023 participant.projects.push(proto::ParticipantProject {
1024 id: db_project.id.to_proto(),
1025 worktree_root_names: Default::default(),
1026 });
1027 participant.projects.last_mut().unwrap()
1028 };
1029
1030 if let Some(db_worktree) = db_worktree {
1031 project.worktree_root_names.push(db_worktree.root_name);
1032 }
1033 }
1034 }
1035
1036 Ok(proto::Room {
1037 id: db_room.id.to_proto(),
1038 live_kit_room: db_room.live_kit_room,
1039 participants: participants.into_values().collect(),
1040 pending_participants,
1041 })
1042 }
1043
1044 async fn commit_room_transaction<T>(
1045 &self,
1046 room_id: RoomId,
1047 tx: DatabaseTransaction,
1048 data: T,
1049 ) -> Result<RoomGuard<T>> {
1050 let lock = self.rooms.entry(room_id).or_default().clone();
1051 let _guard = lock.lock_owned().await;
1052 tx.commit().await?;
1053 Ok(RoomGuard {
1054 data,
1055 _guard,
1056 _not_send: PhantomData,
1057 })
1058 }
1059
1060 pub async fn create_access_token_hash(
1061 &self,
1062 user_id: UserId,
1063 access_token_hash: &str,
1064 max_access_token_count: usize,
1065 ) -> Result<()> {
1066 self.transact(|tx| async {
1067 let tx = tx;
1068
1069 access_token::ActiveModel {
1070 user_id: ActiveValue::set(user_id),
1071 hash: ActiveValue::set(access_token_hash.into()),
1072 ..Default::default()
1073 }
1074 .insert(&tx)
1075 .await?;
1076
1077 access_token::Entity::delete_many()
1078 .filter(
1079 access_token::Column::Id.in_subquery(
1080 Query::select()
1081 .column(access_token::Column::Id)
1082 .from(access_token::Entity)
1083 .and_where(access_token::Column::UserId.eq(user_id))
1084 .order_by(access_token::Column::Id, sea_orm::Order::Desc)
1085 .limit(10000)
1086 .offset(max_access_token_count as u64)
1087 .to_owned(),
1088 ),
1089 )
1090 .exec(&tx)
1091 .await?;
1092 tx.commit().await?;
1093 Ok(())
1094 })
1095 .await
1096 }
1097
1098 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
1099 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1100 enum QueryAs {
1101 Hash,
1102 }
1103
1104 self.transact(|tx| async move {
1105 Ok(access_token::Entity::find()
1106 .select_only()
1107 .column(access_token::Column::Hash)
1108 .filter(access_token::Column::UserId.eq(user_id))
1109 .order_by_desc(access_token::Column::Id)
1110 .into_values::<_, QueryAs>()
1111 .all(&tx)
1112 .await?)
1113 })
1114 .await
1115 }
1116
1117 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
1118 where
1119 F: Send + Fn(DatabaseTransaction) -> Fut,
1120 Fut: Send + Future<Output = Result<T>>,
1121 {
1122 let body = async {
1123 loop {
1124 let tx = self.pool.begin().await?;
1125
1126 // In Postgres, serializable transactions are opt-in
1127 if let DatabaseBackend::Postgres = self.pool.get_database_backend() {
1128 tx.execute(Statement::from_string(
1129 DatabaseBackend::Postgres,
1130 "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
1131 ))
1132 .await?;
1133 }
1134
1135 match f(tx).await {
1136 Ok(result) => return Ok(result),
1137 Err(error) => match error {
1138 Error::Database2(
1139 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
1140 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
1141 ) if error
1142 .as_database_error()
1143 .and_then(|error| error.code())
1144 .as_deref()
1145 == Some("40001") =>
1146 {
1147 // Retry (don't break the loop)
1148 }
1149 error @ _ => return Err(error),
1150 },
1151 }
1152 }
1153 };
1154
1155 #[cfg(test)]
1156 {
1157 if let Some(background) = self.background.as_ref() {
1158 background.simulate_random_delay().await;
1159 }
1160
1161 self.runtime.as_ref().unwrap().block_on(body)
1162 }
1163
1164 #[cfg(not(test))]
1165 {
1166 body.await
1167 }
1168 }
1169}
1170
1171pub struct RoomGuard<T> {
1172 data: T,
1173 _guard: OwnedMutexGuard<()>,
1174 _not_send: PhantomData<Rc<()>>,
1175}
1176
1177impl<T> Deref for RoomGuard<T> {
1178 type Target = T;
1179
1180 fn deref(&self) -> &T {
1181 &self.data
1182 }
1183}
1184
1185impl<T> DerefMut for RoomGuard<T> {
1186 fn deref_mut(&mut self) -> &mut T {
1187 &mut self.data
1188 }
1189}
1190
1191#[derive(Debug, Serialize, Deserialize)]
1192pub struct NewUserParams {
1193 pub github_login: String,
1194 pub github_user_id: i32,
1195 pub invite_count: i32,
1196}
1197
1198#[derive(Debug)]
1199pub struct NewUserResult {
1200 pub user_id: UserId,
1201 pub metrics_id: String,
1202 pub inviting_user_id: Option<UserId>,
1203 pub signup_device_id: Option<String>,
1204}
1205
1206fn random_invite_code() -> String {
1207 nanoid::nanoid!(16)
1208}
1209
1210fn random_email_confirmation_code() -> String {
1211 nanoid::nanoid!(64)
1212}
1213
1214macro_rules! id_type {
1215 ($name:ident) => {
1216 #[derive(
1217 Clone,
1218 Copy,
1219 Debug,
1220 Default,
1221 PartialEq,
1222 Eq,
1223 PartialOrd,
1224 Ord,
1225 Hash,
1226 sqlx::Type,
1227 Serialize,
1228 Deserialize,
1229 )]
1230 #[sqlx(transparent)]
1231 #[serde(transparent)]
1232 pub struct $name(pub i32);
1233
1234 impl $name {
1235 #[allow(unused)]
1236 pub const MAX: Self = Self(i32::MAX);
1237
1238 #[allow(unused)]
1239 pub fn from_proto(value: u64) -> Self {
1240 Self(value as i32)
1241 }
1242
1243 #[allow(unused)]
1244 pub fn to_proto(self) -> u64 {
1245 self.0 as u64
1246 }
1247 }
1248
1249 impl std::fmt::Display for $name {
1250 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1251 self.0.fmt(f)
1252 }
1253 }
1254
1255 impl From<$name> for sea_query::Value {
1256 fn from(value: $name) -> Self {
1257 sea_query::Value::Int(Some(value.0))
1258 }
1259 }
1260
1261 impl sea_orm::TryGetable for $name {
1262 fn try_get(
1263 res: &sea_orm::QueryResult,
1264 pre: &str,
1265 col: &str,
1266 ) -> Result<Self, sea_orm::TryGetError> {
1267 Ok(Self(i32::try_get(res, pre, col)?))
1268 }
1269 }
1270
1271 impl sea_query::ValueType for $name {
1272 fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
1273 match v {
1274 Value::TinyInt(Some(int)) => {
1275 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1276 }
1277 Value::SmallInt(Some(int)) => {
1278 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1279 }
1280 Value::Int(Some(int)) => {
1281 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1282 }
1283 Value::BigInt(Some(int)) => {
1284 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1285 }
1286 Value::TinyUnsigned(Some(int)) => {
1287 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1288 }
1289 Value::SmallUnsigned(Some(int)) => {
1290 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1291 }
1292 Value::Unsigned(Some(int)) => {
1293 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1294 }
1295 Value::BigUnsigned(Some(int)) => {
1296 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1297 }
1298 _ => Err(sea_query::ValueTypeErr),
1299 }
1300 }
1301
1302 fn type_name() -> String {
1303 stringify!($name).into()
1304 }
1305
1306 fn array_type() -> sea_query::ArrayType {
1307 sea_query::ArrayType::Int
1308 }
1309
1310 fn column_type() -> sea_query::ColumnType {
1311 sea_query::ColumnType::Integer(None)
1312 }
1313 }
1314
1315 impl sea_orm::TryFromU64 for $name {
1316 fn try_from_u64(n: u64) -> Result<Self, DbErr> {
1317 Ok(Self(n.try_into().map_err(|_| {
1318 DbErr::ConvertFromU64(concat!(
1319 "error converting ",
1320 stringify!($name),
1321 " to u64"
1322 ))
1323 })?))
1324 }
1325 }
1326
1327 impl sea_query::Nullable for $name {
1328 fn null() -> Value {
1329 Value::Int(None)
1330 }
1331 }
1332 };
1333}
1334
1335id_type!(AccessTokenId);
1336id_type!(ContactId);
1337id_type!(UserId);
1338id_type!(RoomId);
1339id_type!(RoomParticipantId);
1340id_type!(ProjectId);
1341id_type!(ProjectCollaboratorId);
1342id_type!(SignupId);
1343id_type!(WorktreeId);
1344
1345#[cfg(test)]
1346pub use test::*;
1347
1348#[cfg(test)]
1349mod test {
1350 use super::*;
1351 use gpui::executor::Background;
1352 use lazy_static::lazy_static;
1353 use parking_lot::Mutex;
1354 use rand::prelude::*;
1355 use sea_orm::ConnectionTrait;
1356 use sqlx::migrate::MigrateDatabase;
1357 use std::sync::Arc;
1358
1359 pub struct TestDb {
1360 pub db: Option<Arc<Database>>,
1361 pub connection: Option<sqlx::AnyConnection>,
1362 }
1363
1364 impl TestDb {
1365 pub fn sqlite(background: Arc<Background>) -> Self {
1366 let url = format!("sqlite::memory:");
1367 let runtime = tokio::runtime::Builder::new_current_thread()
1368 .enable_io()
1369 .enable_time()
1370 .build()
1371 .unwrap();
1372
1373 let mut db = runtime.block_on(async {
1374 let mut options = ConnectOptions::new(url);
1375 options.max_connections(5);
1376 let db = Database::new(options).await.unwrap();
1377 let sql = include_str!(concat!(
1378 env!("CARGO_MANIFEST_DIR"),
1379 "/migrations.sqlite/20221109000000_test_schema.sql"
1380 ));
1381 db.pool
1382 .execute(sea_orm::Statement::from_string(
1383 db.pool.get_database_backend(),
1384 sql.into(),
1385 ))
1386 .await
1387 .unwrap();
1388 db
1389 });
1390
1391 db.background = Some(background);
1392 db.runtime = Some(runtime);
1393
1394 Self {
1395 db: Some(Arc::new(db)),
1396 connection: None,
1397 }
1398 }
1399
1400 pub fn postgres(background: Arc<Background>) -> Self {
1401 lazy_static! {
1402 static ref LOCK: Mutex<()> = Mutex::new(());
1403 }
1404
1405 let _guard = LOCK.lock();
1406 let mut rng = StdRng::from_entropy();
1407 let url = format!(
1408 "postgres://postgres@localhost/zed-test-{}",
1409 rng.gen::<u128>()
1410 );
1411 let runtime = tokio::runtime::Builder::new_current_thread()
1412 .enable_io()
1413 .enable_time()
1414 .build()
1415 .unwrap();
1416
1417 let mut db = runtime.block_on(async {
1418 sqlx::Postgres::create_database(&url)
1419 .await
1420 .expect("failed to create test db");
1421 let mut options = ConnectOptions::new(url);
1422 options
1423 .max_connections(5)
1424 .idle_timeout(Duration::from_secs(0));
1425 let db = Database::new(options).await.unwrap();
1426 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1427 db.migrate(Path::new(migrations_path), false).await.unwrap();
1428 db
1429 });
1430
1431 db.background = Some(background);
1432 db.runtime = Some(runtime);
1433
1434 Self {
1435 db: Some(Arc::new(db)),
1436 connection: None,
1437 }
1438 }
1439
1440 pub fn db(&self) -> &Arc<Database> {
1441 self.db.as_ref().unwrap()
1442 }
1443 }
1444
1445 impl Drop for TestDb {
1446 fn drop(&mut self) {
1447 let db = self.db.take().unwrap();
1448 if let DatabaseBackend::Postgres = db.pool.get_database_backend() {
1449 db.runtime.as_ref().unwrap().block_on(async {
1450 use util::ResultExt;
1451 let query = "
1452 SELECT pg_terminate_backend(pg_stat_activity.pid)
1453 FROM pg_stat_activity
1454 WHERE
1455 pg_stat_activity.datname = current_database() AND
1456 pid <> pg_backend_pid();
1457 ";
1458 db.pool
1459 .execute(sea_orm::Statement::from_string(
1460 db.pool.get_database_backend(),
1461 query.into(),
1462 ))
1463 .await
1464 .log_err();
1465 sqlx::Postgres::drop_database(db.options.get_url())
1466 .await
1467 .log_err();
1468 })
1469 }
1470 }
1471 }
1472}