1mod access_token;
2mod channel;
3mod channel_member;
4mod channel_path;
5mod contact;
6mod follower;
7mod language_server;
8mod project;
9mod project_collaborator;
10mod room;
11mod room_participant;
12mod server;
13mod signup;
14#[cfg(test)]
15mod tests;
16mod user;
17mod worktree;
18mod worktree_diagnostic_summary;
19mod worktree_entry;
20mod worktree_repository;
21mod worktree_repository_statuses;
22mod worktree_settings_file;
23
24use crate::executor::Executor;
25use crate::{Error, Result};
26use anyhow::anyhow;
27use collections::{BTreeMap, HashMap, HashSet};
28pub use contact::Contact;
29use dashmap::DashMap;
30use futures::StreamExt;
31use hyper::StatusCode;
32use rand::prelude::StdRng;
33use rand::{Rng, SeedableRng};
34use rpc::{proto, ConnectionId};
35use sea_orm::Condition;
36pub use sea_orm::ConnectOptions;
37use sea_orm::{
38 entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseConnection, DatabaseTransaction,
39 DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect,
40 Statement, TransactionTrait,
41};
42use sea_query::{Alias, Expr, OnConflict, Query};
43use serde::{Deserialize, Serialize};
44pub use signup::{Invite, NewSignup, WaitlistSummary};
45use sqlx::migrate::{Migrate, Migration, MigrationSource};
46use sqlx::Connection;
47use std::fmt::Write as _;
48use std::ops::{Deref, DerefMut};
49use std::path::Path;
50use std::time::Duration;
51use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
52use tokio::sync::{Mutex, OwnedMutexGuard};
53pub use user::Model as User;
54
55pub struct Database {
56 options: ConnectOptions,
57 pool: DatabaseConnection,
58 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
59 rng: Mutex<StdRng>,
60 executor: Executor,
61 #[cfg(test)]
62 runtime: Option<tokio::runtime::Runtime>,
63}
64
65impl Database {
66 pub async fn new(options: ConnectOptions, executor: Executor) -> Result<Self> {
67 Ok(Self {
68 options: options.clone(),
69 pool: sea_orm::Database::connect(options).await?,
70 rooms: DashMap::with_capacity(16384),
71 rng: Mutex::new(StdRng::seed_from_u64(0)),
72 executor,
73 #[cfg(test)]
74 runtime: None,
75 })
76 }
77
78 #[cfg(test)]
79 pub fn reset(&self) {
80 self.rooms.clear();
81 }
82
83 pub async fn migrate(
84 &self,
85 migrations_path: &Path,
86 ignore_checksum_mismatch: bool,
87 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
88 let migrations = MigrationSource::resolve(migrations_path)
89 .await
90 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
91
92 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
93
94 connection.ensure_migrations_table().await?;
95 let applied_migrations: HashMap<_, _> = connection
96 .list_applied_migrations()
97 .await?
98 .into_iter()
99 .map(|m| (m.version, m))
100 .collect();
101
102 let mut new_migrations = Vec::new();
103 for migration in migrations {
104 match applied_migrations.get(&migration.version) {
105 Some(applied_migration) => {
106 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
107 {
108 Err(anyhow!(
109 "checksum mismatch for applied migration {}",
110 migration.description
111 ))?;
112 }
113 }
114 None => {
115 let elapsed = connection.apply(&migration).await?;
116 new_migrations.push((migration, elapsed));
117 }
118 }
119 }
120
121 Ok(new_migrations)
122 }
123
124 pub async fn create_server(&self, environment: &str) -> Result<ServerId> {
125 self.transaction(|tx| async move {
126 let server = server::ActiveModel {
127 environment: ActiveValue::set(environment.into()),
128 ..Default::default()
129 }
130 .insert(&*tx)
131 .await?;
132 Ok(server.id)
133 })
134 .await
135 }
136
137 pub async fn stale_room_ids(
138 &self,
139 environment: &str,
140 new_server_id: ServerId,
141 ) -> Result<Vec<RoomId>> {
142 self.transaction(|tx| async move {
143 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
144 enum QueryAs {
145 RoomId,
146 }
147
148 let stale_server_epochs = self
149 .stale_server_ids(environment, new_server_id, &tx)
150 .await?;
151 Ok(room_participant::Entity::find()
152 .select_only()
153 .column(room_participant::Column::RoomId)
154 .distinct()
155 .filter(
156 room_participant::Column::AnsweringConnectionServerId
157 .is_in(stale_server_epochs),
158 )
159 .into_values::<_, QueryAs>()
160 .all(&*tx)
161 .await?)
162 })
163 .await
164 }
165
166 pub async fn refresh_room(
167 &self,
168 room_id: RoomId,
169 new_server_id: ServerId,
170 ) -> Result<RoomGuard<RefreshedRoom>> {
171 self.room_transaction(room_id, |tx| async move {
172 let stale_participant_filter = Condition::all()
173 .add(room_participant::Column::RoomId.eq(room_id))
174 .add(room_participant::Column::AnsweringConnectionId.is_not_null())
175 .add(room_participant::Column::AnsweringConnectionServerId.ne(new_server_id));
176
177 let stale_participant_user_ids = room_participant::Entity::find()
178 .filter(stale_participant_filter.clone())
179 .all(&*tx)
180 .await?
181 .into_iter()
182 .map(|participant| participant.user_id)
183 .collect::<Vec<_>>();
184
185 // Delete participants who failed to reconnect and cancel their calls.
186 let mut canceled_calls_to_user_ids = Vec::new();
187 room_participant::Entity::delete_many()
188 .filter(stale_participant_filter)
189 .exec(&*tx)
190 .await?;
191 let called_participants = room_participant::Entity::find()
192 .filter(
193 Condition::all()
194 .add(
195 room_participant::Column::CallingUserId
196 .is_in(stale_participant_user_ids.iter().copied()),
197 )
198 .add(room_participant::Column::AnsweringConnectionId.is_null()),
199 )
200 .all(&*tx)
201 .await?;
202 room_participant::Entity::delete_many()
203 .filter(
204 room_participant::Column::Id
205 .is_in(called_participants.iter().map(|participant| participant.id)),
206 )
207 .exec(&*tx)
208 .await?;
209 canceled_calls_to_user_ids.extend(
210 called_participants
211 .into_iter()
212 .map(|participant| participant.user_id),
213 );
214
215 let (channel_id, room) = self.get_channel_room(room_id, &tx).await?;
216 let channel_members;
217 if let Some(channel_id) = channel_id {
218 channel_members = self.get_channel_members_internal(channel_id, &tx).await?;
219 } else {
220 channel_members = Vec::new();
221
222 // Delete the room if it becomes empty.
223 if room.participants.is_empty() {
224 project::Entity::delete_many()
225 .filter(project::Column::RoomId.eq(room_id))
226 .exec(&*tx)
227 .await?;
228 room::Entity::delete_by_id(room_id).exec(&*tx).await?;
229 }
230 };
231
232 Ok(RefreshedRoom {
233 room,
234 channel_id,
235 channel_members,
236 stale_participant_user_ids,
237 canceled_calls_to_user_ids,
238 })
239 })
240 .await
241 }
242
243 pub async fn delete_stale_servers(
244 &self,
245 environment: &str,
246 new_server_id: ServerId,
247 ) -> Result<()> {
248 self.transaction(|tx| async move {
249 server::Entity::delete_many()
250 .filter(
251 Condition::all()
252 .add(server::Column::Environment.eq(environment))
253 .add(server::Column::Id.ne(new_server_id)),
254 )
255 .exec(&*tx)
256 .await?;
257 Ok(())
258 })
259 .await
260 }
261
262 async fn stale_server_ids(
263 &self,
264 environment: &str,
265 new_server_id: ServerId,
266 tx: &DatabaseTransaction,
267 ) -> Result<Vec<ServerId>> {
268 let stale_servers = server::Entity::find()
269 .filter(
270 Condition::all()
271 .add(server::Column::Environment.eq(environment))
272 .add(server::Column::Id.ne(new_server_id)),
273 )
274 .all(&*tx)
275 .await?;
276 Ok(stale_servers.into_iter().map(|server| server.id).collect())
277 }
278
279 // users
280
281 pub async fn create_user(
282 &self,
283 email_address: &str,
284 admin: bool,
285 params: NewUserParams,
286 ) -> Result<NewUserResult> {
287 self.transaction(|tx| async {
288 let tx = tx;
289 let user = user::Entity::insert(user::ActiveModel {
290 email_address: ActiveValue::set(Some(email_address.into())),
291 github_login: ActiveValue::set(params.github_login.clone()),
292 github_user_id: ActiveValue::set(Some(params.github_user_id)),
293 admin: ActiveValue::set(admin),
294 metrics_id: ActiveValue::set(Uuid::new_v4()),
295 ..Default::default()
296 })
297 .on_conflict(
298 OnConflict::column(user::Column::GithubLogin)
299 .update_column(user::Column::GithubLogin)
300 .to_owned(),
301 )
302 .exec_with_returning(&*tx)
303 .await?;
304
305 Ok(NewUserResult {
306 user_id: user.id,
307 metrics_id: user.metrics_id.to_string(),
308 signup_device_id: None,
309 inviting_user_id: None,
310 })
311 })
312 .await
313 }
314
315 pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<user::Model>> {
316 self.transaction(|tx| async move { Ok(user::Entity::find_by_id(id).one(&*tx).await?) })
317 .await
318 }
319
320 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
321 self.transaction(|tx| async {
322 let tx = tx;
323 Ok(user::Entity::find()
324 .filter(user::Column::Id.is_in(ids.iter().copied()))
325 .all(&*tx)
326 .await?)
327 })
328 .await
329 }
330
331 pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
332 self.transaction(|tx| async move {
333 Ok(user::Entity::find()
334 .filter(user::Column::GithubLogin.eq(github_login))
335 .one(&*tx)
336 .await?)
337 })
338 .await
339 }
340
341 pub async fn get_or_create_user_by_github_account(
342 &self,
343 github_login: &str,
344 github_user_id: Option<i32>,
345 github_email: Option<&str>,
346 ) -> Result<Option<User>> {
347 self.transaction(|tx| async move {
348 let tx = &*tx;
349 if let Some(github_user_id) = github_user_id {
350 if let Some(user_by_github_user_id) = user::Entity::find()
351 .filter(user::Column::GithubUserId.eq(github_user_id))
352 .one(tx)
353 .await?
354 {
355 let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
356 user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
357 Ok(Some(user_by_github_user_id.update(tx).await?))
358 } else if let Some(user_by_github_login) = user::Entity::find()
359 .filter(user::Column::GithubLogin.eq(github_login))
360 .one(tx)
361 .await?
362 {
363 let mut user_by_github_login = user_by_github_login.into_active_model();
364 user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
365 Ok(Some(user_by_github_login.update(tx).await?))
366 } else {
367 let user = user::Entity::insert(user::ActiveModel {
368 email_address: ActiveValue::set(github_email.map(|email| email.into())),
369 github_login: ActiveValue::set(github_login.into()),
370 github_user_id: ActiveValue::set(Some(github_user_id)),
371 admin: ActiveValue::set(false),
372 invite_count: ActiveValue::set(0),
373 invite_code: ActiveValue::set(None),
374 metrics_id: ActiveValue::set(Uuid::new_v4()),
375 ..Default::default()
376 })
377 .exec_with_returning(&*tx)
378 .await?;
379 Ok(Some(user))
380 }
381 } else {
382 Ok(user::Entity::find()
383 .filter(user::Column::GithubLogin.eq(github_login))
384 .one(tx)
385 .await?)
386 }
387 })
388 .await
389 }
390
391 pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
392 self.transaction(|tx| async move {
393 Ok(user::Entity::find()
394 .order_by_asc(user::Column::GithubLogin)
395 .limit(limit as u64)
396 .offset(page as u64 * limit as u64)
397 .all(&*tx)
398 .await?)
399 })
400 .await
401 }
402
403 pub async fn get_users_with_no_invites(
404 &self,
405 invited_by_another_user: bool,
406 ) -> Result<Vec<User>> {
407 self.transaction(|tx| async move {
408 Ok(user::Entity::find()
409 .filter(
410 user::Column::InviteCount
411 .eq(0)
412 .and(if invited_by_another_user {
413 user::Column::InviterId.is_not_null()
414 } else {
415 user::Column::InviterId.is_null()
416 }),
417 )
418 .all(&*tx)
419 .await?)
420 })
421 .await
422 }
423
424 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
425 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
426 enum QueryAs {
427 MetricsId,
428 }
429
430 self.transaction(|tx| async move {
431 let metrics_id: Uuid = user::Entity::find_by_id(id)
432 .select_only()
433 .column(user::Column::MetricsId)
434 .into_values::<_, QueryAs>()
435 .one(&*tx)
436 .await?
437 .ok_or_else(|| anyhow!("could not find user"))?;
438 Ok(metrics_id.to_string())
439 })
440 .await
441 }
442
443 pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
444 self.transaction(|tx| async move {
445 user::Entity::update_many()
446 .filter(user::Column::Id.eq(id))
447 .set(user::ActiveModel {
448 admin: ActiveValue::set(is_admin),
449 ..Default::default()
450 })
451 .exec(&*tx)
452 .await?;
453 Ok(())
454 })
455 .await
456 }
457
458 pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
459 self.transaction(|tx| async move {
460 user::Entity::update_many()
461 .filter(user::Column::Id.eq(id))
462 .set(user::ActiveModel {
463 connected_once: ActiveValue::set(connected_once),
464 ..Default::default()
465 })
466 .exec(&*tx)
467 .await?;
468 Ok(())
469 })
470 .await
471 }
472
473 pub async fn destroy_user(&self, id: UserId) -> Result<()> {
474 self.transaction(|tx| async move {
475 access_token::Entity::delete_many()
476 .filter(access_token::Column::UserId.eq(id))
477 .exec(&*tx)
478 .await?;
479 user::Entity::delete_by_id(id).exec(&*tx).await?;
480 Ok(())
481 })
482 .await
483 }
484
485 // contacts
486
487 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
488 #[derive(Debug, FromQueryResult)]
489 struct ContactWithUserBusyStatuses {
490 user_id_a: UserId,
491 user_id_b: UserId,
492 a_to_b: bool,
493 accepted: bool,
494 should_notify: bool,
495 user_a_busy: bool,
496 user_b_busy: bool,
497 }
498
499 self.transaction(|tx| async move {
500 let user_a_participant = Alias::new("user_a_participant");
501 let user_b_participant = Alias::new("user_b_participant");
502 let mut db_contacts = contact::Entity::find()
503 .column_as(
504 Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
505 .is_not_null(),
506 "user_a_busy",
507 )
508 .column_as(
509 Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
510 .is_not_null(),
511 "user_b_busy",
512 )
513 .filter(
514 contact::Column::UserIdA
515 .eq(user_id)
516 .or(contact::Column::UserIdB.eq(user_id)),
517 )
518 .join_as(
519 JoinType::LeftJoin,
520 contact::Relation::UserARoomParticipant.def(),
521 user_a_participant,
522 )
523 .join_as(
524 JoinType::LeftJoin,
525 contact::Relation::UserBRoomParticipant.def(),
526 user_b_participant,
527 )
528 .into_model::<ContactWithUserBusyStatuses>()
529 .stream(&*tx)
530 .await?;
531
532 let mut contacts = Vec::new();
533 while let Some(db_contact) = db_contacts.next().await {
534 let db_contact = db_contact?;
535 if db_contact.user_id_a == user_id {
536 if db_contact.accepted {
537 contacts.push(Contact::Accepted {
538 user_id: db_contact.user_id_b,
539 should_notify: db_contact.should_notify && db_contact.a_to_b,
540 busy: db_contact.user_b_busy,
541 });
542 } else if db_contact.a_to_b {
543 contacts.push(Contact::Outgoing {
544 user_id: db_contact.user_id_b,
545 })
546 } else {
547 contacts.push(Contact::Incoming {
548 user_id: db_contact.user_id_b,
549 should_notify: db_contact.should_notify,
550 });
551 }
552 } else if db_contact.accepted {
553 contacts.push(Contact::Accepted {
554 user_id: db_contact.user_id_a,
555 should_notify: db_contact.should_notify && !db_contact.a_to_b,
556 busy: db_contact.user_a_busy,
557 });
558 } else if db_contact.a_to_b {
559 contacts.push(Contact::Incoming {
560 user_id: db_contact.user_id_a,
561 should_notify: db_contact.should_notify,
562 });
563 } else {
564 contacts.push(Contact::Outgoing {
565 user_id: db_contact.user_id_a,
566 });
567 }
568 }
569
570 contacts.sort_unstable_by_key(|contact| contact.user_id());
571
572 Ok(contacts)
573 })
574 .await
575 }
576
577 pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
578 self.transaction(|tx| async move {
579 let participant = room_participant::Entity::find()
580 .filter(room_participant::Column::UserId.eq(user_id))
581 .one(&*tx)
582 .await?;
583 Ok(participant.is_some())
584 })
585 .await
586 }
587
588 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
589 self.transaction(|tx| async move {
590 let (id_a, id_b) = if user_id_1 < user_id_2 {
591 (user_id_1, user_id_2)
592 } else {
593 (user_id_2, user_id_1)
594 };
595
596 Ok(contact::Entity::find()
597 .filter(
598 contact::Column::UserIdA
599 .eq(id_a)
600 .and(contact::Column::UserIdB.eq(id_b))
601 .and(contact::Column::Accepted.eq(true)),
602 )
603 .one(&*tx)
604 .await?
605 .is_some())
606 })
607 .await
608 }
609
610 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
611 self.transaction(|tx| async move {
612 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
613 (sender_id, receiver_id, true)
614 } else {
615 (receiver_id, sender_id, false)
616 };
617
618 let rows_affected = contact::Entity::insert(contact::ActiveModel {
619 user_id_a: ActiveValue::set(id_a),
620 user_id_b: ActiveValue::set(id_b),
621 a_to_b: ActiveValue::set(a_to_b),
622 accepted: ActiveValue::set(false),
623 should_notify: ActiveValue::set(true),
624 ..Default::default()
625 })
626 .on_conflict(
627 OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB])
628 .values([
629 (contact::Column::Accepted, true.into()),
630 (contact::Column::ShouldNotify, false.into()),
631 ])
632 .action_and_where(
633 contact::Column::Accepted.eq(false).and(
634 contact::Column::AToB
635 .eq(a_to_b)
636 .and(contact::Column::UserIdA.eq(id_b))
637 .or(contact::Column::AToB
638 .ne(a_to_b)
639 .and(contact::Column::UserIdA.eq(id_a))),
640 ),
641 )
642 .to_owned(),
643 )
644 .exec_without_returning(&*tx)
645 .await?;
646
647 if rows_affected == 1 {
648 Ok(())
649 } else {
650 Err(anyhow!("contact already requested"))?
651 }
652 })
653 .await
654 }
655
656 /// Returns a bool indicating whether the removed contact had originally accepted or not
657 ///
658 /// Deletes the contact identified by the requester and responder ids, and then returns
659 /// whether the deleted contact had originally accepted or was a pending contact request.
660 ///
661 /// # Arguments
662 ///
663 /// * `requester_id` - The user that initiates this request
664 /// * `responder_id` - The user that will be removed
665 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<bool> {
666 self.transaction(|tx| async move {
667 let (id_a, id_b) = if responder_id < requester_id {
668 (responder_id, requester_id)
669 } else {
670 (requester_id, responder_id)
671 };
672
673 let contact = contact::Entity::find()
674 .filter(
675 contact::Column::UserIdA
676 .eq(id_a)
677 .and(contact::Column::UserIdB.eq(id_b)),
678 )
679 .one(&*tx)
680 .await?
681 .ok_or_else(|| anyhow!("no such contact"))?;
682
683 contact::Entity::delete_by_id(contact.id).exec(&*tx).await?;
684 Ok(contact.accepted)
685 })
686 .await
687 }
688
689 pub async fn dismiss_contact_notification(
690 &self,
691 user_id: UserId,
692 contact_user_id: UserId,
693 ) -> Result<()> {
694 self.transaction(|tx| async move {
695 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
696 (user_id, contact_user_id, true)
697 } else {
698 (contact_user_id, user_id, false)
699 };
700
701 let result = contact::Entity::update_many()
702 .set(contact::ActiveModel {
703 should_notify: ActiveValue::set(false),
704 ..Default::default()
705 })
706 .filter(
707 contact::Column::UserIdA
708 .eq(id_a)
709 .and(contact::Column::UserIdB.eq(id_b))
710 .and(
711 contact::Column::AToB
712 .eq(a_to_b)
713 .and(contact::Column::Accepted.eq(true))
714 .or(contact::Column::AToB
715 .ne(a_to_b)
716 .and(contact::Column::Accepted.eq(false))),
717 ),
718 )
719 .exec(&*tx)
720 .await?;
721 if result.rows_affected == 0 {
722 Err(anyhow!("no such contact request"))?
723 } else {
724 Ok(())
725 }
726 })
727 .await
728 }
729
730 pub async fn respond_to_contact_request(
731 &self,
732 responder_id: UserId,
733 requester_id: UserId,
734 accept: bool,
735 ) -> Result<()> {
736 self.transaction(|tx| async move {
737 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
738 (responder_id, requester_id, false)
739 } else {
740 (requester_id, responder_id, true)
741 };
742 let rows_affected = if accept {
743 let result = contact::Entity::update_many()
744 .set(contact::ActiveModel {
745 accepted: ActiveValue::set(true),
746 should_notify: ActiveValue::set(true),
747 ..Default::default()
748 })
749 .filter(
750 contact::Column::UserIdA
751 .eq(id_a)
752 .and(contact::Column::UserIdB.eq(id_b))
753 .and(contact::Column::AToB.eq(a_to_b)),
754 )
755 .exec(&*tx)
756 .await?;
757 result.rows_affected
758 } else {
759 let result = contact::Entity::delete_many()
760 .filter(
761 contact::Column::UserIdA
762 .eq(id_a)
763 .and(contact::Column::UserIdB.eq(id_b))
764 .and(contact::Column::AToB.eq(a_to_b))
765 .and(contact::Column::Accepted.eq(false)),
766 )
767 .exec(&*tx)
768 .await?;
769
770 result.rows_affected
771 };
772
773 if rows_affected == 1 {
774 Ok(())
775 } else {
776 Err(anyhow!("no such contact request"))?
777 }
778 })
779 .await
780 }
781
782 pub fn fuzzy_like_string(string: &str) -> String {
783 let mut result = String::with_capacity(string.len() * 2 + 1);
784 for c in string.chars() {
785 if c.is_alphanumeric() {
786 result.push('%');
787 result.push(c);
788 }
789 }
790 result.push('%');
791 result
792 }
793
794 pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
795 self.transaction(|tx| async {
796 let tx = tx;
797 let like_string = Self::fuzzy_like_string(name_query);
798 let query = "
799 SELECT users.*
800 FROM users
801 WHERE github_login ILIKE $1
802 ORDER BY github_login <-> $2
803 LIMIT $3
804 ";
805
806 Ok(user::Entity::find()
807 .from_raw_sql(Statement::from_sql_and_values(
808 self.pool.get_database_backend(),
809 query.into(),
810 vec![like_string.into(), name_query.into(), limit.into()],
811 ))
812 .all(&*tx)
813 .await?)
814 })
815 .await
816 }
817
818 // signups
819
820 pub async fn create_signup(&self, signup: &NewSignup) -> Result<()> {
821 self.transaction(|tx| async move {
822 signup::Entity::insert(signup::ActiveModel {
823 email_address: ActiveValue::set(signup.email_address.clone()),
824 email_confirmation_code: ActiveValue::set(random_email_confirmation_code()),
825 email_confirmation_sent: ActiveValue::set(false),
826 platform_mac: ActiveValue::set(signup.platform_mac),
827 platform_windows: ActiveValue::set(signup.platform_windows),
828 platform_linux: ActiveValue::set(signup.platform_linux),
829 platform_unknown: ActiveValue::set(false),
830 editor_features: ActiveValue::set(Some(signup.editor_features.clone())),
831 programming_languages: ActiveValue::set(Some(signup.programming_languages.clone())),
832 device_id: ActiveValue::set(signup.device_id.clone()),
833 added_to_mailing_list: ActiveValue::set(signup.added_to_mailing_list),
834 ..Default::default()
835 })
836 .on_conflict(
837 OnConflict::column(signup::Column::EmailAddress)
838 .update_columns([
839 signup::Column::PlatformMac,
840 signup::Column::PlatformWindows,
841 signup::Column::PlatformLinux,
842 signup::Column::EditorFeatures,
843 signup::Column::ProgrammingLanguages,
844 signup::Column::DeviceId,
845 signup::Column::AddedToMailingList,
846 ])
847 .to_owned(),
848 )
849 .exec(&*tx)
850 .await?;
851 Ok(())
852 })
853 .await
854 }
855
856 pub async fn get_signup(&self, email_address: &str) -> Result<signup::Model> {
857 self.transaction(|tx| async move {
858 let signup = signup::Entity::find()
859 .filter(signup::Column::EmailAddress.eq(email_address))
860 .one(&*tx)
861 .await?
862 .ok_or_else(|| {
863 anyhow!("signup with email address {} doesn't exist", email_address)
864 })?;
865
866 Ok(signup)
867 })
868 .await
869 }
870
871 pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
872 self.transaction(|tx| async move {
873 let query = "
874 SELECT
875 COUNT(*) as count,
876 COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
877 COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
878 COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
879 COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
880 FROM (
881 SELECT *
882 FROM signups
883 WHERE
884 NOT email_confirmation_sent
885 ) AS unsent
886 ";
887 Ok(
888 WaitlistSummary::find_by_statement(Statement::from_sql_and_values(
889 self.pool.get_database_backend(),
890 query.into(),
891 vec![],
892 ))
893 .one(&*tx)
894 .await?
895 .ok_or_else(|| anyhow!("invalid result"))?,
896 )
897 })
898 .await
899 }
900
901 pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
902 let emails = invites
903 .iter()
904 .map(|s| s.email_address.as_str())
905 .collect::<Vec<_>>();
906 self.transaction(|tx| async {
907 let tx = tx;
908 signup::Entity::update_many()
909 .filter(signup::Column::EmailAddress.is_in(emails.iter().copied()))
910 .set(signup::ActiveModel {
911 email_confirmation_sent: ActiveValue::set(true),
912 ..Default::default()
913 })
914 .exec(&*tx)
915 .await?;
916 Ok(())
917 })
918 .await
919 }
920
921 pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
922 self.transaction(|tx| async move {
923 Ok(signup::Entity::find()
924 .select_only()
925 .column(signup::Column::EmailAddress)
926 .column(signup::Column::EmailConfirmationCode)
927 .filter(
928 signup::Column::EmailConfirmationSent.eq(false).and(
929 signup::Column::PlatformMac
930 .eq(true)
931 .or(signup::Column::PlatformUnknown.eq(true)),
932 ),
933 )
934 .order_by_asc(signup::Column::CreatedAt)
935 .limit(count as u64)
936 .into_model()
937 .all(&*tx)
938 .await?)
939 })
940 .await
941 }
942
943 // invite codes
944
945 pub async fn create_invite_from_code(
946 &self,
947 code: &str,
948 email_address: &str,
949 device_id: Option<&str>,
950 added_to_mailing_list: bool,
951 ) -> Result<Invite> {
952 self.transaction(|tx| async move {
953 let existing_user = user::Entity::find()
954 .filter(user::Column::EmailAddress.eq(email_address))
955 .one(&*tx)
956 .await?;
957
958 if existing_user.is_some() {
959 Err(anyhow!("email address is already in use"))?;
960 }
961
962 let inviting_user_with_invites = match user::Entity::find()
963 .filter(
964 user::Column::InviteCode
965 .eq(code)
966 .and(user::Column::InviteCount.gt(0)),
967 )
968 .one(&*tx)
969 .await?
970 {
971 Some(inviting_user) => inviting_user,
972 None => {
973 return Err(Error::Http(
974 StatusCode::UNAUTHORIZED,
975 "unable to find an invite code with invites remaining".to_string(),
976 ))?
977 }
978 };
979 user::Entity::update_many()
980 .filter(
981 user::Column::Id
982 .eq(inviting_user_with_invites.id)
983 .and(user::Column::InviteCount.gt(0)),
984 )
985 .col_expr(
986 user::Column::InviteCount,
987 Expr::col(user::Column::InviteCount).sub(1),
988 )
989 .exec(&*tx)
990 .await?;
991
992 let signup = signup::Entity::insert(signup::ActiveModel {
993 email_address: ActiveValue::set(email_address.into()),
994 email_confirmation_code: ActiveValue::set(random_email_confirmation_code()),
995 email_confirmation_sent: ActiveValue::set(false),
996 inviting_user_id: ActiveValue::set(Some(inviting_user_with_invites.id)),
997 platform_linux: ActiveValue::set(false),
998 platform_mac: ActiveValue::set(false),
999 platform_windows: ActiveValue::set(false),
1000 platform_unknown: ActiveValue::set(true),
1001 device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())),
1002 added_to_mailing_list: ActiveValue::set(added_to_mailing_list),
1003 ..Default::default()
1004 })
1005 .on_conflict(
1006 OnConflict::column(signup::Column::EmailAddress)
1007 .update_column(signup::Column::InvitingUserId)
1008 .to_owned(),
1009 )
1010 .exec_with_returning(&*tx)
1011 .await?;
1012
1013 Ok(Invite {
1014 email_address: signup.email_address,
1015 email_confirmation_code: signup.email_confirmation_code,
1016 })
1017 })
1018 .await
1019 }
1020
1021 pub async fn create_user_from_invite(
1022 &self,
1023 invite: &Invite,
1024 user: NewUserParams,
1025 ) -> Result<Option<NewUserResult>> {
1026 self.transaction(|tx| async {
1027 let tx = tx;
1028 let signup = signup::Entity::find()
1029 .filter(
1030 signup::Column::EmailAddress
1031 .eq(invite.email_address.as_str())
1032 .and(
1033 signup::Column::EmailConfirmationCode
1034 .eq(invite.email_confirmation_code.as_str()),
1035 ),
1036 )
1037 .one(&*tx)
1038 .await?
1039 .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
1040
1041 if signup.user_id.is_some() {
1042 return Ok(None);
1043 }
1044
1045 let user = user::Entity::insert(user::ActiveModel {
1046 email_address: ActiveValue::set(Some(invite.email_address.clone())),
1047 github_login: ActiveValue::set(user.github_login.clone()),
1048 github_user_id: ActiveValue::set(Some(user.github_user_id)),
1049 admin: ActiveValue::set(false),
1050 invite_count: ActiveValue::set(user.invite_count),
1051 invite_code: ActiveValue::set(Some(random_invite_code())),
1052 metrics_id: ActiveValue::set(Uuid::new_v4()),
1053 ..Default::default()
1054 })
1055 .on_conflict(
1056 OnConflict::column(user::Column::GithubLogin)
1057 .update_columns([
1058 user::Column::EmailAddress,
1059 user::Column::GithubUserId,
1060 user::Column::Admin,
1061 ])
1062 .to_owned(),
1063 )
1064 .exec_with_returning(&*tx)
1065 .await?;
1066
1067 let mut signup = signup.into_active_model();
1068 signup.user_id = ActiveValue::set(Some(user.id));
1069 let signup = signup.update(&*tx).await?;
1070
1071 if let Some(inviting_user_id) = signup.inviting_user_id {
1072 let (user_id_a, user_id_b, a_to_b) = if inviting_user_id < user.id {
1073 (inviting_user_id, user.id, true)
1074 } else {
1075 (user.id, inviting_user_id, false)
1076 };
1077
1078 contact::Entity::insert(contact::ActiveModel {
1079 user_id_a: ActiveValue::set(user_id_a),
1080 user_id_b: ActiveValue::set(user_id_b),
1081 a_to_b: ActiveValue::set(a_to_b),
1082 should_notify: ActiveValue::set(true),
1083 accepted: ActiveValue::set(true),
1084 ..Default::default()
1085 })
1086 .on_conflict(OnConflict::new().do_nothing().to_owned())
1087 .exec_without_returning(&*tx)
1088 .await?;
1089 }
1090
1091 Ok(Some(NewUserResult {
1092 user_id: user.id,
1093 metrics_id: user.metrics_id.to_string(),
1094 inviting_user_id: signup.inviting_user_id,
1095 signup_device_id: signup.device_id,
1096 }))
1097 })
1098 .await
1099 }
1100
1101 pub async fn set_invite_count_for_user(&self, id: UserId, count: i32) -> Result<()> {
1102 self.transaction(|tx| async move {
1103 if count > 0 {
1104 user::Entity::update_many()
1105 .filter(
1106 user::Column::Id
1107 .eq(id)
1108 .and(user::Column::InviteCode.is_null()),
1109 )
1110 .set(user::ActiveModel {
1111 invite_code: ActiveValue::set(Some(random_invite_code())),
1112 ..Default::default()
1113 })
1114 .exec(&*tx)
1115 .await?;
1116 }
1117
1118 user::Entity::update_many()
1119 .filter(user::Column::Id.eq(id))
1120 .set(user::ActiveModel {
1121 invite_count: ActiveValue::set(count),
1122 ..Default::default()
1123 })
1124 .exec(&*tx)
1125 .await?;
1126 Ok(())
1127 })
1128 .await
1129 }
1130
1131 pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, i32)>> {
1132 self.transaction(|tx| async move {
1133 match user::Entity::find_by_id(id).one(&*tx).await? {
1134 Some(user) if user.invite_code.is_some() => {
1135 Ok(Some((user.invite_code.unwrap(), user.invite_count)))
1136 }
1137 _ => Ok(None),
1138 }
1139 })
1140 .await
1141 }
1142
1143 pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
1144 self.transaction(|tx| async move {
1145 user::Entity::find()
1146 .filter(user::Column::InviteCode.eq(code))
1147 .one(&*tx)
1148 .await?
1149 .ok_or_else(|| {
1150 Error::Http(
1151 StatusCode::NOT_FOUND,
1152 "that invite code does not exist".to_string(),
1153 )
1154 })
1155 })
1156 .await
1157 }
1158
1159 // rooms
1160
1161 pub async fn incoming_call_for_user(
1162 &self,
1163 user_id: UserId,
1164 ) -> Result<Option<proto::IncomingCall>> {
1165 self.transaction(|tx| async move {
1166 let pending_participant = room_participant::Entity::find()
1167 .filter(
1168 room_participant::Column::UserId
1169 .eq(user_id)
1170 .and(room_participant::Column::AnsweringConnectionId.is_null()),
1171 )
1172 .one(&*tx)
1173 .await?;
1174
1175 if let Some(pending_participant) = pending_participant {
1176 let room = self.get_room(pending_participant.room_id, &tx).await?;
1177 Ok(Self::build_incoming_call(&room, user_id))
1178 } else {
1179 Ok(None)
1180 }
1181 })
1182 .await
1183 }
1184
1185 pub async fn create_room(
1186 &self,
1187 user_id: UserId,
1188 connection: ConnectionId,
1189 live_kit_room: &str,
1190 ) -> Result<proto::Room> {
1191 self.transaction(|tx| async move {
1192 let room = room::ActiveModel {
1193 live_kit_room: ActiveValue::set(live_kit_room.into()),
1194 ..Default::default()
1195 }
1196 .insert(&*tx)
1197 .await?;
1198 room_participant::ActiveModel {
1199 room_id: ActiveValue::set(room.id),
1200 user_id: ActiveValue::set(user_id),
1201 answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
1202 answering_connection_server_id: ActiveValue::set(Some(ServerId(
1203 connection.owner_id as i32,
1204 ))),
1205 answering_connection_lost: ActiveValue::set(false),
1206 calling_user_id: ActiveValue::set(user_id),
1207 calling_connection_id: ActiveValue::set(connection.id as i32),
1208 calling_connection_server_id: ActiveValue::set(Some(ServerId(
1209 connection.owner_id as i32,
1210 ))),
1211 ..Default::default()
1212 }
1213 .insert(&*tx)
1214 .await?;
1215
1216 let room = self.get_room(room.id, &tx).await?;
1217 Ok(room)
1218 })
1219 .await
1220 }
1221
1222 pub async fn call(
1223 &self,
1224 room_id: RoomId,
1225 calling_user_id: UserId,
1226 calling_connection: ConnectionId,
1227 called_user_id: UserId,
1228 initial_project_id: Option<ProjectId>,
1229 ) -> Result<RoomGuard<(proto::Room, proto::IncomingCall)>> {
1230 self.room_transaction(room_id, |tx| async move {
1231 room_participant::ActiveModel {
1232 room_id: ActiveValue::set(room_id),
1233 user_id: ActiveValue::set(called_user_id),
1234 answering_connection_lost: ActiveValue::set(false),
1235 calling_user_id: ActiveValue::set(calling_user_id),
1236 calling_connection_id: ActiveValue::set(calling_connection.id as i32),
1237 calling_connection_server_id: ActiveValue::set(Some(ServerId(
1238 calling_connection.owner_id as i32,
1239 ))),
1240 initial_project_id: ActiveValue::set(initial_project_id),
1241 ..Default::default()
1242 }
1243 .insert(&*tx)
1244 .await?;
1245
1246 let room = self.get_room(room_id, &tx).await?;
1247 let incoming_call = Self::build_incoming_call(&room, called_user_id)
1248 .ok_or_else(|| anyhow!("failed to build incoming call"))?;
1249 Ok((room, incoming_call))
1250 })
1251 .await
1252 }
1253
1254 pub async fn call_failed(
1255 &self,
1256 room_id: RoomId,
1257 called_user_id: UserId,
1258 ) -> Result<RoomGuard<proto::Room>> {
1259 self.room_transaction(room_id, |tx| async move {
1260 room_participant::Entity::delete_many()
1261 .filter(
1262 room_participant::Column::RoomId
1263 .eq(room_id)
1264 .and(room_participant::Column::UserId.eq(called_user_id)),
1265 )
1266 .exec(&*tx)
1267 .await?;
1268 let room = self.get_room(room_id, &tx).await?;
1269 Ok(room)
1270 })
1271 .await
1272 }
1273
1274 pub async fn decline_call(
1275 &self,
1276 expected_room_id: Option<RoomId>,
1277 user_id: UserId,
1278 ) -> Result<Option<RoomGuard<proto::Room>>> {
1279 self.optional_room_transaction(|tx| async move {
1280 let mut filter = Condition::all()
1281 .add(room_participant::Column::UserId.eq(user_id))
1282 .add(room_participant::Column::AnsweringConnectionId.is_null());
1283 if let Some(room_id) = expected_room_id {
1284 filter = filter.add(room_participant::Column::RoomId.eq(room_id));
1285 }
1286 let participant = room_participant::Entity::find()
1287 .filter(filter)
1288 .one(&*tx)
1289 .await?;
1290
1291 let participant = if let Some(participant) = participant {
1292 participant
1293 } else if expected_room_id.is_some() {
1294 return Err(anyhow!("could not find call to decline"))?;
1295 } else {
1296 return Ok(None);
1297 };
1298
1299 let room_id = participant.room_id;
1300 room_participant::Entity::delete(participant.into_active_model())
1301 .exec(&*tx)
1302 .await?;
1303
1304 let room = self.get_room(room_id, &tx).await?;
1305 Ok(Some((room_id, room)))
1306 })
1307 .await
1308 }
1309
1310 pub async fn cancel_call(
1311 &self,
1312 room_id: RoomId,
1313 calling_connection: ConnectionId,
1314 called_user_id: UserId,
1315 ) -> Result<RoomGuard<proto::Room>> {
1316 self.room_transaction(room_id, |tx| async move {
1317 let participant = room_participant::Entity::find()
1318 .filter(
1319 Condition::all()
1320 .add(room_participant::Column::UserId.eq(called_user_id))
1321 .add(room_participant::Column::RoomId.eq(room_id))
1322 .add(
1323 room_participant::Column::CallingConnectionId
1324 .eq(calling_connection.id as i32),
1325 )
1326 .add(
1327 room_participant::Column::CallingConnectionServerId
1328 .eq(calling_connection.owner_id as i32),
1329 )
1330 .add(room_participant::Column::AnsweringConnectionId.is_null()),
1331 )
1332 .one(&*tx)
1333 .await?
1334 .ok_or_else(|| anyhow!("no call to cancel"))?;
1335
1336 room_participant::Entity::delete(participant.into_active_model())
1337 .exec(&*tx)
1338 .await?;
1339
1340 let room = self.get_room(room_id, &tx).await?;
1341 Ok(room)
1342 })
1343 .await
1344 }
1345
1346 pub async fn is_current_room_different_channel(
1347 &self,
1348 user_id: UserId,
1349 channel_id: ChannelId,
1350 ) -> Result<bool> {
1351 self.transaction(|tx| async move {
1352 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1353 enum QueryAs {
1354 ChannelId,
1355 }
1356
1357 let channel_id_model: Option<ChannelId> = room_participant::Entity::find()
1358 .select_only()
1359 .column_as(room::Column::ChannelId, QueryAs::ChannelId)
1360 .inner_join(room::Entity)
1361 .filter(room_participant::Column::UserId.eq(user_id))
1362 .into_values::<_, QueryAs>()
1363 .one(&*tx)
1364 .await?;
1365
1366 let result = channel_id_model
1367 .map(|channel_id_model| channel_id_model != channel_id)
1368 .unwrap_or(false);
1369
1370 Ok(result)
1371 })
1372 .await
1373 }
1374
1375 pub async fn join_room(
1376 &self,
1377 room_id: RoomId,
1378 user_id: UserId,
1379 channel_id: Option<ChannelId>,
1380 connection: ConnectionId,
1381 ) -> Result<RoomGuard<JoinRoom>> {
1382 self.room_transaction(room_id, |tx| async move {
1383 if let Some(channel_id) = channel_id {
1384 self.check_user_is_channel_member(channel_id, user_id, &*tx)
1385 .await?;
1386
1387 room_participant::ActiveModel {
1388 room_id: ActiveValue::set(room_id),
1389 user_id: ActiveValue::set(user_id),
1390 answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
1391 answering_connection_server_id: ActiveValue::set(Some(ServerId(
1392 connection.owner_id as i32,
1393 ))),
1394 answering_connection_lost: ActiveValue::set(false),
1395 // Redundant for the channel join use case, used for channel and call invitations
1396 calling_user_id: ActiveValue::set(user_id),
1397 calling_connection_id: ActiveValue::set(connection.id as i32),
1398 calling_connection_server_id: ActiveValue::set(Some(ServerId(
1399 connection.owner_id as i32,
1400 ))),
1401 ..Default::default()
1402 }
1403 .insert(&*tx)
1404 .await?;
1405 } else {
1406 let result = room_participant::Entity::update_many()
1407 .filter(
1408 Condition::all()
1409 .add(room_participant::Column::RoomId.eq(room_id))
1410 .add(room_participant::Column::UserId.eq(user_id))
1411 .add(room_participant::Column::AnsweringConnectionId.is_null()),
1412 )
1413 .set(room_participant::ActiveModel {
1414 answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
1415 answering_connection_server_id: ActiveValue::set(Some(ServerId(
1416 connection.owner_id as i32,
1417 ))),
1418 answering_connection_lost: ActiveValue::set(false),
1419 ..Default::default()
1420 })
1421 .exec(&*tx)
1422 .await?;
1423 if result.rows_affected == 0 {
1424 Err(anyhow!("room does not exist or was already joined"))?;
1425 }
1426 }
1427
1428 let room = self.get_room(room_id, &tx).await?;
1429 let channel_members = if let Some(channel_id) = channel_id {
1430 self.get_channel_members_internal(channel_id, &tx).await?
1431 } else {
1432 Vec::new()
1433 };
1434 Ok(JoinRoom {
1435 room,
1436 channel_id,
1437 channel_members,
1438 })
1439 })
1440 .await
1441 }
1442
1443 pub async fn rejoin_room(
1444 &self,
1445 rejoin_room: proto::RejoinRoom,
1446 user_id: UserId,
1447 connection: ConnectionId,
1448 ) -> Result<RoomGuard<RejoinedRoom>> {
1449 let room_id = RoomId::from_proto(rejoin_room.id);
1450 self.room_transaction(room_id, |tx| async {
1451 let tx = tx;
1452 let participant_update = room_participant::Entity::update_many()
1453 .filter(
1454 Condition::all()
1455 .add(room_participant::Column::RoomId.eq(room_id))
1456 .add(room_participant::Column::UserId.eq(user_id))
1457 .add(room_participant::Column::AnsweringConnectionId.is_not_null())
1458 .add(
1459 Condition::any()
1460 .add(room_participant::Column::AnsweringConnectionLost.eq(true))
1461 .add(
1462 room_participant::Column::AnsweringConnectionServerId
1463 .ne(connection.owner_id as i32),
1464 ),
1465 ),
1466 )
1467 .set(room_participant::ActiveModel {
1468 answering_connection_id: ActiveValue::set(Some(connection.id as i32)),
1469 answering_connection_server_id: ActiveValue::set(Some(ServerId(
1470 connection.owner_id as i32,
1471 ))),
1472 answering_connection_lost: ActiveValue::set(false),
1473 ..Default::default()
1474 })
1475 .exec(&*tx)
1476 .await?;
1477 if participant_update.rows_affected == 0 {
1478 return Err(anyhow!("room does not exist or was already joined"))?;
1479 }
1480
1481 let mut reshared_projects = Vec::new();
1482 for reshared_project in &rejoin_room.reshared_projects {
1483 let project_id = ProjectId::from_proto(reshared_project.project_id);
1484 let project = project::Entity::find_by_id(project_id)
1485 .one(&*tx)
1486 .await?
1487 .ok_or_else(|| anyhow!("project does not exist"))?;
1488 if project.host_user_id != user_id {
1489 return Err(anyhow!("no such project"))?;
1490 }
1491
1492 let mut collaborators = project
1493 .find_related(project_collaborator::Entity)
1494 .all(&*tx)
1495 .await?;
1496 let host_ix = collaborators
1497 .iter()
1498 .position(|collaborator| {
1499 collaborator.user_id == user_id && collaborator.is_host
1500 })
1501 .ok_or_else(|| anyhow!("host not found among collaborators"))?;
1502 let host = collaborators.swap_remove(host_ix);
1503 let old_connection_id = host.connection();
1504
1505 project::Entity::update(project::ActiveModel {
1506 host_connection_id: ActiveValue::set(Some(connection.id as i32)),
1507 host_connection_server_id: ActiveValue::set(Some(ServerId(
1508 connection.owner_id as i32,
1509 ))),
1510 ..project.into_active_model()
1511 })
1512 .exec(&*tx)
1513 .await?;
1514 project_collaborator::Entity::update(project_collaborator::ActiveModel {
1515 connection_id: ActiveValue::set(connection.id as i32),
1516 connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)),
1517 ..host.into_active_model()
1518 })
1519 .exec(&*tx)
1520 .await?;
1521
1522 self.update_project_worktrees(project_id, &reshared_project.worktrees, &tx)
1523 .await?;
1524
1525 reshared_projects.push(ResharedProject {
1526 id: project_id,
1527 old_connection_id,
1528 collaborators: collaborators
1529 .iter()
1530 .map(|collaborator| ProjectCollaborator {
1531 connection_id: collaborator.connection(),
1532 user_id: collaborator.user_id,
1533 replica_id: collaborator.replica_id,
1534 is_host: collaborator.is_host,
1535 })
1536 .collect(),
1537 worktrees: reshared_project.worktrees.clone(),
1538 });
1539 }
1540
1541 project::Entity::delete_many()
1542 .filter(
1543 Condition::all()
1544 .add(project::Column::RoomId.eq(room_id))
1545 .add(project::Column::HostUserId.eq(user_id))
1546 .add(
1547 project::Column::Id
1548 .is_not_in(reshared_projects.iter().map(|project| project.id)),
1549 ),
1550 )
1551 .exec(&*tx)
1552 .await?;
1553
1554 let mut rejoined_projects = Vec::new();
1555 for rejoined_project in &rejoin_room.rejoined_projects {
1556 let project_id = ProjectId::from_proto(rejoined_project.id);
1557 let Some(project) = project::Entity::find_by_id(project_id)
1558 .one(&*tx)
1559 .await? else { continue };
1560
1561 let mut worktrees = Vec::new();
1562 let db_worktrees = project.find_related(worktree::Entity).all(&*tx).await?;
1563 for db_worktree in db_worktrees {
1564 let mut worktree = RejoinedWorktree {
1565 id: db_worktree.id as u64,
1566 abs_path: db_worktree.abs_path,
1567 root_name: db_worktree.root_name,
1568 visible: db_worktree.visible,
1569 updated_entries: Default::default(),
1570 removed_entries: Default::default(),
1571 updated_repositories: Default::default(),
1572 removed_repositories: Default::default(),
1573 diagnostic_summaries: Default::default(),
1574 settings_files: Default::default(),
1575 scan_id: db_worktree.scan_id as u64,
1576 completed_scan_id: db_worktree.completed_scan_id as u64,
1577 };
1578
1579 let rejoined_worktree = rejoined_project
1580 .worktrees
1581 .iter()
1582 .find(|worktree| worktree.id == db_worktree.id as u64);
1583
1584 // File entries
1585 {
1586 let entry_filter = if let Some(rejoined_worktree) = rejoined_worktree {
1587 worktree_entry::Column::ScanId.gt(rejoined_worktree.scan_id)
1588 } else {
1589 worktree_entry::Column::IsDeleted.eq(false)
1590 };
1591
1592 let mut db_entries = worktree_entry::Entity::find()
1593 .filter(
1594 Condition::all()
1595 .add(worktree_entry::Column::ProjectId.eq(project.id))
1596 .add(worktree_entry::Column::WorktreeId.eq(worktree.id))
1597 .add(entry_filter),
1598 )
1599 .stream(&*tx)
1600 .await?;
1601
1602 while let Some(db_entry) = db_entries.next().await {
1603 let db_entry = db_entry?;
1604 if db_entry.is_deleted {
1605 worktree.removed_entries.push(db_entry.id as u64);
1606 } else {
1607 worktree.updated_entries.push(proto::Entry {
1608 id: db_entry.id as u64,
1609 is_dir: db_entry.is_dir,
1610 path: db_entry.path,
1611 inode: db_entry.inode as u64,
1612 mtime: Some(proto::Timestamp {
1613 seconds: db_entry.mtime_seconds as u64,
1614 nanos: db_entry.mtime_nanos as u32,
1615 }),
1616 is_symlink: db_entry.is_symlink,
1617 is_ignored: db_entry.is_ignored,
1618 is_external: db_entry.is_external,
1619 git_status: db_entry.git_status.map(|status| status as i32),
1620 });
1621 }
1622 }
1623 }
1624
1625 // Repository Entries
1626 {
1627 let repository_entry_filter =
1628 if let Some(rejoined_worktree) = rejoined_worktree {
1629 worktree_repository::Column::ScanId.gt(rejoined_worktree.scan_id)
1630 } else {
1631 worktree_repository::Column::IsDeleted.eq(false)
1632 };
1633
1634 let mut db_repositories = worktree_repository::Entity::find()
1635 .filter(
1636 Condition::all()
1637 .add(worktree_repository::Column::ProjectId.eq(project.id))
1638 .add(worktree_repository::Column::WorktreeId.eq(worktree.id))
1639 .add(repository_entry_filter),
1640 )
1641 .stream(&*tx)
1642 .await?;
1643
1644 while let Some(db_repository) = db_repositories.next().await {
1645 let db_repository = db_repository?;
1646 if db_repository.is_deleted {
1647 worktree
1648 .removed_repositories
1649 .push(db_repository.work_directory_id as u64);
1650 } else {
1651 worktree.updated_repositories.push(proto::RepositoryEntry {
1652 work_directory_id: db_repository.work_directory_id as u64,
1653 branch: db_repository.branch,
1654 });
1655 }
1656 }
1657 }
1658
1659 worktrees.push(worktree);
1660 }
1661
1662 let language_servers = project
1663 .find_related(language_server::Entity)
1664 .all(&*tx)
1665 .await?
1666 .into_iter()
1667 .map(|language_server| proto::LanguageServer {
1668 id: language_server.id as u64,
1669 name: language_server.name,
1670 })
1671 .collect::<Vec<_>>();
1672
1673 {
1674 let mut db_settings_files = worktree_settings_file::Entity::find()
1675 .filter(worktree_settings_file::Column::ProjectId.eq(project_id))
1676 .stream(&*tx)
1677 .await?;
1678 while let Some(db_settings_file) = db_settings_files.next().await {
1679 let db_settings_file = db_settings_file?;
1680 if let Some(worktree) = worktrees
1681 .iter_mut()
1682 .find(|w| w.id == db_settings_file.worktree_id as u64)
1683 {
1684 worktree.settings_files.push(WorktreeSettingsFile {
1685 path: db_settings_file.path,
1686 content: db_settings_file.content,
1687 });
1688 }
1689 }
1690 }
1691
1692 let mut collaborators = project
1693 .find_related(project_collaborator::Entity)
1694 .all(&*tx)
1695 .await?;
1696 let self_collaborator = if let Some(self_collaborator_ix) = collaborators
1697 .iter()
1698 .position(|collaborator| collaborator.user_id == user_id)
1699 {
1700 collaborators.swap_remove(self_collaborator_ix)
1701 } else {
1702 continue;
1703 };
1704 let old_connection_id = self_collaborator.connection();
1705 project_collaborator::Entity::update(project_collaborator::ActiveModel {
1706 connection_id: ActiveValue::set(connection.id as i32),
1707 connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)),
1708 ..self_collaborator.into_active_model()
1709 })
1710 .exec(&*tx)
1711 .await?;
1712
1713 let collaborators = collaborators
1714 .into_iter()
1715 .map(|collaborator| ProjectCollaborator {
1716 connection_id: collaborator.connection(),
1717 user_id: collaborator.user_id,
1718 replica_id: collaborator.replica_id,
1719 is_host: collaborator.is_host,
1720 })
1721 .collect::<Vec<_>>();
1722
1723 rejoined_projects.push(RejoinedProject {
1724 id: project_id,
1725 old_connection_id,
1726 collaborators,
1727 worktrees,
1728 language_servers,
1729 });
1730 }
1731
1732 let (channel_id, room) = self.get_channel_room(room_id, &tx).await?;
1733 let channel_members = if let Some(channel_id) = channel_id {
1734 self.get_channel_members_internal(channel_id, &tx).await?
1735 } else {
1736 Vec::new()
1737 };
1738
1739 Ok(RejoinedRoom {
1740 room,
1741 channel_id,
1742 channel_members,
1743 rejoined_projects,
1744 reshared_projects,
1745 })
1746 })
1747 .await
1748 }
1749
1750 pub async fn leave_room(
1751 &self,
1752 connection: ConnectionId,
1753 ) -> Result<Option<RoomGuard<LeftRoom>>> {
1754 self.optional_room_transaction(|tx| async move {
1755 let leaving_participant = room_participant::Entity::find()
1756 .filter(
1757 Condition::all()
1758 .add(
1759 room_participant::Column::AnsweringConnectionId
1760 .eq(connection.id as i32),
1761 )
1762 .add(
1763 room_participant::Column::AnsweringConnectionServerId
1764 .eq(connection.owner_id as i32),
1765 ),
1766 )
1767 .one(&*tx)
1768 .await?;
1769
1770 if let Some(leaving_participant) = leaving_participant {
1771 // Leave room.
1772 let room_id = leaving_participant.room_id;
1773 room_participant::Entity::delete_by_id(leaving_participant.id)
1774 .exec(&*tx)
1775 .await?;
1776
1777 // Cancel pending calls initiated by the leaving user.
1778 let called_participants = room_participant::Entity::find()
1779 .filter(
1780 Condition::all()
1781 .add(
1782 room_participant::Column::CallingUserId
1783 .eq(leaving_participant.user_id),
1784 )
1785 .add(room_participant::Column::AnsweringConnectionId.is_null()),
1786 )
1787 .all(&*tx)
1788 .await?;
1789 room_participant::Entity::delete_many()
1790 .filter(
1791 room_participant::Column::Id
1792 .is_in(called_participants.iter().map(|participant| participant.id)),
1793 )
1794 .exec(&*tx)
1795 .await?;
1796 let canceled_calls_to_user_ids = called_participants
1797 .into_iter()
1798 .map(|participant| participant.user_id)
1799 .collect();
1800
1801 // Detect left projects.
1802 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1803 enum QueryProjectIds {
1804 ProjectId,
1805 }
1806 let project_ids: Vec<ProjectId> = project_collaborator::Entity::find()
1807 .select_only()
1808 .column_as(
1809 project_collaborator::Column::ProjectId,
1810 QueryProjectIds::ProjectId,
1811 )
1812 .filter(
1813 Condition::all()
1814 .add(
1815 project_collaborator::Column::ConnectionId.eq(connection.id as i32),
1816 )
1817 .add(
1818 project_collaborator::Column::ConnectionServerId
1819 .eq(connection.owner_id as i32),
1820 ),
1821 )
1822 .into_values::<_, QueryProjectIds>()
1823 .all(&*tx)
1824 .await?;
1825 let mut left_projects = HashMap::default();
1826 let mut collaborators = project_collaborator::Entity::find()
1827 .filter(project_collaborator::Column::ProjectId.is_in(project_ids))
1828 .stream(&*tx)
1829 .await?;
1830 while let Some(collaborator) = collaborators.next().await {
1831 let collaborator = collaborator?;
1832 let left_project =
1833 left_projects
1834 .entry(collaborator.project_id)
1835 .or_insert(LeftProject {
1836 id: collaborator.project_id,
1837 host_user_id: Default::default(),
1838 connection_ids: Default::default(),
1839 host_connection_id: Default::default(),
1840 });
1841
1842 let collaborator_connection_id = collaborator.connection();
1843 if collaborator_connection_id != connection {
1844 left_project.connection_ids.push(collaborator_connection_id);
1845 }
1846
1847 if collaborator.is_host {
1848 left_project.host_user_id = collaborator.user_id;
1849 left_project.host_connection_id = collaborator_connection_id;
1850 }
1851 }
1852 drop(collaborators);
1853
1854 // Leave projects.
1855 project_collaborator::Entity::delete_many()
1856 .filter(
1857 Condition::all()
1858 .add(
1859 project_collaborator::Column::ConnectionId.eq(connection.id as i32),
1860 )
1861 .add(
1862 project_collaborator::Column::ConnectionServerId
1863 .eq(connection.owner_id as i32),
1864 ),
1865 )
1866 .exec(&*tx)
1867 .await?;
1868
1869 // Unshare projects.
1870 project::Entity::delete_many()
1871 .filter(
1872 Condition::all()
1873 .add(project::Column::RoomId.eq(room_id))
1874 .add(project::Column::HostConnectionId.eq(connection.id as i32))
1875 .add(
1876 project::Column::HostConnectionServerId
1877 .eq(connection.owner_id as i32),
1878 ),
1879 )
1880 .exec(&*tx)
1881 .await?;
1882
1883 let (channel_id, room) = self.get_channel_room(room_id, &tx).await?;
1884 let deleted = if room.participants.is_empty() {
1885 let result = room::Entity::delete_by_id(room_id)
1886 .filter(room::Column::ChannelId.is_null())
1887 .exec(&*tx)
1888 .await?;
1889 result.rows_affected > 0
1890 } else {
1891 false
1892 };
1893
1894 let channel_members = if let Some(channel_id) = channel_id {
1895 self.get_channel_members_internal(channel_id, &tx).await?
1896 } else {
1897 Vec::new()
1898 };
1899 let left_room = LeftRoom {
1900 room,
1901 channel_id,
1902 channel_members,
1903 left_projects,
1904 canceled_calls_to_user_ids,
1905 deleted,
1906 };
1907
1908 if left_room.room.participants.is_empty() {
1909 self.rooms.remove(&room_id);
1910 }
1911
1912 Ok(Some((room_id, left_room)))
1913 } else {
1914 Ok(None)
1915 }
1916 })
1917 .await
1918 }
1919
1920 pub async fn follow(
1921 &self,
1922 project_id: ProjectId,
1923 leader_connection: ConnectionId,
1924 follower_connection: ConnectionId,
1925 ) -> Result<RoomGuard<proto::Room>> {
1926 let room_id = self.room_id_for_project(project_id).await?;
1927 self.room_transaction(room_id, |tx| async move {
1928 follower::ActiveModel {
1929 room_id: ActiveValue::set(room_id),
1930 project_id: ActiveValue::set(project_id),
1931 leader_connection_server_id: ActiveValue::set(ServerId(
1932 leader_connection.owner_id as i32,
1933 )),
1934 leader_connection_id: ActiveValue::set(leader_connection.id as i32),
1935 follower_connection_server_id: ActiveValue::set(ServerId(
1936 follower_connection.owner_id as i32,
1937 )),
1938 follower_connection_id: ActiveValue::set(follower_connection.id as i32),
1939 ..Default::default()
1940 }
1941 .insert(&*tx)
1942 .await?;
1943
1944 let room = self.get_room(room_id, &*tx).await?;
1945 Ok(room)
1946 })
1947 .await
1948 }
1949
1950 pub async fn unfollow(
1951 &self,
1952 project_id: ProjectId,
1953 leader_connection: ConnectionId,
1954 follower_connection: ConnectionId,
1955 ) -> Result<RoomGuard<proto::Room>> {
1956 let room_id = self.room_id_for_project(project_id).await?;
1957 self.room_transaction(room_id, |tx| async move {
1958 follower::Entity::delete_many()
1959 .filter(
1960 Condition::all()
1961 .add(follower::Column::ProjectId.eq(project_id))
1962 .add(
1963 follower::Column::LeaderConnectionServerId
1964 .eq(leader_connection.owner_id),
1965 )
1966 .add(follower::Column::LeaderConnectionId.eq(leader_connection.id))
1967 .add(
1968 follower::Column::FollowerConnectionServerId
1969 .eq(follower_connection.owner_id),
1970 )
1971 .add(follower::Column::FollowerConnectionId.eq(follower_connection.id)),
1972 )
1973 .exec(&*tx)
1974 .await?;
1975
1976 let room = self.get_room(room_id, &*tx).await?;
1977 Ok(room)
1978 })
1979 .await
1980 }
1981
1982 pub async fn update_room_participant_location(
1983 &self,
1984 room_id: RoomId,
1985 connection: ConnectionId,
1986 location: proto::ParticipantLocation,
1987 ) -> Result<RoomGuard<proto::Room>> {
1988 self.room_transaction(room_id, |tx| async {
1989 let tx = tx;
1990 let location_kind;
1991 let location_project_id;
1992 match location
1993 .variant
1994 .as_ref()
1995 .ok_or_else(|| anyhow!("invalid location"))?
1996 {
1997 proto::participant_location::Variant::SharedProject(project) => {
1998 location_kind = 0;
1999 location_project_id = Some(ProjectId::from_proto(project.id));
2000 }
2001 proto::participant_location::Variant::UnsharedProject(_) => {
2002 location_kind = 1;
2003 location_project_id = None;
2004 }
2005 proto::participant_location::Variant::External(_) => {
2006 location_kind = 2;
2007 location_project_id = None;
2008 }
2009 }
2010
2011 let result = room_participant::Entity::update_many()
2012 .filter(
2013 Condition::all()
2014 .add(room_participant::Column::RoomId.eq(room_id))
2015 .add(
2016 room_participant::Column::AnsweringConnectionId
2017 .eq(connection.id as i32),
2018 )
2019 .add(
2020 room_participant::Column::AnsweringConnectionServerId
2021 .eq(connection.owner_id as i32),
2022 ),
2023 )
2024 .set(room_participant::ActiveModel {
2025 location_kind: ActiveValue::set(Some(location_kind)),
2026 location_project_id: ActiveValue::set(location_project_id),
2027 ..Default::default()
2028 })
2029 .exec(&*tx)
2030 .await?;
2031
2032 if result.rows_affected == 1 {
2033 let room = self.get_room(room_id, &tx).await?;
2034 Ok(room)
2035 } else {
2036 Err(anyhow!("could not update room participant location"))?
2037 }
2038 })
2039 .await
2040 }
2041
2042 pub async fn connection_lost(&self, connection: ConnectionId) -> Result<()> {
2043 self.transaction(|tx| async move {
2044 let participant = room_participant::Entity::find()
2045 .filter(
2046 Condition::all()
2047 .add(
2048 room_participant::Column::AnsweringConnectionId
2049 .eq(connection.id as i32),
2050 )
2051 .add(
2052 room_participant::Column::AnsweringConnectionServerId
2053 .eq(connection.owner_id as i32),
2054 ),
2055 )
2056 .one(&*tx)
2057 .await?
2058 .ok_or_else(|| anyhow!("not a participant in any room"))?;
2059
2060 room_participant::Entity::update(room_participant::ActiveModel {
2061 answering_connection_lost: ActiveValue::set(true),
2062 ..participant.into_active_model()
2063 })
2064 .exec(&*tx)
2065 .await?;
2066
2067 Ok(())
2068 })
2069 .await
2070 }
2071
2072 fn build_incoming_call(
2073 room: &proto::Room,
2074 called_user_id: UserId,
2075 ) -> Option<proto::IncomingCall> {
2076 let pending_participant = room
2077 .pending_participants
2078 .iter()
2079 .find(|participant| participant.user_id == called_user_id.to_proto())?;
2080
2081 Some(proto::IncomingCall {
2082 room_id: room.id,
2083 calling_user_id: pending_participant.calling_user_id,
2084 participant_user_ids: room
2085 .participants
2086 .iter()
2087 .map(|participant| participant.user_id)
2088 .collect(),
2089 initial_project: room.participants.iter().find_map(|participant| {
2090 let initial_project_id = pending_participant.initial_project_id?;
2091 participant
2092 .projects
2093 .iter()
2094 .find(|project| project.id == initial_project_id)
2095 .cloned()
2096 }),
2097 })
2098 }
2099 async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
2100 let (_, room) = self.get_channel_room(room_id, tx).await?;
2101 Ok(room)
2102 }
2103
2104 async fn get_channel_room(
2105 &self,
2106 room_id: RoomId,
2107 tx: &DatabaseTransaction,
2108 ) -> Result<(Option<ChannelId>, proto::Room)> {
2109 let db_room = room::Entity::find_by_id(room_id)
2110 .one(tx)
2111 .await?
2112 .ok_or_else(|| anyhow!("could not find room"))?;
2113
2114 let mut db_participants = db_room
2115 .find_related(room_participant::Entity)
2116 .stream(tx)
2117 .await?;
2118 let mut participants = HashMap::default();
2119 let mut pending_participants = Vec::new();
2120 while let Some(db_participant) = db_participants.next().await {
2121 let db_participant = db_participant?;
2122 if let Some((answering_connection_id, answering_connection_server_id)) = db_participant
2123 .answering_connection_id
2124 .zip(db_participant.answering_connection_server_id)
2125 {
2126 let location = match (
2127 db_participant.location_kind,
2128 db_participant.location_project_id,
2129 ) {
2130 (Some(0), Some(project_id)) => {
2131 Some(proto::participant_location::Variant::SharedProject(
2132 proto::participant_location::SharedProject {
2133 id: project_id.to_proto(),
2134 },
2135 ))
2136 }
2137 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
2138 Default::default(),
2139 )),
2140 _ => Some(proto::participant_location::Variant::External(
2141 Default::default(),
2142 )),
2143 };
2144
2145 let answering_connection = ConnectionId {
2146 owner_id: answering_connection_server_id.0 as u32,
2147 id: answering_connection_id as u32,
2148 };
2149 participants.insert(
2150 answering_connection,
2151 proto::Participant {
2152 user_id: db_participant.user_id.to_proto(),
2153 peer_id: Some(answering_connection.into()),
2154 projects: Default::default(),
2155 location: Some(proto::ParticipantLocation { variant: location }),
2156 },
2157 );
2158 } else {
2159 pending_participants.push(proto::PendingParticipant {
2160 user_id: db_participant.user_id.to_proto(),
2161 calling_user_id: db_participant.calling_user_id.to_proto(),
2162 initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
2163 });
2164 }
2165 }
2166 drop(db_participants);
2167
2168 let mut db_projects = db_room
2169 .find_related(project::Entity)
2170 .find_with_related(worktree::Entity)
2171 .stream(tx)
2172 .await?;
2173
2174 while let Some(row) = db_projects.next().await {
2175 let (db_project, db_worktree) = row?;
2176 let host_connection = db_project.host_connection()?;
2177 if let Some(participant) = participants.get_mut(&host_connection) {
2178 let project = if let Some(project) = participant
2179 .projects
2180 .iter_mut()
2181 .find(|project| project.id == db_project.id.to_proto())
2182 {
2183 project
2184 } else {
2185 participant.projects.push(proto::ParticipantProject {
2186 id: db_project.id.to_proto(),
2187 worktree_root_names: Default::default(),
2188 });
2189 participant.projects.last_mut().unwrap()
2190 };
2191
2192 if let Some(db_worktree) = db_worktree {
2193 if db_worktree.visible {
2194 project.worktree_root_names.push(db_worktree.root_name);
2195 }
2196 }
2197 }
2198 }
2199 drop(db_projects);
2200
2201 let mut db_followers = db_room.find_related(follower::Entity).stream(tx).await?;
2202 let mut followers = Vec::new();
2203 while let Some(db_follower) = db_followers.next().await {
2204 let db_follower = db_follower?;
2205 followers.push(proto::Follower {
2206 leader_id: Some(db_follower.leader_connection().into()),
2207 follower_id: Some(db_follower.follower_connection().into()),
2208 project_id: db_follower.project_id.to_proto(),
2209 });
2210 }
2211
2212 Ok((
2213 db_room.channel_id,
2214 proto::Room {
2215 id: db_room.id.to_proto(),
2216 live_kit_room: db_room.live_kit_room,
2217 participants: participants.into_values().collect(),
2218 pending_participants,
2219 followers,
2220 },
2221 ))
2222 }
2223
2224 // projects
2225
2226 pub async fn project_count_excluding_admins(&self) -> Result<usize> {
2227 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
2228 enum QueryAs {
2229 Count,
2230 }
2231
2232 self.transaction(|tx| async move {
2233 Ok(project::Entity::find()
2234 .select_only()
2235 .column_as(project::Column::Id.count(), QueryAs::Count)
2236 .inner_join(user::Entity)
2237 .filter(user::Column::Admin.eq(false))
2238 .into_values::<_, QueryAs>()
2239 .one(&*tx)
2240 .await?
2241 .unwrap_or(0i64) as usize)
2242 })
2243 .await
2244 }
2245
2246 pub async fn share_project(
2247 &self,
2248 room_id: RoomId,
2249 connection: ConnectionId,
2250 worktrees: &[proto::WorktreeMetadata],
2251 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
2252 self.room_transaction(room_id, |tx| async move {
2253 let participant = room_participant::Entity::find()
2254 .filter(
2255 Condition::all()
2256 .add(
2257 room_participant::Column::AnsweringConnectionId
2258 .eq(connection.id as i32),
2259 )
2260 .add(
2261 room_participant::Column::AnsweringConnectionServerId
2262 .eq(connection.owner_id as i32),
2263 ),
2264 )
2265 .one(&*tx)
2266 .await?
2267 .ok_or_else(|| anyhow!("could not find participant"))?;
2268 if participant.room_id != room_id {
2269 return Err(anyhow!("shared project on unexpected room"))?;
2270 }
2271
2272 let project = project::ActiveModel {
2273 room_id: ActiveValue::set(participant.room_id),
2274 host_user_id: ActiveValue::set(participant.user_id),
2275 host_connection_id: ActiveValue::set(Some(connection.id as i32)),
2276 host_connection_server_id: ActiveValue::set(Some(ServerId(
2277 connection.owner_id as i32,
2278 ))),
2279 ..Default::default()
2280 }
2281 .insert(&*tx)
2282 .await?;
2283
2284 if !worktrees.is_empty() {
2285 worktree::Entity::insert_many(worktrees.iter().map(|worktree| {
2286 worktree::ActiveModel {
2287 id: ActiveValue::set(worktree.id as i64),
2288 project_id: ActiveValue::set(project.id),
2289 abs_path: ActiveValue::set(worktree.abs_path.clone()),
2290 root_name: ActiveValue::set(worktree.root_name.clone()),
2291 visible: ActiveValue::set(worktree.visible),
2292 scan_id: ActiveValue::set(0),
2293 completed_scan_id: ActiveValue::set(0),
2294 }
2295 }))
2296 .exec(&*tx)
2297 .await?;
2298 }
2299
2300 project_collaborator::ActiveModel {
2301 project_id: ActiveValue::set(project.id),
2302 connection_id: ActiveValue::set(connection.id as i32),
2303 connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)),
2304 user_id: ActiveValue::set(participant.user_id),
2305 replica_id: ActiveValue::set(ReplicaId(0)),
2306 is_host: ActiveValue::set(true),
2307 ..Default::default()
2308 }
2309 .insert(&*tx)
2310 .await?;
2311
2312 let room = self.get_room(room_id, &tx).await?;
2313 Ok((project.id, room))
2314 })
2315 .await
2316 }
2317
2318 pub async fn unshare_project(
2319 &self,
2320 project_id: ProjectId,
2321 connection: ConnectionId,
2322 ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
2323 let room_id = self.room_id_for_project(project_id).await?;
2324 self.room_transaction(room_id, |tx| async move {
2325 let guest_connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
2326
2327 let project = project::Entity::find_by_id(project_id)
2328 .one(&*tx)
2329 .await?
2330 .ok_or_else(|| anyhow!("project not found"))?;
2331 if project.host_connection()? == connection {
2332 project::Entity::delete(project.into_active_model())
2333 .exec(&*tx)
2334 .await?;
2335 let room = self.get_room(room_id, &tx).await?;
2336 Ok((room, guest_connection_ids))
2337 } else {
2338 Err(anyhow!("cannot unshare a project hosted by another user"))?
2339 }
2340 })
2341 .await
2342 }
2343
2344 pub async fn update_project(
2345 &self,
2346 project_id: ProjectId,
2347 connection: ConnectionId,
2348 worktrees: &[proto::WorktreeMetadata],
2349 ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
2350 let room_id = self.room_id_for_project(project_id).await?;
2351 self.room_transaction(room_id, |tx| async move {
2352 let project = project::Entity::find_by_id(project_id)
2353 .filter(
2354 Condition::all()
2355 .add(project::Column::HostConnectionId.eq(connection.id as i32))
2356 .add(
2357 project::Column::HostConnectionServerId.eq(connection.owner_id as i32),
2358 ),
2359 )
2360 .one(&*tx)
2361 .await?
2362 .ok_or_else(|| anyhow!("no such project"))?;
2363
2364 self.update_project_worktrees(project.id, worktrees, &tx)
2365 .await?;
2366
2367 let guest_connection_ids = self.project_guest_connection_ids(project.id, &tx).await?;
2368 let room = self.get_room(project.room_id, &tx).await?;
2369 Ok((room, guest_connection_ids))
2370 })
2371 .await
2372 }
2373
2374 async fn update_project_worktrees(
2375 &self,
2376 project_id: ProjectId,
2377 worktrees: &[proto::WorktreeMetadata],
2378 tx: &DatabaseTransaction,
2379 ) -> Result<()> {
2380 if !worktrees.is_empty() {
2381 worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
2382 id: ActiveValue::set(worktree.id as i64),
2383 project_id: ActiveValue::set(project_id),
2384 abs_path: ActiveValue::set(worktree.abs_path.clone()),
2385 root_name: ActiveValue::set(worktree.root_name.clone()),
2386 visible: ActiveValue::set(worktree.visible),
2387 scan_id: ActiveValue::set(0),
2388 completed_scan_id: ActiveValue::set(0),
2389 }))
2390 .on_conflict(
2391 OnConflict::columns([worktree::Column::ProjectId, worktree::Column::Id])
2392 .update_column(worktree::Column::RootName)
2393 .to_owned(),
2394 )
2395 .exec(&*tx)
2396 .await?;
2397 }
2398
2399 worktree::Entity::delete_many()
2400 .filter(worktree::Column::ProjectId.eq(project_id).and(
2401 worktree::Column::Id.is_not_in(worktrees.iter().map(|worktree| worktree.id as i64)),
2402 ))
2403 .exec(&*tx)
2404 .await?;
2405
2406 Ok(())
2407 }
2408
2409 pub async fn update_worktree(
2410 &self,
2411 update: &proto::UpdateWorktree,
2412 connection: ConnectionId,
2413 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
2414 let project_id = ProjectId::from_proto(update.project_id);
2415 let worktree_id = update.worktree_id as i64;
2416 let room_id = self.room_id_for_project(project_id).await?;
2417 self.room_transaction(room_id, |tx| async move {
2418 // Ensure the update comes from the host.
2419 let _project = project::Entity::find_by_id(project_id)
2420 .filter(
2421 Condition::all()
2422 .add(project::Column::HostConnectionId.eq(connection.id as i32))
2423 .add(
2424 project::Column::HostConnectionServerId.eq(connection.owner_id as i32),
2425 ),
2426 )
2427 .one(&*tx)
2428 .await?
2429 .ok_or_else(|| anyhow!("no such project"))?;
2430
2431 // Update metadata.
2432 worktree::Entity::update(worktree::ActiveModel {
2433 id: ActiveValue::set(worktree_id),
2434 project_id: ActiveValue::set(project_id),
2435 root_name: ActiveValue::set(update.root_name.clone()),
2436 scan_id: ActiveValue::set(update.scan_id as i64),
2437 completed_scan_id: if update.is_last_update {
2438 ActiveValue::set(update.scan_id as i64)
2439 } else {
2440 ActiveValue::default()
2441 },
2442 abs_path: ActiveValue::set(update.abs_path.clone()),
2443 ..Default::default()
2444 })
2445 .exec(&*tx)
2446 .await?;
2447
2448 if !update.updated_entries.is_empty() {
2449 worktree_entry::Entity::insert_many(update.updated_entries.iter().map(|entry| {
2450 let mtime = entry.mtime.clone().unwrap_or_default();
2451 worktree_entry::ActiveModel {
2452 project_id: ActiveValue::set(project_id),
2453 worktree_id: ActiveValue::set(worktree_id),
2454 id: ActiveValue::set(entry.id as i64),
2455 is_dir: ActiveValue::set(entry.is_dir),
2456 path: ActiveValue::set(entry.path.clone()),
2457 inode: ActiveValue::set(entry.inode as i64),
2458 mtime_seconds: ActiveValue::set(mtime.seconds as i64),
2459 mtime_nanos: ActiveValue::set(mtime.nanos as i32),
2460 is_symlink: ActiveValue::set(entry.is_symlink),
2461 is_ignored: ActiveValue::set(entry.is_ignored),
2462 is_external: ActiveValue::set(entry.is_external),
2463 git_status: ActiveValue::set(entry.git_status.map(|status| status as i64)),
2464 is_deleted: ActiveValue::set(false),
2465 scan_id: ActiveValue::set(update.scan_id as i64),
2466 }
2467 }))
2468 .on_conflict(
2469 OnConflict::columns([
2470 worktree_entry::Column::ProjectId,
2471 worktree_entry::Column::WorktreeId,
2472 worktree_entry::Column::Id,
2473 ])
2474 .update_columns([
2475 worktree_entry::Column::IsDir,
2476 worktree_entry::Column::Path,
2477 worktree_entry::Column::Inode,
2478 worktree_entry::Column::MtimeSeconds,
2479 worktree_entry::Column::MtimeNanos,
2480 worktree_entry::Column::IsSymlink,
2481 worktree_entry::Column::IsIgnored,
2482 worktree_entry::Column::GitStatus,
2483 worktree_entry::Column::ScanId,
2484 ])
2485 .to_owned(),
2486 )
2487 .exec(&*tx)
2488 .await?;
2489 }
2490
2491 if !update.removed_entries.is_empty() {
2492 worktree_entry::Entity::update_many()
2493 .filter(
2494 worktree_entry::Column::ProjectId
2495 .eq(project_id)
2496 .and(worktree_entry::Column::WorktreeId.eq(worktree_id))
2497 .and(
2498 worktree_entry::Column::Id
2499 .is_in(update.removed_entries.iter().map(|id| *id as i64)),
2500 ),
2501 )
2502 .set(worktree_entry::ActiveModel {
2503 is_deleted: ActiveValue::Set(true),
2504 scan_id: ActiveValue::Set(update.scan_id as i64),
2505 ..Default::default()
2506 })
2507 .exec(&*tx)
2508 .await?;
2509 }
2510
2511 if !update.updated_repositories.is_empty() {
2512 worktree_repository::Entity::insert_many(update.updated_repositories.iter().map(
2513 |repository| worktree_repository::ActiveModel {
2514 project_id: ActiveValue::set(project_id),
2515 worktree_id: ActiveValue::set(worktree_id),
2516 work_directory_id: ActiveValue::set(repository.work_directory_id as i64),
2517 scan_id: ActiveValue::set(update.scan_id as i64),
2518 branch: ActiveValue::set(repository.branch.clone()),
2519 is_deleted: ActiveValue::set(false),
2520 },
2521 ))
2522 .on_conflict(
2523 OnConflict::columns([
2524 worktree_repository::Column::ProjectId,
2525 worktree_repository::Column::WorktreeId,
2526 worktree_repository::Column::WorkDirectoryId,
2527 ])
2528 .update_columns([
2529 worktree_repository::Column::ScanId,
2530 worktree_repository::Column::Branch,
2531 ])
2532 .to_owned(),
2533 )
2534 .exec(&*tx)
2535 .await?;
2536 }
2537
2538 if !update.removed_repositories.is_empty() {
2539 worktree_repository::Entity::update_many()
2540 .filter(
2541 worktree_repository::Column::ProjectId
2542 .eq(project_id)
2543 .and(worktree_repository::Column::WorktreeId.eq(worktree_id))
2544 .and(
2545 worktree_repository::Column::WorkDirectoryId
2546 .is_in(update.removed_repositories.iter().map(|id| *id as i64)),
2547 ),
2548 )
2549 .set(worktree_repository::ActiveModel {
2550 is_deleted: ActiveValue::Set(true),
2551 scan_id: ActiveValue::Set(update.scan_id as i64),
2552 ..Default::default()
2553 })
2554 .exec(&*tx)
2555 .await?;
2556 }
2557
2558 let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
2559 Ok(connection_ids)
2560 })
2561 .await
2562 }
2563
2564 pub async fn update_diagnostic_summary(
2565 &self,
2566 update: &proto::UpdateDiagnosticSummary,
2567 connection: ConnectionId,
2568 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
2569 let project_id = ProjectId::from_proto(update.project_id);
2570 let worktree_id = update.worktree_id as i64;
2571 let room_id = self.room_id_for_project(project_id).await?;
2572 self.room_transaction(room_id, |tx| async move {
2573 let summary = update
2574 .summary
2575 .as_ref()
2576 .ok_or_else(|| anyhow!("invalid summary"))?;
2577
2578 // Ensure the update comes from the host.
2579 let project = project::Entity::find_by_id(project_id)
2580 .one(&*tx)
2581 .await?
2582 .ok_or_else(|| anyhow!("no such project"))?;
2583 if project.host_connection()? != connection {
2584 return Err(anyhow!("can't update a project hosted by someone else"))?;
2585 }
2586
2587 // Update summary.
2588 worktree_diagnostic_summary::Entity::insert(worktree_diagnostic_summary::ActiveModel {
2589 project_id: ActiveValue::set(project_id),
2590 worktree_id: ActiveValue::set(worktree_id),
2591 path: ActiveValue::set(summary.path.clone()),
2592 language_server_id: ActiveValue::set(summary.language_server_id as i64),
2593 error_count: ActiveValue::set(summary.error_count as i32),
2594 warning_count: ActiveValue::set(summary.warning_count as i32),
2595 ..Default::default()
2596 })
2597 .on_conflict(
2598 OnConflict::columns([
2599 worktree_diagnostic_summary::Column::ProjectId,
2600 worktree_diagnostic_summary::Column::WorktreeId,
2601 worktree_diagnostic_summary::Column::Path,
2602 ])
2603 .update_columns([
2604 worktree_diagnostic_summary::Column::LanguageServerId,
2605 worktree_diagnostic_summary::Column::ErrorCount,
2606 worktree_diagnostic_summary::Column::WarningCount,
2607 ])
2608 .to_owned(),
2609 )
2610 .exec(&*tx)
2611 .await?;
2612
2613 let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
2614 Ok(connection_ids)
2615 })
2616 .await
2617 }
2618
2619 pub async fn start_language_server(
2620 &self,
2621 update: &proto::StartLanguageServer,
2622 connection: ConnectionId,
2623 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
2624 let project_id = ProjectId::from_proto(update.project_id);
2625 let room_id = self.room_id_for_project(project_id).await?;
2626 self.room_transaction(room_id, |tx| async move {
2627 let server = update
2628 .server
2629 .as_ref()
2630 .ok_or_else(|| anyhow!("invalid language server"))?;
2631
2632 // Ensure the update comes from the host.
2633 let project = project::Entity::find_by_id(project_id)
2634 .one(&*tx)
2635 .await?
2636 .ok_or_else(|| anyhow!("no such project"))?;
2637 if project.host_connection()? != connection {
2638 return Err(anyhow!("can't update a project hosted by someone else"))?;
2639 }
2640
2641 // Add the newly-started language server.
2642 language_server::Entity::insert(language_server::ActiveModel {
2643 project_id: ActiveValue::set(project_id),
2644 id: ActiveValue::set(server.id as i64),
2645 name: ActiveValue::set(server.name.clone()),
2646 ..Default::default()
2647 })
2648 .on_conflict(
2649 OnConflict::columns([
2650 language_server::Column::ProjectId,
2651 language_server::Column::Id,
2652 ])
2653 .update_column(language_server::Column::Name)
2654 .to_owned(),
2655 )
2656 .exec(&*tx)
2657 .await?;
2658
2659 let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
2660 Ok(connection_ids)
2661 })
2662 .await
2663 }
2664
2665 pub async fn update_worktree_settings(
2666 &self,
2667 update: &proto::UpdateWorktreeSettings,
2668 connection: ConnectionId,
2669 ) -> Result<RoomGuard<Vec<ConnectionId>>> {
2670 let project_id = ProjectId::from_proto(update.project_id);
2671 let room_id = self.room_id_for_project(project_id).await?;
2672 self.room_transaction(room_id, |tx| async move {
2673 // Ensure the update comes from the host.
2674 let project = project::Entity::find_by_id(project_id)
2675 .one(&*tx)
2676 .await?
2677 .ok_or_else(|| anyhow!("no such project"))?;
2678 if project.host_connection()? != connection {
2679 return Err(anyhow!("can't update a project hosted by someone else"))?;
2680 }
2681
2682 if let Some(content) = &update.content {
2683 worktree_settings_file::Entity::insert(worktree_settings_file::ActiveModel {
2684 project_id: ActiveValue::Set(project_id),
2685 worktree_id: ActiveValue::Set(update.worktree_id as i64),
2686 path: ActiveValue::Set(update.path.clone()),
2687 content: ActiveValue::Set(content.clone()),
2688 })
2689 .on_conflict(
2690 OnConflict::columns([
2691 worktree_settings_file::Column::ProjectId,
2692 worktree_settings_file::Column::WorktreeId,
2693 worktree_settings_file::Column::Path,
2694 ])
2695 .update_column(worktree_settings_file::Column::Content)
2696 .to_owned(),
2697 )
2698 .exec(&*tx)
2699 .await?;
2700 } else {
2701 worktree_settings_file::Entity::delete(worktree_settings_file::ActiveModel {
2702 project_id: ActiveValue::Set(project_id),
2703 worktree_id: ActiveValue::Set(update.worktree_id as i64),
2704 path: ActiveValue::Set(update.path.clone()),
2705 ..Default::default()
2706 })
2707 .exec(&*tx)
2708 .await?;
2709 }
2710
2711 let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
2712 Ok(connection_ids)
2713 })
2714 .await
2715 }
2716
2717 pub async fn join_project(
2718 &self,
2719 project_id: ProjectId,
2720 connection: ConnectionId,
2721 ) -> Result<RoomGuard<(Project, ReplicaId)>> {
2722 let room_id = self.room_id_for_project(project_id).await?;
2723 self.room_transaction(room_id, |tx| async move {
2724 let participant = room_participant::Entity::find()
2725 .filter(
2726 Condition::all()
2727 .add(
2728 room_participant::Column::AnsweringConnectionId
2729 .eq(connection.id as i32),
2730 )
2731 .add(
2732 room_participant::Column::AnsweringConnectionServerId
2733 .eq(connection.owner_id as i32),
2734 ),
2735 )
2736 .one(&*tx)
2737 .await?
2738 .ok_or_else(|| anyhow!("must join a room first"))?;
2739
2740 let project = project::Entity::find_by_id(project_id)
2741 .one(&*tx)
2742 .await?
2743 .ok_or_else(|| anyhow!("no such project"))?;
2744 if project.room_id != participant.room_id {
2745 return Err(anyhow!("no such project"))?;
2746 }
2747
2748 let mut collaborators = project
2749 .find_related(project_collaborator::Entity)
2750 .all(&*tx)
2751 .await?;
2752 let replica_ids = collaborators
2753 .iter()
2754 .map(|c| c.replica_id)
2755 .collect::<HashSet<_>>();
2756 let mut replica_id = ReplicaId(1);
2757 while replica_ids.contains(&replica_id) {
2758 replica_id.0 += 1;
2759 }
2760 let new_collaborator = project_collaborator::ActiveModel {
2761 project_id: ActiveValue::set(project_id),
2762 connection_id: ActiveValue::set(connection.id as i32),
2763 connection_server_id: ActiveValue::set(ServerId(connection.owner_id as i32)),
2764 user_id: ActiveValue::set(participant.user_id),
2765 replica_id: ActiveValue::set(replica_id),
2766 is_host: ActiveValue::set(false),
2767 ..Default::default()
2768 }
2769 .insert(&*tx)
2770 .await?;
2771 collaborators.push(new_collaborator);
2772
2773 let db_worktrees = project.find_related(worktree::Entity).all(&*tx).await?;
2774 let mut worktrees = db_worktrees
2775 .into_iter()
2776 .map(|db_worktree| {
2777 (
2778 db_worktree.id as u64,
2779 Worktree {
2780 id: db_worktree.id as u64,
2781 abs_path: db_worktree.abs_path,
2782 root_name: db_worktree.root_name,
2783 visible: db_worktree.visible,
2784 entries: Default::default(),
2785 repository_entries: Default::default(),
2786 diagnostic_summaries: Default::default(),
2787 settings_files: Default::default(),
2788 scan_id: db_worktree.scan_id as u64,
2789 completed_scan_id: db_worktree.completed_scan_id as u64,
2790 },
2791 )
2792 })
2793 .collect::<BTreeMap<_, _>>();
2794
2795 // Populate worktree entries.
2796 {
2797 let mut db_entries = worktree_entry::Entity::find()
2798 .filter(
2799 Condition::all()
2800 .add(worktree_entry::Column::ProjectId.eq(project_id))
2801 .add(worktree_entry::Column::IsDeleted.eq(false)),
2802 )
2803 .stream(&*tx)
2804 .await?;
2805 while let Some(db_entry) = db_entries.next().await {
2806 let db_entry = db_entry?;
2807 if let Some(worktree) = worktrees.get_mut(&(db_entry.worktree_id as u64)) {
2808 worktree.entries.push(proto::Entry {
2809 id: db_entry.id as u64,
2810 is_dir: db_entry.is_dir,
2811 path: db_entry.path,
2812 inode: db_entry.inode as u64,
2813 mtime: Some(proto::Timestamp {
2814 seconds: db_entry.mtime_seconds as u64,
2815 nanos: db_entry.mtime_nanos as u32,
2816 }),
2817 is_symlink: db_entry.is_symlink,
2818 is_ignored: db_entry.is_ignored,
2819 is_external: db_entry.is_external,
2820 git_status: db_entry.git_status.map(|status| status as i32),
2821 });
2822 }
2823 }
2824 }
2825
2826 // Populate repository entries.
2827 {
2828 let mut db_repository_entries = worktree_repository::Entity::find()
2829 .filter(
2830 Condition::all()
2831 .add(worktree_repository::Column::ProjectId.eq(project_id))
2832 .add(worktree_repository::Column::IsDeleted.eq(false)),
2833 )
2834 .stream(&*tx)
2835 .await?;
2836 while let Some(db_repository_entry) = db_repository_entries.next().await {
2837 let db_repository_entry = db_repository_entry?;
2838 if let Some(worktree) =
2839 worktrees.get_mut(&(db_repository_entry.worktree_id as u64))
2840 {
2841 worktree.repository_entries.insert(
2842 db_repository_entry.work_directory_id as u64,
2843 proto::RepositoryEntry {
2844 work_directory_id: db_repository_entry.work_directory_id as u64,
2845 branch: db_repository_entry.branch,
2846 },
2847 );
2848 }
2849 }
2850 }
2851
2852 // Populate worktree diagnostic summaries.
2853 {
2854 let mut db_summaries = worktree_diagnostic_summary::Entity::find()
2855 .filter(worktree_diagnostic_summary::Column::ProjectId.eq(project_id))
2856 .stream(&*tx)
2857 .await?;
2858 while let Some(db_summary) = db_summaries.next().await {
2859 let db_summary = db_summary?;
2860 if let Some(worktree) = worktrees.get_mut(&(db_summary.worktree_id as u64)) {
2861 worktree
2862 .diagnostic_summaries
2863 .push(proto::DiagnosticSummary {
2864 path: db_summary.path,
2865 language_server_id: db_summary.language_server_id as u64,
2866 error_count: db_summary.error_count as u32,
2867 warning_count: db_summary.warning_count as u32,
2868 });
2869 }
2870 }
2871 }
2872
2873 // Populate worktree settings files
2874 {
2875 let mut db_settings_files = worktree_settings_file::Entity::find()
2876 .filter(worktree_settings_file::Column::ProjectId.eq(project_id))
2877 .stream(&*tx)
2878 .await?;
2879 while let Some(db_settings_file) = db_settings_files.next().await {
2880 let db_settings_file = db_settings_file?;
2881 if let Some(worktree) =
2882 worktrees.get_mut(&(db_settings_file.worktree_id as u64))
2883 {
2884 worktree.settings_files.push(WorktreeSettingsFile {
2885 path: db_settings_file.path,
2886 content: db_settings_file.content,
2887 });
2888 }
2889 }
2890 }
2891
2892 // Populate language servers.
2893 let language_servers = project
2894 .find_related(language_server::Entity)
2895 .all(&*tx)
2896 .await?;
2897
2898 let project = Project {
2899 collaborators: collaborators
2900 .into_iter()
2901 .map(|collaborator| ProjectCollaborator {
2902 connection_id: collaborator.connection(),
2903 user_id: collaborator.user_id,
2904 replica_id: collaborator.replica_id,
2905 is_host: collaborator.is_host,
2906 })
2907 .collect(),
2908 worktrees,
2909 language_servers: language_servers
2910 .into_iter()
2911 .map(|language_server| proto::LanguageServer {
2912 id: language_server.id as u64,
2913 name: language_server.name,
2914 })
2915 .collect(),
2916 };
2917 Ok((project, replica_id as ReplicaId))
2918 })
2919 .await
2920 }
2921
2922 pub async fn leave_project(
2923 &self,
2924 project_id: ProjectId,
2925 connection: ConnectionId,
2926 ) -> Result<RoomGuard<(proto::Room, LeftProject)>> {
2927 let room_id = self.room_id_for_project(project_id).await?;
2928 self.room_transaction(room_id, |tx| async move {
2929 let result = project_collaborator::Entity::delete_many()
2930 .filter(
2931 Condition::all()
2932 .add(project_collaborator::Column::ProjectId.eq(project_id))
2933 .add(project_collaborator::Column::ConnectionId.eq(connection.id as i32))
2934 .add(
2935 project_collaborator::Column::ConnectionServerId
2936 .eq(connection.owner_id as i32),
2937 ),
2938 )
2939 .exec(&*tx)
2940 .await?;
2941 if result.rows_affected == 0 {
2942 Err(anyhow!("not a collaborator on this project"))?;
2943 }
2944
2945 let project = project::Entity::find_by_id(project_id)
2946 .one(&*tx)
2947 .await?
2948 .ok_or_else(|| anyhow!("no such project"))?;
2949 let collaborators = project
2950 .find_related(project_collaborator::Entity)
2951 .all(&*tx)
2952 .await?;
2953 let connection_ids = collaborators
2954 .into_iter()
2955 .map(|collaborator| collaborator.connection())
2956 .collect();
2957
2958 follower::Entity::delete_many()
2959 .filter(
2960 Condition::any()
2961 .add(
2962 Condition::all()
2963 .add(follower::Column::ProjectId.eq(project_id))
2964 .add(
2965 follower::Column::LeaderConnectionServerId
2966 .eq(connection.owner_id),
2967 )
2968 .add(follower::Column::LeaderConnectionId.eq(connection.id)),
2969 )
2970 .add(
2971 Condition::all()
2972 .add(follower::Column::ProjectId.eq(project_id))
2973 .add(
2974 follower::Column::FollowerConnectionServerId
2975 .eq(connection.owner_id),
2976 )
2977 .add(follower::Column::FollowerConnectionId.eq(connection.id)),
2978 ),
2979 )
2980 .exec(&*tx)
2981 .await?;
2982
2983 let room = self.get_room(project.room_id, &tx).await?;
2984 let left_project = LeftProject {
2985 id: project_id,
2986 host_user_id: project.host_user_id,
2987 host_connection_id: project.host_connection()?,
2988 connection_ids,
2989 };
2990 Ok((room, left_project))
2991 })
2992 .await
2993 }
2994
2995 pub async fn project_collaborators(
2996 &self,
2997 project_id: ProjectId,
2998 connection_id: ConnectionId,
2999 ) -> Result<RoomGuard<Vec<ProjectCollaborator>>> {
3000 let room_id = self.room_id_for_project(project_id).await?;
3001 self.room_transaction(room_id, |tx| async move {
3002 let collaborators = project_collaborator::Entity::find()
3003 .filter(project_collaborator::Column::ProjectId.eq(project_id))
3004 .all(&*tx)
3005 .await?
3006 .into_iter()
3007 .map(|collaborator| ProjectCollaborator {
3008 connection_id: collaborator.connection(),
3009 user_id: collaborator.user_id,
3010 replica_id: collaborator.replica_id,
3011 is_host: collaborator.is_host,
3012 })
3013 .collect::<Vec<_>>();
3014
3015 if collaborators
3016 .iter()
3017 .any(|collaborator| collaborator.connection_id == connection_id)
3018 {
3019 Ok(collaborators)
3020 } else {
3021 Err(anyhow!("no such project"))?
3022 }
3023 })
3024 .await
3025 }
3026
3027 pub async fn project_connection_ids(
3028 &self,
3029 project_id: ProjectId,
3030 connection_id: ConnectionId,
3031 ) -> Result<RoomGuard<HashSet<ConnectionId>>> {
3032 let room_id = self.room_id_for_project(project_id).await?;
3033 self.room_transaction(room_id, |tx| async move {
3034 let mut collaborators = project_collaborator::Entity::find()
3035 .filter(project_collaborator::Column::ProjectId.eq(project_id))
3036 .stream(&*tx)
3037 .await?;
3038
3039 let mut connection_ids = HashSet::default();
3040 while let Some(collaborator) = collaborators.next().await {
3041 let collaborator = collaborator?;
3042 connection_ids.insert(collaborator.connection());
3043 }
3044
3045 if connection_ids.contains(&connection_id) {
3046 Ok(connection_ids)
3047 } else {
3048 Err(anyhow!("no such project"))?
3049 }
3050 })
3051 .await
3052 }
3053
3054 async fn project_guest_connection_ids(
3055 &self,
3056 project_id: ProjectId,
3057 tx: &DatabaseTransaction,
3058 ) -> Result<Vec<ConnectionId>> {
3059 let mut collaborators = project_collaborator::Entity::find()
3060 .filter(
3061 project_collaborator::Column::ProjectId
3062 .eq(project_id)
3063 .and(project_collaborator::Column::IsHost.eq(false)),
3064 )
3065 .stream(tx)
3066 .await?;
3067
3068 let mut guest_connection_ids = Vec::new();
3069 while let Some(collaborator) = collaborators.next().await {
3070 let collaborator = collaborator?;
3071 guest_connection_ids.push(collaborator.connection());
3072 }
3073 Ok(guest_connection_ids)
3074 }
3075
3076 async fn room_id_for_project(&self, project_id: ProjectId) -> Result<RoomId> {
3077 self.transaction(|tx| async move {
3078 let project = project::Entity::find_by_id(project_id)
3079 .one(&*tx)
3080 .await?
3081 .ok_or_else(|| anyhow!("project {} not found", project_id))?;
3082 Ok(project.room_id)
3083 })
3084 .await
3085 }
3086
3087 // access tokens
3088
3089 pub async fn create_access_token(
3090 &self,
3091 user_id: UserId,
3092 access_token_hash: &str,
3093 max_access_token_count: usize,
3094 ) -> Result<AccessTokenId> {
3095 self.transaction(|tx| async {
3096 let tx = tx;
3097
3098 let token = access_token::ActiveModel {
3099 user_id: ActiveValue::set(user_id),
3100 hash: ActiveValue::set(access_token_hash.into()),
3101 ..Default::default()
3102 }
3103 .insert(&*tx)
3104 .await?;
3105
3106 access_token::Entity::delete_many()
3107 .filter(
3108 access_token::Column::Id.in_subquery(
3109 Query::select()
3110 .column(access_token::Column::Id)
3111 .from(access_token::Entity)
3112 .and_where(access_token::Column::UserId.eq(user_id))
3113 .order_by(access_token::Column::Id, sea_orm::Order::Desc)
3114 .limit(10000)
3115 .offset(max_access_token_count as u64)
3116 .to_owned(),
3117 ),
3118 )
3119 .exec(&*tx)
3120 .await?;
3121 Ok(token.id)
3122 })
3123 .await
3124 }
3125
3126 pub async fn get_access_token(
3127 &self,
3128 access_token_id: AccessTokenId,
3129 ) -> Result<access_token::Model> {
3130 self.transaction(|tx| async move {
3131 Ok(access_token::Entity::find_by_id(access_token_id)
3132 .one(&*tx)
3133 .await?
3134 .ok_or_else(|| anyhow!("no such access token"))?)
3135 })
3136 .await
3137 }
3138
3139 // channels
3140
3141 pub async fn create_root_channel(
3142 &self,
3143 name: &str,
3144 live_kit_room: &str,
3145 creator_id: UserId,
3146 ) -> Result<ChannelId> {
3147 self.create_channel(name, None, live_kit_room, creator_id)
3148 .await
3149 }
3150
3151 pub async fn create_channel(
3152 &self,
3153 name: &str,
3154 parent: Option<ChannelId>,
3155 live_kit_room: &str,
3156 creator_id: UserId,
3157 ) -> Result<ChannelId> {
3158 let name = Self::sanitize_channel_name(name)?;
3159 self.transaction(move |tx| async move {
3160 if let Some(parent) = parent {
3161 self.check_user_is_channel_admin(parent, creator_id, &*tx)
3162 .await?;
3163 }
3164
3165 let channel = channel::ActiveModel {
3166 name: ActiveValue::Set(name.to_string()),
3167 ..Default::default()
3168 }
3169 .insert(&*tx)
3170 .await?;
3171
3172 let channel_paths_stmt;
3173 if let Some(parent) = parent {
3174 let sql = r#"
3175 INSERT INTO channel_paths
3176 (id_path, channel_id)
3177 SELECT
3178 id_path || $1 || '/', $2
3179 FROM
3180 channel_paths
3181 WHERE
3182 channel_id = $3
3183 "#;
3184 channel_paths_stmt = Statement::from_sql_and_values(
3185 self.pool.get_database_backend(),
3186 sql,
3187 [
3188 channel.id.to_proto().into(),
3189 channel.id.to_proto().into(),
3190 parent.to_proto().into(),
3191 ],
3192 );
3193 tx.execute(channel_paths_stmt).await?;
3194 } else {
3195 channel_path::Entity::insert(channel_path::ActiveModel {
3196 channel_id: ActiveValue::Set(channel.id),
3197 id_path: ActiveValue::Set(format!("/{}/", channel.id)),
3198 })
3199 .exec(&*tx)
3200 .await?;
3201 }
3202
3203 channel_member::ActiveModel {
3204 channel_id: ActiveValue::Set(channel.id),
3205 user_id: ActiveValue::Set(creator_id),
3206 accepted: ActiveValue::Set(true),
3207 admin: ActiveValue::Set(true),
3208 ..Default::default()
3209 }
3210 .insert(&*tx)
3211 .await?;
3212
3213 room::ActiveModel {
3214 channel_id: ActiveValue::Set(Some(channel.id)),
3215 live_kit_room: ActiveValue::Set(live_kit_room.to_string()),
3216 ..Default::default()
3217 }
3218 .insert(&*tx)
3219 .await?;
3220
3221 Ok(channel.id)
3222 })
3223 .await
3224 }
3225
3226 pub async fn remove_channel(
3227 &self,
3228 channel_id: ChannelId,
3229 user_id: UserId,
3230 ) -> Result<(Vec<ChannelId>, Vec<UserId>)> {
3231 self.transaction(move |tx| async move {
3232 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
3233 .await?;
3234
3235 // Don't remove descendant channels that have additional parents.
3236 let mut channels_to_remove = self.get_channel_descendants([channel_id], &*tx).await?;
3237 {
3238 let mut channels_to_keep = channel_path::Entity::find()
3239 .filter(
3240 channel_path::Column::ChannelId
3241 .is_in(
3242 channels_to_remove
3243 .keys()
3244 .copied()
3245 .filter(|&id| id != channel_id),
3246 )
3247 .and(
3248 channel_path::Column::IdPath
3249 .not_like(&format!("%/{}/%", channel_id)),
3250 ),
3251 )
3252 .stream(&*tx)
3253 .await?;
3254 while let Some(row) = channels_to_keep.next().await {
3255 let row = row?;
3256 channels_to_remove.remove(&row.channel_id);
3257 }
3258 }
3259
3260 let channel_ancestors = self.get_channel_ancestors(channel_id, &*tx).await?;
3261 let members_to_notify: Vec<UserId> = channel_member::Entity::find()
3262 .filter(channel_member::Column::ChannelId.is_in(channel_ancestors))
3263 .select_only()
3264 .column(channel_member::Column::UserId)
3265 .distinct()
3266 .into_values::<_, QueryUserIds>()
3267 .all(&*tx)
3268 .await?;
3269
3270 channel::Entity::delete_many()
3271 .filter(channel::Column::Id.is_in(channels_to_remove.keys().copied()))
3272 .exec(&*tx)
3273 .await?;
3274
3275 Ok((channels_to_remove.into_keys().collect(), members_to_notify))
3276 })
3277 .await
3278 }
3279
3280 pub async fn invite_channel_member(
3281 &self,
3282 channel_id: ChannelId,
3283 invitee_id: UserId,
3284 inviter_id: UserId,
3285 is_admin: bool,
3286 ) -> Result<()> {
3287 self.transaction(move |tx| async move {
3288 self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
3289 .await?;
3290
3291 channel_member::ActiveModel {
3292 channel_id: ActiveValue::Set(channel_id),
3293 user_id: ActiveValue::Set(invitee_id),
3294 accepted: ActiveValue::Set(false),
3295 admin: ActiveValue::Set(is_admin),
3296 ..Default::default()
3297 }
3298 .insert(&*tx)
3299 .await?;
3300
3301 Ok(())
3302 })
3303 .await
3304 }
3305
3306 fn sanitize_channel_name(name: &str) -> Result<&str> {
3307 let new_name = name.trim().trim_start_matches('#');
3308 if new_name == "" {
3309 Err(anyhow!("channel name can't be blank"))?;
3310 }
3311 Ok(new_name)
3312 }
3313
3314 pub async fn rename_channel(
3315 &self,
3316 channel_id: ChannelId,
3317 user_id: UserId,
3318 new_name: &str,
3319 ) -> Result<String> {
3320 self.transaction(move |tx| async move {
3321 let new_name = Self::sanitize_channel_name(new_name)?.to_string();
3322
3323 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
3324 .await?;
3325
3326 channel::ActiveModel {
3327 id: ActiveValue::Unchanged(channel_id),
3328 name: ActiveValue::Set(new_name.clone()),
3329 ..Default::default()
3330 }
3331 .update(&*tx)
3332 .await?;
3333
3334 Ok(new_name)
3335 })
3336 .await
3337 }
3338
3339 pub async fn respond_to_channel_invite(
3340 &self,
3341 channel_id: ChannelId,
3342 user_id: UserId,
3343 accept: bool,
3344 ) -> Result<()> {
3345 self.transaction(move |tx| async move {
3346 let rows_affected = if accept {
3347 channel_member::Entity::update_many()
3348 .set(channel_member::ActiveModel {
3349 accepted: ActiveValue::Set(accept),
3350 ..Default::default()
3351 })
3352 .filter(
3353 channel_member::Column::ChannelId
3354 .eq(channel_id)
3355 .and(channel_member::Column::UserId.eq(user_id))
3356 .and(channel_member::Column::Accepted.eq(false)),
3357 )
3358 .exec(&*tx)
3359 .await?
3360 .rows_affected
3361 } else {
3362 channel_member::ActiveModel {
3363 channel_id: ActiveValue::Unchanged(channel_id),
3364 user_id: ActiveValue::Unchanged(user_id),
3365 ..Default::default()
3366 }
3367 .delete(&*tx)
3368 .await?
3369 .rows_affected
3370 };
3371
3372 if rows_affected == 0 {
3373 Err(anyhow!("no such invitation"))?;
3374 }
3375
3376 Ok(())
3377 })
3378 .await
3379 }
3380
3381 pub async fn remove_channel_member(
3382 &self,
3383 channel_id: ChannelId,
3384 member_id: UserId,
3385 remover_id: UserId,
3386 ) -> Result<()> {
3387 self.transaction(|tx| async move {
3388 self.check_user_is_channel_admin(channel_id, remover_id, &*tx)
3389 .await?;
3390
3391 let result = channel_member::Entity::delete_many()
3392 .filter(
3393 channel_member::Column::ChannelId
3394 .eq(channel_id)
3395 .and(channel_member::Column::UserId.eq(member_id)),
3396 )
3397 .exec(&*tx)
3398 .await?;
3399
3400 if result.rows_affected == 0 {
3401 Err(anyhow!("no such member"))?;
3402 }
3403
3404 Ok(())
3405 })
3406 .await
3407 }
3408
3409 pub async fn get_channel_invites_for_user(&self, user_id: UserId) -> Result<Vec<Channel>> {
3410 self.transaction(|tx| async move {
3411 let channel_invites = channel_member::Entity::find()
3412 .filter(
3413 channel_member::Column::UserId
3414 .eq(user_id)
3415 .and(channel_member::Column::Accepted.eq(false)),
3416 )
3417 .all(&*tx)
3418 .await?;
3419
3420 let channels = channel::Entity::find()
3421 .filter(
3422 channel::Column::Id.is_in(
3423 channel_invites
3424 .into_iter()
3425 .map(|channel_member| channel_member.channel_id),
3426 ),
3427 )
3428 .all(&*tx)
3429 .await?;
3430
3431 let channels = channels
3432 .into_iter()
3433 .map(|channel| Channel {
3434 id: channel.id,
3435 name: channel.name,
3436 parent_id: None,
3437 })
3438 .collect();
3439
3440 Ok(channels)
3441 })
3442 .await
3443 }
3444
3445 pub async fn get_channels_for_user(&self, user_id: UserId) -> Result<ChannelsForUser> {
3446 self.transaction(|tx| async move {
3447 let tx = tx;
3448
3449 let channel_memberships = channel_member::Entity::find()
3450 .filter(
3451 channel_member::Column::UserId
3452 .eq(user_id)
3453 .and(channel_member::Column::Accepted.eq(true)),
3454 )
3455 .all(&*tx)
3456 .await?;
3457
3458 let parents_by_child_id = self
3459 .get_channel_descendants(channel_memberships.iter().map(|m| m.channel_id), &*tx)
3460 .await?;
3461
3462 let channels_with_admin_privileges = channel_memberships
3463 .iter()
3464 .filter_map(|membership| membership.admin.then_some(membership.channel_id))
3465 .collect();
3466
3467 let mut channels = Vec::with_capacity(parents_by_child_id.len());
3468 {
3469 let mut rows = channel::Entity::find()
3470 .filter(channel::Column::Id.is_in(parents_by_child_id.keys().copied()))
3471 .stream(&*tx)
3472 .await?;
3473 while let Some(row) = rows.next().await {
3474 let row = row?;
3475 channels.push(Channel {
3476 id: row.id,
3477 name: row.name,
3478 parent_id: parents_by_child_id.get(&row.id).copied().flatten(),
3479 });
3480 }
3481 }
3482
3483 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
3484 enum QueryUserIdsAndChannelIds {
3485 ChannelId,
3486 UserId,
3487 }
3488
3489 let mut channel_participants: HashMap<ChannelId, Vec<UserId>> = HashMap::default();
3490 {
3491 let mut rows = room_participant::Entity::find()
3492 .inner_join(room::Entity)
3493 .filter(room::Column::ChannelId.is_in(channels.iter().map(|c| c.id)))
3494 .select_only()
3495 .column(room::Column::ChannelId)
3496 .column(room_participant::Column::UserId)
3497 .into_values::<_, QueryUserIdsAndChannelIds>()
3498 .stream(&*tx)
3499 .await?;
3500 while let Some(row) = rows.next().await {
3501 let row: (ChannelId, UserId) = row?;
3502 channel_participants.entry(row.0).or_default().push(row.1)
3503 }
3504 }
3505
3506 Ok(ChannelsForUser {
3507 channels,
3508 channel_participants,
3509 channels_with_admin_privileges,
3510 })
3511 })
3512 .await
3513 }
3514
3515 pub async fn get_channel_members(&self, id: ChannelId) -> Result<Vec<UserId>> {
3516 self.transaction(|tx| async move { self.get_channel_members_internal(id, &*tx).await })
3517 .await
3518 }
3519
3520 pub async fn set_channel_member_admin(
3521 &self,
3522 channel_id: ChannelId,
3523 from: UserId,
3524 for_user: UserId,
3525 admin: bool,
3526 ) -> Result<()> {
3527 self.transaction(|tx| async move {
3528 self.check_user_is_channel_admin(channel_id, from, &*tx)
3529 .await?;
3530
3531 let result = channel_member::Entity::update_many()
3532 .filter(
3533 channel_member::Column::ChannelId
3534 .eq(channel_id)
3535 .and(channel_member::Column::UserId.eq(for_user)),
3536 )
3537 .set(channel_member::ActiveModel {
3538 admin: ActiveValue::set(admin),
3539 ..Default::default()
3540 })
3541 .exec(&*tx)
3542 .await?;
3543
3544 if result.rows_affected == 0 {
3545 Err(anyhow!("no such member"))?;
3546 }
3547
3548 Ok(())
3549 })
3550 .await
3551 }
3552
3553 pub async fn get_channel_member_details(
3554 &self,
3555 channel_id: ChannelId,
3556 user_id: UserId,
3557 ) -> Result<Vec<proto::ChannelMember>> {
3558 self.transaction(|tx| async move {
3559 self.check_user_is_channel_admin(channel_id, user_id, &*tx)
3560 .await?;
3561
3562 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
3563 enum QueryMemberDetails {
3564 UserId,
3565 Admin,
3566 IsDirectMember,
3567 Accepted,
3568 }
3569
3570 let tx = tx;
3571 let ancestor_ids = self.get_channel_ancestors(channel_id, &*tx).await?;
3572 let mut stream = channel_member::Entity::find()
3573 .distinct()
3574 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
3575 .select_only()
3576 .column(channel_member::Column::UserId)
3577 .column(channel_member::Column::Admin)
3578 .column_as(
3579 channel_member::Column::ChannelId.eq(channel_id),
3580 QueryMemberDetails::IsDirectMember,
3581 )
3582 .column(channel_member::Column::Accepted)
3583 .order_by_asc(channel_member::Column::UserId)
3584 .into_values::<_, QueryMemberDetails>()
3585 .stream(&*tx)
3586 .await?;
3587
3588 let mut rows = Vec::<proto::ChannelMember>::new();
3589 while let Some(row) = stream.next().await {
3590 let (user_id, is_admin, is_direct_member, is_invite_accepted): (
3591 UserId,
3592 bool,
3593 bool,
3594 bool,
3595 ) = row?;
3596 let kind = match (is_direct_member, is_invite_accepted) {
3597 (true, true) => proto::channel_member::Kind::Member,
3598 (true, false) => proto::channel_member::Kind::Invitee,
3599 (false, true) => proto::channel_member::Kind::AncestorMember,
3600 (false, false) => continue,
3601 };
3602 let user_id = user_id.to_proto();
3603 let kind = kind.into();
3604 if let Some(last_row) = rows.last_mut() {
3605 if last_row.user_id == user_id {
3606 if is_direct_member {
3607 last_row.kind = kind;
3608 last_row.admin = is_admin;
3609 }
3610 continue;
3611 }
3612 }
3613 rows.push(proto::ChannelMember {
3614 user_id,
3615 kind,
3616 admin: is_admin,
3617 });
3618 }
3619
3620 Ok(rows)
3621 })
3622 .await
3623 }
3624
3625 pub async fn get_channel_members_internal(
3626 &self,
3627 id: ChannelId,
3628 tx: &DatabaseTransaction,
3629 ) -> Result<Vec<UserId>> {
3630 let ancestor_ids = self.get_channel_ancestors(id, tx).await?;
3631 let user_ids = channel_member::Entity::find()
3632 .distinct()
3633 .filter(channel_member::Column::ChannelId.is_in(ancestor_ids.iter().copied()))
3634 .select_only()
3635 .column(channel_member::Column::UserId)
3636 .into_values::<_, QueryUserIds>()
3637 .all(&*tx)
3638 .await?;
3639 Ok(user_ids)
3640 }
3641
3642 async fn check_user_is_channel_member(
3643 &self,
3644 channel_id: ChannelId,
3645 user_id: UserId,
3646 tx: &DatabaseTransaction,
3647 ) -> Result<()> {
3648 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
3649 channel_member::Entity::find()
3650 .filter(
3651 channel_member::Column::ChannelId
3652 .is_in(channel_ids)
3653 .and(channel_member::Column::UserId.eq(user_id)),
3654 )
3655 .one(&*tx)
3656 .await?
3657 .ok_or_else(|| anyhow!("user is not a channel member or channel does not exist"))?;
3658 Ok(())
3659 }
3660
3661 async fn check_user_is_channel_admin(
3662 &self,
3663 channel_id: ChannelId,
3664 user_id: UserId,
3665 tx: &DatabaseTransaction,
3666 ) -> Result<()> {
3667 let channel_ids = self.get_channel_ancestors(channel_id, tx).await?;
3668 channel_member::Entity::find()
3669 .filter(
3670 channel_member::Column::ChannelId
3671 .is_in(channel_ids)
3672 .and(channel_member::Column::UserId.eq(user_id))
3673 .and(channel_member::Column::Admin.eq(true)),
3674 )
3675 .one(&*tx)
3676 .await?
3677 .ok_or_else(|| anyhow!("user is not a channel admin or channel does not exist"))?;
3678 Ok(())
3679 }
3680
3681 async fn get_channel_ancestors(
3682 &self,
3683 channel_id: ChannelId,
3684 tx: &DatabaseTransaction,
3685 ) -> Result<Vec<ChannelId>> {
3686 let paths = channel_path::Entity::find()
3687 .filter(channel_path::Column::ChannelId.eq(channel_id))
3688 .all(tx)
3689 .await?;
3690 let mut channel_ids = Vec::new();
3691 for path in paths {
3692 for id in path.id_path.trim_matches('/').split('/') {
3693 if let Ok(id) = id.parse() {
3694 let id = ChannelId::from_proto(id);
3695 if let Err(ix) = channel_ids.binary_search(&id) {
3696 channel_ids.insert(ix, id);
3697 }
3698 }
3699 }
3700 }
3701 Ok(channel_ids)
3702 }
3703
3704 async fn get_channel_descendants(
3705 &self,
3706 channel_ids: impl IntoIterator<Item = ChannelId>,
3707 tx: &DatabaseTransaction,
3708 ) -> Result<HashMap<ChannelId, Option<ChannelId>>> {
3709 let mut values = String::new();
3710 for id in channel_ids {
3711 if !values.is_empty() {
3712 values.push_str(", ");
3713 }
3714 write!(&mut values, "({})", id).unwrap();
3715 }
3716
3717 if values.is_empty() {
3718 return Ok(HashMap::default());
3719 }
3720
3721 let sql = format!(
3722 r#"
3723 SELECT
3724 descendant_paths.*
3725 FROM
3726 channel_paths parent_paths, channel_paths descendant_paths
3727 WHERE
3728 parent_paths.channel_id IN ({values}) AND
3729 descendant_paths.id_path LIKE (parent_paths.id_path || '%')
3730 "#
3731 );
3732
3733 let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
3734
3735 let mut parents_by_child_id = HashMap::default();
3736 let mut paths = channel_path::Entity::find()
3737 .from_raw_sql(stmt)
3738 .stream(tx)
3739 .await?;
3740
3741 while let Some(path) = paths.next().await {
3742 let path = path?;
3743 let ids = path.id_path.trim_matches('/').split('/');
3744 let mut parent_id = None;
3745 for id in ids {
3746 if let Ok(id) = id.parse() {
3747 let id = ChannelId::from_proto(id);
3748 if id == path.channel_id {
3749 break;
3750 }
3751 parent_id = Some(id);
3752 }
3753 }
3754 parents_by_child_id.insert(path.channel_id, parent_id);
3755 }
3756
3757 Ok(parents_by_child_id)
3758 }
3759
3760 /// Returns the channel with the given ID and:
3761 /// - true if the user is a member
3762 /// - false if the user hasn't accepted the invitation yet
3763 pub async fn get_channel(
3764 &self,
3765 channel_id: ChannelId,
3766 user_id: UserId,
3767 ) -> Result<Option<(Channel, bool)>> {
3768 self.transaction(|tx| async move {
3769 let tx = tx;
3770
3771 let channel = channel::Entity::find_by_id(channel_id).one(&*tx).await?;
3772
3773 if let Some(channel) = channel {
3774 if self
3775 .check_user_is_channel_member(channel_id, user_id, &*tx)
3776 .await
3777 .is_err()
3778 {
3779 return Ok(None);
3780 }
3781
3782 let channel_membership = channel_member::Entity::find()
3783 .filter(
3784 channel_member::Column::ChannelId
3785 .eq(channel_id)
3786 .and(channel_member::Column::UserId.eq(user_id)),
3787 )
3788 .one(&*tx)
3789 .await?;
3790
3791 let is_accepted = channel_membership
3792 .map(|membership| membership.accepted)
3793 .unwrap_or(false);
3794
3795 Ok(Some((
3796 Channel {
3797 id: channel.id,
3798 name: channel.name,
3799 parent_id: None,
3800 },
3801 is_accepted,
3802 )))
3803 } else {
3804 Ok(None)
3805 }
3806 })
3807 .await
3808 }
3809
3810 pub async fn room_id_for_channel(&self, channel_id: ChannelId) -> Result<RoomId> {
3811 self.transaction(|tx| async move {
3812 let tx = tx;
3813 let room = channel::Model {
3814 id: channel_id,
3815 ..Default::default()
3816 }
3817 .find_related(room::Entity)
3818 .one(&*tx)
3819 .await?
3820 .ok_or_else(|| anyhow!("invalid channel"))?;
3821 Ok(room.id)
3822 })
3823 .await
3824 }
3825
3826 async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
3827 where
3828 F: Send + Fn(TransactionHandle) -> Fut,
3829 Fut: Send + Future<Output = Result<T>>,
3830 {
3831 let body = async {
3832 let mut i = 0;
3833 loop {
3834 let (tx, result) = self.with_transaction(&f).await?;
3835 match result {
3836 Ok(result) => match tx.commit().await.map_err(Into::into) {
3837 Ok(()) => return Ok(result),
3838 Err(error) => {
3839 if !self.retry_on_serialization_error(&error, i).await {
3840 return Err(error);
3841 }
3842 }
3843 },
3844 Err(error) => {
3845 tx.rollback().await?;
3846 if !self.retry_on_serialization_error(&error, i).await {
3847 return Err(error);
3848 }
3849 }
3850 }
3851 i += 1;
3852 }
3853 };
3854
3855 self.run(body).await
3856 }
3857
3858 async fn optional_room_transaction<F, Fut, T>(&self, f: F) -> Result<Option<RoomGuard<T>>>
3859 where
3860 F: Send + Fn(TransactionHandle) -> Fut,
3861 Fut: Send + Future<Output = Result<Option<(RoomId, T)>>>,
3862 {
3863 let body = async {
3864 let mut i = 0;
3865 loop {
3866 let (tx, result) = self.with_transaction(&f).await?;
3867 match result {
3868 Ok(Some((room_id, data))) => {
3869 let lock = self.rooms.entry(room_id).or_default().clone();
3870 let _guard = lock.lock_owned().await;
3871 match tx.commit().await.map_err(Into::into) {
3872 Ok(()) => {
3873 return Ok(Some(RoomGuard {
3874 data,
3875 _guard,
3876 _not_send: PhantomData,
3877 }));
3878 }
3879 Err(error) => {
3880 if !self.retry_on_serialization_error(&error, i).await {
3881 return Err(error);
3882 }
3883 }
3884 }
3885 }
3886 Ok(None) => match tx.commit().await.map_err(Into::into) {
3887 Ok(()) => return Ok(None),
3888 Err(error) => {
3889 if !self.retry_on_serialization_error(&error, i).await {
3890 return Err(error);
3891 }
3892 }
3893 },
3894 Err(error) => {
3895 tx.rollback().await?;
3896 if !self.retry_on_serialization_error(&error, i).await {
3897 return Err(error);
3898 }
3899 }
3900 }
3901 i += 1;
3902 }
3903 };
3904
3905 self.run(body).await
3906 }
3907
3908 async fn room_transaction<F, Fut, T>(&self, room_id: RoomId, f: F) -> Result<RoomGuard<T>>
3909 where
3910 F: Send + Fn(TransactionHandle) -> Fut,
3911 Fut: Send + Future<Output = Result<T>>,
3912 {
3913 let body = async {
3914 let mut i = 0;
3915 loop {
3916 let lock = self.rooms.entry(room_id).or_default().clone();
3917 let _guard = lock.lock_owned().await;
3918 let (tx, result) = self.with_transaction(&f).await?;
3919 match result {
3920 Ok(data) => match tx.commit().await.map_err(Into::into) {
3921 Ok(()) => {
3922 return Ok(RoomGuard {
3923 data,
3924 _guard,
3925 _not_send: PhantomData,
3926 });
3927 }
3928 Err(error) => {
3929 if !self.retry_on_serialization_error(&error, i).await {
3930 return Err(error);
3931 }
3932 }
3933 },
3934 Err(error) => {
3935 tx.rollback().await?;
3936 if !self.retry_on_serialization_error(&error, i).await {
3937 return Err(error);
3938 }
3939 }
3940 }
3941 i += 1;
3942 }
3943 };
3944
3945 self.run(body).await
3946 }
3947
3948 async fn with_transaction<F, Fut, T>(&self, f: &F) -> Result<(DatabaseTransaction, Result<T>)>
3949 where
3950 F: Send + Fn(TransactionHandle) -> Fut,
3951 Fut: Send + Future<Output = Result<T>>,
3952 {
3953 let tx = self
3954 .pool
3955 .begin_with_config(Some(IsolationLevel::Serializable), None)
3956 .await?;
3957
3958 let mut tx = Arc::new(Some(tx));
3959 let result = f(TransactionHandle(tx.clone())).await;
3960 let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else {
3961 return Err(anyhow!("couldn't complete transaction because it's still in use"))?;
3962 };
3963
3964 Ok((tx, result))
3965 }
3966
3967 async fn run<F, T>(&self, future: F) -> Result<T>
3968 where
3969 F: Future<Output = Result<T>>,
3970 {
3971 #[cfg(test)]
3972 {
3973 if let Executor::Deterministic(executor) = &self.executor {
3974 executor.simulate_random_delay().await;
3975 }
3976
3977 self.runtime.as_ref().unwrap().block_on(future)
3978 }
3979
3980 #[cfg(not(test))]
3981 {
3982 future.await
3983 }
3984 }
3985
3986 async fn retry_on_serialization_error(&self, error: &Error, prev_attempt_count: u32) -> bool {
3987 // If the error is due to a failure to serialize concurrent transactions, then retry
3988 // this transaction after a delay. With each subsequent retry, double the delay duration.
3989 // Also vary the delay randomly in order to ensure different database connections retry
3990 // at different times.
3991 if is_serialization_error(error) {
3992 let base_delay = 4_u64 << prev_attempt_count.min(16);
3993 let randomized_delay = base_delay as f32 * self.rng.lock().await.gen_range(0.5..=2.0);
3994 log::info!(
3995 "retrying transaction after serialization error. delay: {} ms.",
3996 randomized_delay
3997 );
3998 self.executor
3999 .sleep(Duration::from_millis(randomized_delay as u64))
4000 .await;
4001 true
4002 } else {
4003 false
4004 }
4005 }
4006}
4007
4008fn is_serialization_error(error: &Error) -> bool {
4009 const SERIALIZATION_FAILURE_CODE: &'static str = "40001";
4010 match error {
4011 Error::Database(
4012 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
4013 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
4014 ) if error
4015 .as_database_error()
4016 .and_then(|error| error.code())
4017 .as_deref()
4018 == Some(SERIALIZATION_FAILURE_CODE) =>
4019 {
4020 true
4021 }
4022 _ => false,
4023 }
4024}
4025
4026struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
4027
4028impl Deref for TransactionHandle {
4029 type Target = DatabaseTransaction;
4030
4031 fn deref(&self) -> &Self::Target {
4032 self.0.as_ref().as_ref().unwrap()
4033 }
4034}
4035
4036pub struct RoomGuard<T> {
4037 data: T,
4038 _guard: OwnedMutexGuard<()>,
4039 _not_send: PhantomData<Rc<()>>,
4040}
4041
4042impl<T> Deref for RoomGuard<T> {
4043 type Target = T;
4044
4045 fn deref(&self) -> &T {
4046 &self.data
4047 }
4048}
4049
4050impl<T> DerefMut for RoomGuard<T> {
4051 fn deref_mut(&mut self) -> &mut T {
4052 &mut self.data
4053 }
4054}
4055
4056#[derive(Debug, Serialize, Deserialize)]
4057pub struct NewUserParams {
4058 pub github_login: String,
4059 pub github_user_id: i32,
4060 pub invite_count: i32,
4061}
4062
4063#[derive(Debug)]
4064pub struct NewUserResult {
4065 pub user_id: UserId,
4066 pub metrics_id: String,
4067 pub inviting_user_id: Option<UserId>,
4068 pub signup_device_id: Option<String>,
4069}
4070
4071#[derive(FromQueryResult, Debug, PartialEq)]
4072pub struct Channel {
4073 pub id: ChannelId,
4074 pub name: String,
4075 pub parent_id: Option<ChannelId>,
4076}
4077
4078#[derive(Debug, PartialEq)]
4079pub struct ChannelsForUser {
4080 pub channels: Vec<Channel>,
4081 pub channel_participants: HashMap<ChannelId, Vec<UserId>>,
4082 pub channels_with_admin_privileges: HashSet<ChannelId>,
4083}
4084
4085fn random_invite_code() -> String {
4086 nanoid::nanoid!(16)
4087}
4088
4089fn random_email_confirmation_code() -> String {
4090 nanoid::nanoid!(64)
4091}
4092
4093macro_rules! id_type {
4094 ($name:ident) => {
4095 #[derive(
4096 Clone,
4097 Copy,
4098 Debug,
4099 Default,
4100 PartialEq,
4101 Eq,
4102 PartialOrd,
4103 Ord,
4104 Hash,
4105 Serialize,
4106 Deserialize,
4107 )]
4108 #[serde(transparent)]
4109 pub struct $name(pub i32);
4110
4111 impl $name {
4112 #[allow(unused)]
4113 pub const MAX: Self = Self(i32::MAX);
4114
4115 #[allow(unused)]
4116 pub fn from_proto(value: u64) -> Self {
4117 Self(value as i32)
4118 }
4119
4120 #[allow(unused)]
4121 pub fn to_proto(self) -> u64 {
4122 self.0 as u64
4123 }
4124 }
4125
4126 impl std::fmt::Display for $name {
4127 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
4128 self.0.fmt(f)
4129 }
4130 }
4131
4132 impl From<$name> for sea_query::Value {
4133 fn from(value: $name) -> Self {
4134 sea_query::Value::Int(Some(value.0))
4135 }
4136 }
4137
4138 impl sea_orm::TryGetable for $name {
4139 fn try_get(
4140 res: &sea_orm::QueryResult,
4141 pre: &str,
4142 col: &str,
4143 ) -> Result<Self, sea_orm::TryGetError> {
4144 Ok(Self(i32::try_get(res, pre, col)?))
4145 }
4146 }
4147
4148 impl sea_query::ValueType for $name {
4149 fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
4150 match v {
4151 Value::TinyInt(Some(int)) => {
4152 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
4153 }
4154 Value::SmallInt(Some(int)) => {
4155 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
4156 }
4157 Value::Int(Some(int)) => {
4158 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
4159 }
4160 Value::BigInt(Some(int)) => {
4161 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
4162 }
4163 Value::TinyUnsigned(Some(int)) => {
4164 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
4165 }
4166 Value::SmallUnsigned(Some(int)) => {
4167 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
4168 }
4169 Value::Unsigned(Some(int)) => {
4170 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
4171 }
4172 Value::BigUnsigned(Some(int)) => {
4173 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
4174 }
4175 _ => Err(sea_query::ValueTypeErr),
4176 }
4177 }
4178
4179 fn type_name() -> String {
4180 stringify!($name).into()
4181 }
4182
4183 fn array_type() -> sea_query::ArrayType {
4184 sea_query::ArrayType::Int
4185 }
4186
4187 fn column_type() -> sea_query::ColumnType {
4188 sea_query::ColumnType::Integer(None)
4189 }
4190 }
4191
4192 impl sea_orm::TryFromU64 for $name {
4193 fn try_from_u64(n: u64) -> Result<Self, DbErr> {
4194 Ok(Self(n.try_into().map_err(|_| {
4195 DbErr::ConvertFromU64(concat!(
4196 "error converting ",
4197 stringify!($name),
4198 " to u64"
4199 ))
4200 })?))
4201 }
4202 }
4203
4204 impl sea_query::Nullable for $name {
4205 fn null() -> Value {
4206 Value::Int(None)
4207 }
4208 }
4209 };
4210}
4211
4212id_type!(AccessTokenId);
4213id_type!(ChannelId);
4214id_type!(ChannelMemberId);
4215id_type!(ContactId);
4216id_type!(FollowerId);
4217id_type!(RoomId);
4218id_type!(RoomParticipantId);
4219id_type!(ProjectId);
4220id_type!(ProjectCollaboratorId);
4221id_type!(ReplicaId);
4222id_type!(ServerId);
4223id_type!(SignupId);
4224id_type!(UserId);
4225
4226#[derive(Clone)]
4227pub struct JoinRoom {
4228 pub room: proto::Room,
4229 pub channel_id: Option<ChannelId>,
4230 pub channel_members: Vec<UserId>,
4231}
4232
4233pub struct RejoinedRoom {
4234 pub room: proto::Room,
4235 pub rejoined_projects: Vec<RejoinedProject>,
4236 pub reshared_projects: Vec<ResharedProject>,
4237 pub channel_id: Option<ChannelId>,
4238 pub channel_members: Vec<UserId>,
4239}
4240
4241pub struct ResharedProject {
4242 pub id: ProjectId,
4243 pub old_connection_id: ConnectionId,
4244 pub collaborators: Vec<ProjectCollaborator>,
4245 pub worktrees: Vec<proto::WorktreeMetadata>,
4246}
4247
4248pub struct RejoinedProject {
4249 pub id: ProjectId,
4250 pub old_connection_id: ConnectionId,
4251 pub collaborators: Vec<ProjectCollaborator>,
4252 pub worktrees: Vec<RejoinedWorktree>,
4253 pub language_servers: Vec<proto::LanguageServer>,
4254}
4255
4256#[derive(Debug)]
4257pub struct RejoinedWorktree {
4258 pub id: u64,
4259 pub abs_path: String,
4260 pub root_name: String,
4261 pub visible: bool,
4262 pub updated_entries: Vec<proto::Entry>,
4263 pub removed_entries: Vec<u64>,
4264 pub updated_repositories: Vec<proto::RepositoryEntry>,
4265 pub removed_repositories: Vec<u64>,
4266 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
4267 pub settings_files: Vec<WorktreeSettingsFile>,
4268 pub scan_id: u64,
4269 pub completed_scan_id: u64,
4270}
4271
4272pub struct LeftRoom {
4273 pub room: proto::Room,
4274 pub channel_id: Option<ChannelId>,
4275 pub channel_members: Vec<UserId>,
4276 pub left_projects: HashMap<ProjectId, LeftProject>,
4277 pub canceled_calls_to_user_ids: Vec<UserId>,
4278 pub deleted: bool,
4279}
4280
4281pub struct RefreshedRoom {
4282 pub room: proto::Room,
4283 pub channel_id: Option<ChannelId>,
4284 pub channel_members: Vec<UserId>,
4285 pub stale_participant_user_ids: Vec<UserId>,
4286 pub canceled_calls_to_user_ids: Vec<UserId>,
4287}
4288
4289pub struct Project {
4290 pub collaborators: Vec<ProjectCollaborator>,
4291 pub worktrees: BTreeMap<u64, Worktree>,
4292 pub language_servers: Vec<proto::LanguageServer>,
4293}
4294
4295pub struct ProjectCollaborator {
4296 pub connection_id: ConnectionId,
4297 pub user_id: UserId,
4298 pub replica_id: ReplicaId,
4299 pub is_host: bool,
4300}
4301
4302impl ProjectCollaborator {
4303 pub fn to_proto(&self) -> proto::Collaborator {
4304 proto::Collaborator {
4305 peer_id: Some(self.connection_id.into()),
4306 replica_id: self.replica_id.0 as u32,
4307 user_id: self.user_id.to_proto(),
4308 }
4309 }
4310}
4311
4312#[derive(Debug)]
4313pub struct LeftProject {
4314 pub id: ProjectId,
4315 pub host_user_id: UserId,
4316 pub host_connection_id: ConnectionId,
4317 pub connection_ids: Vec<ConnectionId>,
4318}
4319
4320pub struct Worktree {
4321 pub id: u64,
4322 pub abs_path: String,
4323 pub root_name: String,
4324 pub visible: bool,
4325 pub entries: Vec<proto::Entry>,
4326 pub repository_entries: BTreeMap<u64, proto::RepositoryEntry>,
4327 pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
4328 pub settings_files: Vec<WorktreeSettingsFile>,
4329 pub scan_id: u64,
4330 pub completed_scan_id: u64,
4331}
4332
4333#[derive(Debug)]
4334pub struct WorktreeSettingsFile {
4335 pub path: String,
4336 pub content: String,
4337}
4338
4339#[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
4340enum QueryUserIds {
4341 UserId,
4342}
4343
4344#[cfg(test)]
4345pub use test::*;
4346
4347#[cfg(test)]
4348mod test {
4349 use super::*;
4350 use gpui::executor::Background;
4351 use parking_lot::Mutex;
4352 use sea_orm::ConnectionTrait;
4353 use sqlx::migrate::MigrateDatabase;
4354 use std::sync::Arc;
4355
4356 pub struct TestDb {
4357 pub db: Option<Arc<Database>>,
4358 pub connection: Option<sqlx::AnyConnection>,
4359 }
4360
4361 impl TestDb {
4362 pub fn sqlite(background: Arc<Background>) -> Self {
4363 let url = format!("sqlite::memory:");
4364 let runtime = tokio::runtime::Builder::new_current_thread()
4365 .enable_io()
4366 .enable_time()
4367 .build()
4368 .unwrap();
4369
4370 let mut db = runtime.block_on(async {
4371 let mut options = ConnectOptions::new(url);
4372 options.max_connections(5);
4373 let db = Database::new(options, Executor::Deterministic(background))
4374 .await
4375 .unwrap();
4376 let sql = include_str!(concat!(
4377 env!("CARGO_MANIFEST_DIR"),
4378 "/migrations.sqlite/20221109000000_test_schema.sql"
4379 ));
4380 db.pool
4381 .execute(sea_orm::Statement::from_string(
4382 db.pool.get_database_backend(),
4383 sql.into(),
4384 ))
4385 .await
4386 .unwrap();
4387 db
4388 });
4389
4390 db.runtime = Some(runtime);
4391
4392 Self {
4393 db: Some(Arc::new(db)),
4394 connection: None,
4395 }
4396 }
4397
4398 pub fn postgres(background: Arc<Background>) -> Self {
4399 static LOCK: Mutex<()> = Mutex::new(());
4400
4401 let _guard = LOCK.lock();
4402 let mut rng = StdRng::from_entropy();
4403 let url = format!(
4404 "postgres://postgres@localhost/zed-test-{}",
4405 rng.gen::<u128>()
4406 );
4407 let runtime = tokio::runtime::Builder::new_current_thread()
4408 .enable_io()
4409 .enable_time()
4410 .build()
4411 .unwrap();
4412
4413 let mut db = runtime.block_on(async {
4414 sqlx::Postgres::create_database(&url)
4415 .await
4416 .expect("failed to create test db");
4417 let mut options = ConnectOptions::new(url);
4418 options
4419 .max_connections(5)
4420 .idle_timeout(Duration::from_secs(0));
4421 let db = Database::new(options, Executor::Deterministic(background))
4422 .await
4423 .unwrap();
4424 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
4425 db.migrate(Path::new(migrations_path), false).await.unwrap();
4426 db
4427 });
4428
4429 db.runtime = Some(runtime);
4430
4431 Self {
4432 db: Some(Arc::new(db)),
4433 connection: None,
4434 }
4435 }
4436
4437 pub fn db(&self) -> &Arc<Database> {
4438 self.db.as_ref().unwrap()
4439 }
4440 }
4441
4442 impl Drop for TestDb {
4443 fn drop(&mut self) {
4444 let db = self.db.take().unwrap();
4445 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
4446 db.runtime.as_ref().unwrap().block_on(async {
4447 use util::ResultExt;
4448 let query = "
4449 SELECT pg_terminate_backend(pg_stat_activity.pid)
4450 FROM pg_stat_activity
4451 WHERE
4452 pg_stat_activity.datname = current_database() AND
4453 pid <> pg_backend_pid();
4454 ";
4455 db.pool
4456 .execute(sea_orm::Statement::from_string(
4457 db.pool.get_database_backend(),
4458 query.into(),
4459 ))
4460 .await
4461 .log_err();
4462 sqlx::Postgres::drop_database(db.options.get_url())
4463 .await
4464 .log_err();
4465 })
4466 }
4467 }
4468 }
4469}