1mod access_token;
2mod contact;
3mod project;
4mod project_collaborator;
5mod room;
6mod room_participant;
7#[cfg(test)]
8mod tests;
9mod user;
10mod worktree;
11
12use crate::{Error, Result};
13use anyhow::anyhow;
14use collections::HashMap;
15use dashmap::DashMap;
16use futures::StreamExt;
17use rpc::{proto, ConnectionId};
18use sea_orm::{
19 entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
20 TransactionTrait,
21};
22use sea_orm::{
23 ActiveValue, ConnectionTrait, FromQueryResult, IntoActiveModel, JoinType, QueryOrder,
24 QuerySelect,
25};
26use sea_query::{Alias, Expr, OnConflict, Query};
27use serde::{Deserialize, Serialize};
28use sqlx::migrate::{Migrate, Migration, MigrationSource};
29use sqlx::Connection;
30use std::ops::{Deref, DerefMut};
31use std::path::Path;
32use std::time::Duration;
33use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
34use tokio::sync::{Mutex, OwnedMutexGuard};
35
36pub use contact::Contact;
37pub use user::Model as User;
38
39pub struct Database {
40 options: ConnectOptions,
41 pool: DatabaseConnection,
42 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
43 #[cfg(test)]
44 background: Option<std::sync::Arc<gpui::executor::Background>>,
45 #[cfg(test)]
46 runtime: Option<tokio::runtime::Runtime>,
47}
48
49impl Database {
50 pub async fn new(options: ConnectOptions) -> Result<Self> {
51 Ok(Self {
52 options: options.clone(),
53 pool: sea_orm::Database::connect(options).await?,
54 rooms: DashMap::with_capacity(16384),
55 #[cfg(test)]
56 background: None,
57 #[cfg(test)]
58 runtime: None,
59 })
60 }
61
62 pub async fn migrate(
63 &self,
64 migrations_path: &Path,
65 ignore_checksum_mismatch: bool,
66 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
67 let migrations = MigrationSource::resolve(migrations_path)
68 .await
69 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
70
71 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
72
73 connection.ensure_migrations_table().await?;
74 let applied_migrations: HashMap<_, _> = connection
75 .list_applied_migrations()
76 .await?
77 .into_iter()
78 .map(|m| (m.version, m))
79 .collect();
80
81 let mut new_migrations = Vec::new();
82 for migration in migrations {
83 match applied_migrations.get(&migration.version) {
84 Some(applied_migration) => {
85 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
86 {
87 Err(anyhow!(
88 "checksum mismatch for applied migration {}",
89 migration.description
90 ))?;
91 }
92 }
93 None => {
94 let elapsed = connection.apply(&migration).await?;
95 new_migrations.push((migration, elapsed));
96 }
97 }
98 }
99
100 Ok(new_migrations)
101 }
102
103 // users
104
105 pub async fn create_user(
106 &self,
107 email_address: &str,
108 admin: bool,
109 params: NewUserParams,
110 ) -> Result<NewUserResult> {
111 self.transact(|tx| async {
112 let user = user::Entity::insert(user::ActiveModel {
113 email_address: ActiveValue::set(Some(email_address.into())),
114 github_login: ActiveValue::set(params.github_login.clone()),
115 github_user_id: ActiveValue::set(Some(params.github_user_id)),
116 admin: ActiveValue::set(admin),
117 metrics_id: ActiveValue::set(Uuid::new_v4()),
118 ..Default::default()
119 })
120 .on_conflict(
121 OnConflict::column(user::Column::GithubLogin)
122 .update_column(user::Column::GithubLogin)
123 .to_owned(),
124 )
125 .exec_with_returning(&tx)
126 .await?;
127
128 tx.commit().await?;
129
130 Ok(NewUserResult {
131 user_id: user.id,
132 metrics_id: user.metrics_id.to_string(),
133 signup_device_id: None,
134 inviting_user_id: None,
135 })
136 })
137 .await
138 }
139
140 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
141 self.transact(|tx| async {
142 let tx = tx;
143 Ok(user::Entity::find()
144 .filter(user::Column::Id.is_in(ids.iter().copied()))
145 .all(&tx)
146 .await?)
147 })
148 .await
149 }
150
151 pub async fn get_user_by_github_account(
152 &self,
153 github_login: &str,
154 github_user_id: Option<i32>,
155 ) -> Result<Option<User>> {
156 self.transact(|tx| async {
157 let tx = tx;
158 if let Some(github_user_id) = github_user_id {
159 if let Some(user_by_github_user_id) = user::Entity::find()
160 .filter(user::Column::GithubUserId.eq(github_user_id))
161 .one(&tx)
162 .await?
163 {
164 let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
165 user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
166 Ok(Some(user_by_github_user_id.update(&tx).await?))
167 } else if let Some(user_by_github_login) = user::Entity::find()
168 .filter(user::Column::GithubLogin.eq(github_login))
169 .one(&tx)
170 .await?
171 {
172 let mut user_by_github_login = user_by_github_login.into_active_model();
173 user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
174 Ok(Some(user_by_github_login.update(&tx).await?))
175 } else {
176 Ok(None)
177 }
178 } else {
179 Ok(user::Entity::find()
180 .filter(user::Column::GithubLogin.eq(github_login))
181 .one(&tx)
182 .await?)
183 }
184 })
185 .await
186 }
187
188 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
189 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
190 enum QueryAs {
191 MetricsId,
192 }
193
194 self.transact(|tx| async move {
195 let metrics_id: Uuid = user::Entity::find_by_id(id)
196 .select_only()
197 .column(user::Column::MetricsId)
198 .into_values::<_, QueryAs>()
199 .one(&tx)
200 .await?
201 .ok_or_else(|| anyhow!("could not find user"))?;
202 Ok(metrics_id.to_string())
203 })
204 .await
205 }
206
207 // contacts
208
209 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
210 #[derive(Debug, FromQueryResult)]
211 struct ContactWithUserBusyStatuses {
212 user_id_a: UserId,
213 user_id_b: UserId,
214 a_to_b: bool,
215 accepted: bool,
216 should_notify: bool,
217 user_a_busy: bool,
218 user_b_busy: bool,
219 }
220
221 self.transact(|tx| async move {
222 let user_a_participant = Alias::new("user_a_participant");
223 let user_b_participant = Alias::new("user_b_participant");
224 let mut db_contacts = contact::Entity::find()
225 .column_as(
226 Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
227 .is_not_null(),
228 "user_a_busy",
229 )
230 .column_as(
231 Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
232 .is_not_null(),
233 "user_b_busy",
234 )
235 .filter(
236 contact::Column::UserIdA
237 .eq(user_id)
238 .or(contact::Column::UserIdB.eq(user_id)),
239 )
240 .join_as(
241 JoinType::LeftJoin,
242 contact::Relation::UserARoomParticipant.def(),
243 user_a_participant,
244 )
245 .join_as(
246 JoinType::LeftJoin,
247 contact::Relation::UserBRoomParticipant.def(),
248 user_b_participant,
249 )
250 .into_model::<ContactWithUserBusyStatuses>()
251 .stream(&tx)
252 .await?;
253
254 let mut contacts = Vec::new();
255 while let Some(db_contact) = db_contacts.next().await {
256 let db_contact = db_contact?;
257 if db_contact.user_id_a == user_id {
258 if db_contact.accepted {
259 contacts.push(Contact::Accepted {
260 user_id: db_contact.user_id_b,
261 should_notify: db_contact.should_notify && db_contact.a_to_b,
262 busy: db_contact.user_b_busy,
263 });
264 } else if db_contact.a_to_b {
265 contacts.push(Contact::Outgoing {
266 user_id: db_contact.user_id_b,
267 })
268 } else {
269 contacts.push(Contact::Incoming {
270 user_id: db_contact.user_id_b,
271 should_notify: db_contact.should_notify,
272 });
273 }
274 } else if db_contact.accepted {
275 contacts.push(Contact::Accepted {
276 user_id: db_contact.user_id_a,
277 should_notify: db_contact.should_notify && !db_contact.a_to_b,
278 busy: db_contact.user_a_busy,
279 });
280 } else if db_contact.a_to_b {
281 contacts.push(Contact::Incoming {
282 user_id: db_contact.user_id_a,
283 should_notify: db_contact.should_notify,
284 });
285 } else {
286 contacts.push(Contact::Outgoing {
287 user_id: db_contact.user_id_a,
288 });
289 }
290 }
291
292 contacts.sort_unstable_by_key(|contact| contact.user_id());
293
294 Ok(contacts)
295 })
296 .await
297 }
298
299 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
300 self.transact(|tx| async move {
301 let (id_a, id_b) = if user_id_1 < user_id_2 {
302 (user_id_1, user_id_2)
303 } else {
304 (user_id_2, user_id_1)
305 };
306
307 Ok(contact::Entity::find()
308 .filter(
309 contact::Column::UserIdA
310 .eq(id_a)
311 .and(contact::Column::UserIdB.eq(id_b))
312 .and(contact::Column::Accepted.eq(true)),
313 )
314 .one(&tx)
315 .await?
316 .is_some())
317 })
318 .await
319 }
320
321 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
322 self.transact(|mut tx| async move {
323 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
324 (sender_id, receiver_id, true)
325 } else {
326 (receiver_id, sender_id, false)
327 };
328
329 let rows_affected = contact::Entity::insert(contact::ActiveModel {
330 user_id_a: ActiveValue::set(id_a),
331 user_id_b: ActiveValue::set(id_b),
332 a_to_b: ActiveValue::set(a_to_b),
333 accepted: ActiveValue::set(false),
334 should_notify: ActiveValue::set(true),
335 ..Default::default()
336 })
337 .on_conflict(
338 OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB])
339 .values([
340 (contact::Column::Accepted, true.into()),
341 (contact::Column::ShouldNotify, false.into()),
342 ])
343 .action_and_where(
344 contact::Column::Accepted.eq(false).and(
345 contact::Column::AToB
346 .eq(a_to_b)
347 .and(contact::Column::UserIdA.eq(id_b))
348 .or(contact::Column::AToB
349 .ne(a_to_b)
350 .and(contact::Column::UserIdA.eq(id_a))),
351 ),
352 )
353 .to_owned(),
354 )
355 .exec_without_returning(&tx)
356 .await?;
357
358 if rows_affected == 1 {
359 tx.commit().await?;
360 Ok(())
361 } else {
362 Err(anyhow!("contact already requested"))?
363 }
364 })
365 .await
366 }
367
368 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
369 self.transact(|mut tx| async move {
370 // let (id_a, id_b) = if responder_id < requester_id {
371 // (responder_id, requester_id)
372 // } else {
373 // (requester_id, responder_id)
374 // };
375 // let query = "
376 // DELETE FROM contacts
377 // WHERE user_id_a = $1 AND user_id_b = $2;
378 // ";
379 // let result = sqlx::query(query)
380 // .bind(id_a.0)
381 // .bind(id_b.0)
382 // .execute(&mut tx)
383 // .await?;
384
385 // if result.rows_affected() == 1 {
386 // tx.commit().await?;
387 // Ok(())
388 // } else {
389 // Err(anyhow!("no such contact"))?
390 // }
391 todo!()
392 })
393 .await
394 }
395
396 pub async fn dismiss_contact_notification(
397 &self,
398 user_id: UserId,
399 contact_user_id: UserId,
400 ) -> Result<()> {
401 self.transact(|tx| async move {
402 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
403 (user_id, contact_user_id, true)
404 } else {
405 (contact_user_id, user_id, false)
406 };
407
408 let result = contact::Entity::update_many()
409 .set(contact::ActiveModel {
410 should_notify: ActiveValue::set(false),
411 ..Default::default()
412 })
413 .filter(
414 contact::Column::UserIdA
415 .eq(id_a)
416 .and(contact::Column::UserIdB.eq(id_b))
417 .and(
418 contact::Column::AToB
419 .eq(a_to_b)
420 .and(contact::Column::Accepted.eq(true))
421 .or(contact::Column::AToB
422 .ne(a_to_b)
423 .and(contact::Column::Accepted.eq(false))),
424 ),
425 )
426 .exec(&tx)
427 .await?;
428 if result.rows_affected == 0 {
429 Err(anyhow!("no such contact request"))?
430 } else {
431 tx.commit().await?;
432 Ok(())
433 }
434 })
435 .await
436 }
437
438 pub async fn respond_to_contact_request(
439 &self,
440 responder_id: UserId,
441 requester_id: UserId,
442 accept: bool,
443 ) -> Result<()> {
444 self.transact(|tx| async move {
445 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
446 (responder_id, requester_id, false)
447 } else {
448 (requester_id, responder_id, true)
449 };
450 let rows_affected = if accept {
451 let result = contact::Entity::update_many()
452 .set(contact::ActiveModel {
453 accepted: ActiveValue::set(true),
454 should_notify: ActiveValue::set(true),
455 ..Default::default()
456 })
457 .filter(
458 contact::Column::UserIdA
459 .eq(id_a)
460 .and(contact::Column::UserIdB.eq(id_b))
461 .and(contact::Column::AToB.eq(a_to_b)),
462 )
463 .exec(&tx)
464 .await?;
465 result.rows_affected
466 } else {
467 let result = contact::Entity::delete_many()
468 .filter(
469 contact::Column::UserIdA
470 .eq(id_a)
471 .and(contact::Column::UserIdB.eq(id_b))
472 .and(contact::Column::AToB.eq(a_to_b))
473 .and(contact::Column::Accepted.eq(false)),
474 )
475 .exec(&tx)
476 .await?;
477
478 result.rows_affected
479 };
480
481 if rows_affected == 1 {
482 tx.commit().await?;
483 Ok(())
484 } else {
485 Err(anyhow!("no such contact request"))?
486 }
487 })
488 .await
489 }
490
491 // projects
492
493 pub async fn share_project(
494 &self,
495 room_id: RoomId,
496 connection_id: ConnectionId,
497 worktrees: &[proto::WorktreeMetadata],
498 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
499 self.transact(|tx| async move {
500 let participant = room_participant::Entity::find()
501 .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
502 .one(&tx)
503 .await?
504 .ok_or_else(|| anyhow!("could not find participant"))?;
505 if participant.room_id != room_id {
506 return Err(anyhow!("shared project on unexpected room"))?;
507 }
508
509 let project = project::ActiveModel {
510 room_id: ActiveValue::set(participant.room_id),
511 host_user_id: ActiveValue::set(participant.user_id),
512 host_connection_id: ActiveValue::set(connection_id.0 as i32),
513 ..Default::default()
514 }
515 .insert(&tx)
516 .await?;
517
518 worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
519 id: ActiveValue::set(worktree.id as i32),
520 project_id: ActiveValue::set(project.id),
521 abs_path: ActiveValue::set(worktree.abs_path.clone()),
522 root_name: ActiveValue::set(worktree.root_name.clone()),
523 visible: ActiveValue::set(worktree.visible),
524 scan_id: ActiveValue::set(0),
525 is_complete: ActiveValue::set(false),
526 }))
527 .exec(&tx)
528 .await?;
529
530 project_collaborator::ActiveModel {
531 project_id: ActiveValue::set(project.id),
532 connection_id: ActiveValue::set(connection_id.0 as i32),
533 user_id: ActiveValue::set(participant.user_id),
534 replica_id: ActiveValue::set(0),
535 is_host: ActiveValue::set(true),
536 ..Default::default()
537 }
538 .insert(&tx)
539 .await?;
540
541 let room = self.get_room(room_id, &tx).await?;
542 self.commit_room_transaction(room_id, tx, (project.id, room))
543 .await
544 })
545 .await
546 }
547
548 async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
549 let db_room = room::Entity::find_by_id(room_id)
550 .one(tx)
551 .await?
552 .ok_or_else(|| anyhow!("could not find room"))?;
553
554 let mut db_participants = db_room
555 .find_related(room_participant::Entity)
556 .stream(tx)
557 .await?;
558 let mut participants = HashMap::default();
559 let mut pending_participants = Vec::new();
560 while let Some(db_participant) = db_participants.next().await {
561 let db_participant = db_participant?;
562 if let Some(answering_connection_id) = db_participant.answering_connection_id {
563 let location = match (
564 db_participant.location_kind,
565 db_participant.location_project_id,
566 ) {
567 (Some(0), Some(project_id)) => {
568 Some(proto::participant_location::Variant::SharedProject(
569 proto::participant_location::SharedProject {
570 id: project_id.to_proto(),
571 },
572 ))
573 }
574 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
575 Default::default(),
576 )),
577 _ => Some(proto::participant_location::Variant::External(
578 Default::default(),
579 )),
580 };
581 participants.insert(
582 answering_connection_id,
583 proto::Participant {
584 user_id: db_participant.user_id.to_proto(),
585 peer_id: answering_connection_id as u32,
586 projects: Default::default(),
587 location: Some(proto::ParticipantLocation { variant: location }),
588 },
589 );
590 } else {
591 pending_participants.push(proto::PendingParticipant {
592 user_id: db_participant.user_id.to_proto(),
593 calling_user_id: db_participant.calling_user_id.to_proto(),
594 initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
595 });
596 }
597 }
598
599 let mut db_projects = db_room
600 .find_related(project::Entity)
601 .find_with_related(worktree::Entity)
602 .stream(tx)
603 .await?;
604
605 while let Some(row) = db_projects.next().await {
606 let (db_project, db_worktree) = row?;
607 if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
608 let project = if let Some(project) = participant
609 .projects
610 .iter_mut()
611 .find(|project| project.id == db_project.id.to_proto())
612 {
613 project
614 } else {
615 participant.projects.push(proto::ParticipantProject {
616 id: db_project.id.to_proto(),
617 worktree_root_names: Default::default(),
618 });
619 participant.projects.last_mut().unwrap()
620 };
621
622 if let Some(db_worktree) = db_worktree {
623 project.worktree_root_names.push(db_worktree.root_name);
624 }
625 }
626 }
627
628 Ok(proto::Room {
629 id: db_room.id.to_proto(),
630 live_kit_room: db_room.live_kit_room,
631 participants: participants.into_values().collect(),
632 pending_participants,
633 })
634 }
635
636 async fn commit_room_transaction<T>(
637 &self,
638 room_id: RoomId,
639 tx: DatabaseTransaction,
640 data: T,
641 ) -> Result<RoomGuard<T>> {
642 let lock = self.rooms.entry(room_id).or_default().clone();
643 let _guard = lock.lock_owned().await;
644 tx.commit().await?;
645 Ok(RoomGuard {
646 data,
647 _guard,
648 _not_send: PhantomData,
649 })
650 }
651
652 pub async fn create_access_token_hash(
653 &self,
654 user_id: UserId,
655 access_token_hash: &str,
656 max_access_token_count: usize,
657 ) -> Result<()> {
658 self.transact(|tx| async {
659 let tx = tx;
660
661 access_token::ActiveModel {
662 user_id: ActiveValue::set(user_id),
663 hash: ActiveValue::set(access_token_hash.into()),
664 ..Default::default()
665 }
666 .insert(&tx)
667 .await?;
668
669 access_token::Entity::delete_many()
670 .filter(
671 access_token::Column::Id.in_subquery(
672 Query::select()
673 .column(access_token::Column::Id)
674 .from(access_token::Entity)
675 .and_where(access_token::Column::UserId.eq(user_id))
676 .order_by(access_token::Column::Id, sea_orm::Order::Desc)
677 .limit(10000)
678 .offset(max_access_token_count as u64)
679 .to_owned(),
680 ),
681 )
682 .exec(&tx)
683 .await?;
684 tx.commit().await?;
685 Ok(())
686 })
687 .await
688 }
689
690 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
691 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
692 enum QueryAs {
693 Hash,
694 }
695
696 self.transact(|tx| async move {
697 Ok(access_token::Entity::find()
698 .select_only()
699 .column(access_token::Column::Hash)
700 .filter(access_token::Column::UserId.eq(user_id))
701 .order_by_desc(access_token::Column::Id)
702 .into_values::<_, QueryAs>()
703 .all(&tx)
704 .await?)
705 })
706 .await
707 }
708
709 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
710 where
711 F: Send + Fn(DatabaseTransaction) -> Fut,
712 Fut: Send + Future<Output = Result<T>>,
713 {
714 let body = async {
715 loop {
716 let tx = self.pool.begin().await?;
717
718 // In Postgres, serializable transactions are opt-in
719 if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() {
720 tx.execute(sea_orm::Statement::from_string(
721 sea_orm::DatabaseBackend::Postgres,
722 "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
723 ))
724 .await?;
725 }
726
727 match f(tx).await {
728 Ok(result) => return Ok(result),
729 Err(error) => match error {
730 Error::Database2(
731 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
732 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
733 ) if error
734 .as_database_error()
735 .and_then(|error| error.code())
736 .as_deref()
737 == Some("40001") =>
738 {
739 // Retry (don't break the loop)
740 }
741 error @ _ => return Err(error),
742 },
743 }
744 }
745 };
746
747 #[cfg(test)]
748 {
749 if let Some(background) = self.background.as_ref() {
750 background.simulate_random_delay().await;
751 }
752
753 self.runtime.as_ref().unwrap().block_on(body)
754 }
755
756 #[cfg(not(test))]
757 {
758 body.await
759 }
760 }
761}
762
763pub struct RoomGuard<T> {
764 data: T,
765 _guard: OwnedMutexGuard<()>,
766 _not_send: PhantomData<Rc<()>>,
767}
768
769impl<T> Deref for RoomGuard<T> {
770 type Target = T;
771
772 fn deref(&self) -> &T {
773 &self.data
774 }
775}
776
777impl<T> DerefMut for RoomGuard<T> {
778 fn deref_mut(&mut self) -> &mut T {
779 &mut self.data
780 }
781}
782
783#[derive(Debug, Serialize, Deserialize)]
784pub struct NewUserParams {
785 pub github_login: String,
786 pub github_user_id: i32,
787 pub invite_count: i32,
788}
789
790#[derive(Debug)]
791pub struct NewUserResult {
792 pub user_id: UserId,
793 pub metrics_id: String,
794 pub inviting_user_id: Option<UserId>,
795 pub signup_device_id: Option<String>,
796}
797
798fn random_invite_code() -> String {
799 nanoid::nanoid!(16)
800}
801
802fn random_email_confirmation_code() -> String {
803 nanoid::nanoid!(64)
804}
805
806macro_rules! id_type {
807 ($name:ident) => {
808 #[derive(
809 Clone,
810 Copy,
811 Debug,
812 Default,
813 PartialEq,
814 Eq,
815 PartialOrd,
816 Ord,
817 Hash,
818 sqlx::Type,
819 Serialize,
820 Deserialize,
821 )]
822 #[sqlx(transparent)]
823 #[serde(transparent)]
824 pub struct $name(pub i32);
825
826 impl $name {
827 #[allow(unused)]
828 pub const MAX: Self = Self(i32::MAX);
829
830 #[allow(unused)]
831 pub fn from_proto(value: u64) -> Self {
832 Self(value as i32)
833 }
834
835 #[allow(unused)]
836 pub fn to_proto(self) -> u64 {
837 self.0 as u64
838 }
839 }
840
841 impl std::fmt::Display for $name {
842 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
843 self.0.fmt(f)
844 }
845 }
846
847 impl From<$name> for sea_query::Value {
848 fn from(value: $name) -> Self {
849 sea_query::Value::Int(Some(value.0))
850 }
851 }
852
853 impl sea_orm::TryGetable for $name {
854 fn try_get(
855 res: &sea_orm::QueryResult,
856 pre: &str,
857 col: &str,
858 ) -> Result<Self, sea_orm::TryGetError> {
859 Ok(Self(i32::try_get(res, pre, col)?))
860 }
861 }
862
863 impl sea_query::ValueType for $name {
864 fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
865 match v {
866 Value::TinyInt(Some(int)) => {
867 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
868 }
869 Value::SmallInt(Some(int)) => {
870 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
871 }
872 Value::Int(Some(int)) => {
873 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
874 }
875 Value::BigInt(Some(int)) => {
876 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
877 }
878 Value::TinyUnsigned(Some(int)) => {
879 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
880 }
881 Value::SmallUnsigned(Some(int)) => {
882 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
883 }
884 Value::Unsigned(Some(int)) => {
885 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
886 }
887 Value::BigUnsigned(Some(int)) => {
888 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
889 }
890 _ => Err(sea_query::ValueTypeErr),
891 }
892 }
893
894 fn type_name() -> String {
895 stringify!($name).into()
896 }
897
898 fn array_type() -> sea_query::ArrayType {
899 sea_query::ArrayType::Int
900 }
901
902 fn column_type() -> sea_query::ColumnType {
903 sea_query::ColumnType::Integer(None)
904 }
905 }
906
907 impl sea_orm::TryFromU64 for $name {
908 fn try_from_u64(n: u64) -> Result<Self, DbErr> {
909 Ok(Self(n.try_into().map_err(|_| {
910 DbErr::ConvertFromU64(concat!(
911 "error converting ",
912 stringify!($name),
913 " to u64"
914 ))
915 })?))
916 }
917 }
918
919 impl sea_query::Nullable for $name {
920 fn null() -> Value {
921 Value::Int(None)
922 }
923 }
924 };
925}
926
927id_type!(AccessTokenId);
928id_type!(ContactId);
929id_type!(UserId);
930id_type!(RoomId);
931id_type!(RoomParticipantId);
932id_type!(ProjectId);
933id_type!(ProjectCollaboratorId);
934id_type!(WorktreeId);
935
936#[cfg(test)]
937pub use test::*;
938
939#[cfg(test)]
940mod test {
941 use super::*;
942 use gpui::executor::Background;
943 use lazy_static::lazy_static;
944 use parking_lot::Mutex;
945 use rand::prelude::*;
946 use sea_orm::ConnectionTrait;
947 use sqlx::migrate::MigrateDatabase;
948 use std::sync::Arc;
949
950 pub struct TestDb {
951 pub db: Option<Arc<Database>>,
952 pub connection: Option<sqlx::AnyConnection>,
953 }
954
955 impl TestDb {
956 pub fn sqlite(background: Arc<Background>) -> Self {
957 let url = format!("sqlite::memory:");
958 let runtime = tokio::runtime::Builder::new_current_thread()
959 .enable_io()
960 .enable_time()
961 .build()
962 .unwrap();
963
964 let mut db = runtime.block_on(async {
965 let mut options = ConnectOptions::new(url);
966 options.max_connections(5);
967 let db = Database::new(options).await.unwrap();
968 let sql = include_str!(concat!(
969 env!("CARGO_MANIFEST_DIR"),
970 "/migrations.sqlite/20221109000000_test_schema.sql"
971 ));
972 db.pool
973 .execute(sea_orm::Statement::from_string(
974 db.pool.get_database_backend(),
975 sql.into(),
976 ))
977 .await
978 .unwrap();
979 db
980 });
981
982 db.background = Some(background);
983 db.runtime = Some(runtime);
984
985 Self {
986 db: Some(Arc::new(db)),
987 connection: None,
988 }
989 }
990
991 pub fn postgres(background: Arc<Background>) -> Self {
992 lazy_static! {
993 static ref LOCK: Mutex<()> = Mutex::new(());
994 }
995
996 let _guard = LOCK.lock();
997 let mut rng = StdRng::from_entropy();
998 let url = format!(
999 "postgres://postgres@localhost/zed-test-{}",
1000 rng.gen::<u128>()
1001 );
1002 let runtime = tokio::runtime::Builder::new_current_thread()
1003 .enable_io()
1004 .enable_time()
1005 .build()
1006 .unwrap();
1007
1008 let mut db = runtime.block_on(async {
1009 sqlx::Postgres::create_database(&url)
1010 .await
1011 .expect("failed to create test db");
1012 let mut options = ConnectOptions::new(url);
1013 options
1014 .max_connections(5)
1015 .idle_timeout(Duration::from_secs(0));
1016 let db = Database::new(options).await.unwrap();
1017 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1018 db.migrate(Path::new(migrations_path), false).await.unwrap();
1019 db
1020 });
1021
1022 db.background = Some(background);
1023 db.runtime = Some(runtime);
1024
1025 Self {
1026 db: Some(Arc::new(db)),
1027 connection: None,
1028 }
1029 }
1030
1031 pub fn db(&self) -> &Arc<Database> {
1032 self.db.as_ref().unwrap()
1033 }
1034 }
1035
1036 impl Drop for TestDb {
1037 fn drop(&mut self) {
1038 let db = self.db.take().unwrap();
1039 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
1040 db.runtime.as_ref().unwrap().block_on(async {
1041 use util::ResultExt;
1042 let query = "
1043 SELECT pg_terminate_backend(pg_stat_activity.pid)
1044 FROM pg_stat_activity
1045 WHERE
1046 pg_stat_activity.datname = current_database() AND
1047 pid <> pg_backend_pid();
1048 ";
1049 db.pool
1050 .execute(sea_orm::Statement::from_string(
1051 db.pool.get_database_backend(),
1052 query.into(),
1053 ))
1054 .await
1055 .log_err();
1056 sqlx::Postgres::drop_database(db.options.get_url())
1057 .await
1058 .log_err();
1059 })
1060 }
1061 }
1062 }
1063}