1mod access_token;
2mod contact;
3mod project;
4mod project_collaborator;
5mod room;
6mod room_participant;
7#[cfg(test)]
8mod tests;
9mod user;
10mod worktree;
11
12use crate::{Error, Result};
13use anyhow::anyhow;
14use collections::HashMap;
15use dashmap::DashMap;
16use futures::StreamExt;
17use rpc::{proto, ConnectionId};
18use sea_orm::{
19 entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
20 TransactionTrait,
21};
22use sea_orm::{
23 ActiveValue, ConnectionTrait, FromQueryResult, IntoActiveModel, JoinType, QueryOrder,
24 QuerySelect,
25};
26use sea_query::{Alias, Expr, OnConflict, Query};
27use serde::{Deserialize, Serialize};
28use sqlx::migrate::{Migrate, Migration, MigrationSource};
29use sqlx::Connection;
30use std::ops::{Deref, DerefMut};
31use std::path::Path;
32use std::time::Duration;
33use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
34use tokio::sync::{Mutex, OwnedMutexGuard};
35
36pub use contact::Contact;
37pub use user::Model as User;
38
39pub struct Database {
40 options: ConnectOptions,
41 pool: DatabaseConnection,
42 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
43 #[cfg(test)]
44 background: Option<std::sync::Arc<gpui::executor::Background>>,
45 #[cfg(test)]
46 runtime: Option<tokio::runtime::Runtime>,
47}
48
49impl Database {
50 pub async fn new(options: ConnectOptions) -> Result<Self> {
51 Ok(Self {
52 options: options.clone(),
53 pool: sea_orm::Database::connect(options).await?,
54 rooms: DashMap::with_capacity(16384),
55 #[cfg(test)]
56 background: None,
57 #[cfg(test)]
58 runtime: None,
59 })
60 }
61
62 pub async fn migrate(
63 &self,
64 migrations_path: &Path,
65 ignore_checksum_mismatch: bool,
66 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
67 let migrations = MigrationSource::resolve(migrations_path)
68 .await
69 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
70
71 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
72
73 connection.ensure_migrations_table().await?;
74 let applied_migrations: HashMap<_, _> = connection
75 .list_applied_migrations()
76 .await?
77 .into_iter()
78 .map(|m| (m.version, m))
79 .collect();
80
81 let mut new_migrations = Vec::new();
82 for migration in migrations {
83 match applied_migrations.get(&migration.version) {
84 Some(applied_migration) => {
85 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
86 {
87 Err(anyhow!(
88 "checksum mismatch for applied migration {}",
89 migration.description
90 ))?;
91 }
92 }
93 None => {
94 let elapsed = connection.apply(&migration).await?;
95 new_migrations.push((migration, elapsed));
96 }
97 }
98 }
99
100 Ok(new_migrations)
101 }
102
103 // users
104
105 pub async fn create_user(
106 &self,
107 email_address: &str,
108 admin: bool,
109 params: NewUserParams,
110 ) -> Result<NewUserResult> {
111 self.transact(|tx| async {
112 let user = user::Entity::insert(user::ActiveModel {
113 email_address: ActiveValue::set(Some(email_address.into())),
114 github_login: ActiveValue::set(params.github_login.clone()),
115 github_user_id: ActiveValue::set(Some(params.github_user_id)),
116 admin: ActiveValue::set(admin),
117 metrics_id: ActiveValue::set(Uuid::new_v4()),
118 ..Default::default()
119 })
120 .on_conflict(
121 OnConflict::column(user::Column::GithubLogin)
122 .update_column(user::Column::GithubLogin)
123 .to_owned(),
124 )
125 .exec_with_returning(&tx)
126 .await?;
127
128 tx.commit().await?;
129
130 Ok(NewUserResult {
131 user_id: user.id,
132 metrics_id: user.metrics_id.to_string(),
133 signup_device_id: None,
134 inviting_user_id: None,
135 })
136 })
137 .await
138 }
139
140 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
141 self.transact(|tx| async {
142 let tx = tx;
143 Ok(user::Entity::find()
144 .filter(user::Column::Id.is_in(ids.iter().copied()))
145 .all(&tx)
146 .await?)
147 })
148 .await
149 }
150
151 pub async fn get_user_by_github_account(
152 &self,
153 github_login: &str,
154 github_user_id: Option<i32>,
155 ) -> Result<Option<User>> {
156 self.transact(|tx| async {
157 let tx = tx;
158 if let Some(github_user_id) = github_user_id {
159 if let Some(user_by_github_user_id) = user::Entity::find()
160 .filter(user::Column::GithubUserId.eq(github_user_id))
161 .one(&tx)
162 .await?
163 {
164 let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
165 user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
166 Ok(Some(user_by_github_user_id.update(&tx).await?))
167 } else if let Some(user_by_github_login) = user::Entity::find()
168 .filter(user::Column::GithubLogin.eq(github_login))
169 .one(&tx)
170 .await?
171 {
172 let mut user_by_github_login = user_by_github_login.into_active_model();
173 user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
174 Ok(Some(user_by_github_login.update(&tx).await?))
175 } else {
176 Ok(None)
177 }
178 } else {
179 Ok(user::Entity::find()
180 .filter(user::Column::GithubLogin.eq(github_login))
181 .one(&tx)
182 .await?)
183 }
184 })
185 .await
186 }
187
188 pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
189 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
190 enum QueryAs {
191 MetricsId,
192 }
193
194 self.transact(|tx| async move {
195 let metrics_id: Uuid = user::Entity::find_by_id(id)
196 .select_only()
197 .column(user::Column::MetricsId)
198 .into_values::<_, QueryAs>()
199 .one(&tx)
200 .await?
201 .ok_or_else(|| anyhow!("could not find user"))?;
202 Ok(metrics_id.to_string())
203 })
204 .await
205 }
206
207 // contacts
208
209 pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
210 #[derive(Debug, FromQueryResult)]
211 struct ContactWithUserBusyStatuses {
212 user_id_a: UserId,
213 user_id_b: UserId,
214 a_to_b: bool,
215 accepted: bool,
216 should_notify: bool,
217 user_a_busy: bool,
218 user_b_busy: bool,
219 }
220
221 self.transact(|tx| async move {
222 let user_a_participant = Alias::new("user_a_participant");
223 let user_b_participant = Alias::new("user_b_participant");
224 let mut db_contacts = contact::Entity::find()
225 .column_as(
226 Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
227 .is_not_null(),
228 "user_a_busy",
229 )
230 .column_as(
231 Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
232 .is_not_null(),
233 "user_b_busy",
234 )
235 .filter(
236 contact::Column::UserIdA
237 .eq(user_id)
238 .or(contact::Column::UserIdB.eq(user_id)),
239 )
240 .join_as(
241 JoinType::LeftJoin,
242 contact::Relation::UserARoomParticipant.def(),
243 user_a_participant,
244 )
245 .join_as(
246 JoinType::LeftJoin,
247 contact::Relation::UserBRoomParticipant.def(),
248 user_b_participant,
249 )
250 .into_model::<ContactWithUserBusyStatuses>()
251 .stream(&tx)
252 .await?;
253
254 let mut contacts = Vec::new();
255 while let Some(db_contact) = db_contacts.next().await {
256 let db_contact = db_contact?;
257 if db_contact.user_id_a == user_id {
258 if db_contact.accepted {
259 contacts.push(Contact::Accepted {
260 user_id: db_contact.user_id_b,
261 should_notify: db_contact.should_notify && db_contact.a_to_b,
262 busy: db_contact.user_b_busy,
263 });
264 } else if db_contact.a_to_b {
265 contacts.push(Contact::Outgoing {
266 user_id: db_contact.user_id_b,
267 })
268 } else {
269 contacts.push(Contact::Incoming {
270 user_id: db_contact.user_id_b,
271 should_notify: db_contact.should_notify,
272 });
273 }
274 } else if db_contact.accepted {
275 contacts.push(Contact::Accepted {
276 user_id: db_contact.user_id_a,
277 should_notify: db_contact.should_notify && !db_contact.a_to_b,
278 busy: db_contact.user_a_busy,
279 });
280 } else if db_contact.a_to_b {
281 contacts.push(Contact::Incoming {
282 user_id: db_contact.user_id_a,
283 should_notify: db_contact.should_notify,
284 });
285 } else {
286 contacts.push(Contact::Outgoing {
287 user_id: db_contact.user_id_a,
288 });
289 }
290 }
291
292 contacts.sort_unstable_by_key(|contact| contact.user_id());
293
294 Ok(contacts)
295 })
296 .await
297 }
298
299 pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
300 self.transact(|tx| async move {
301 let (id_a, id_b) = if user_id_1 < user_id_2 {
302 (user_id_1, user_id_2)
303 } else {
304 (user_id_2, user_id_1)
305 };
306
307 Ok(contact::Entity::find()
308 .filter(
309 contact::Column::UserIdA
310 .eq(id_a)
311 .and(contact::Column::UserIdB.eq(id_b))
312 .and(contact::Column::Accepted.eq(true)),
313 )
314 .one(&tx)
315 .await?
316 .is_some())
317 })
318 .await
319 }
320
321 pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
322 self.transact(|mut tx| async move {
323 let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
324 (sender_id, receiver_id, true)
325 } else {
326 (receiver_id, sender_id, false)
327 };
328
329 let rows_affected = contact::Entity::insert(contact::ActiveModel {
330 user_id_a: ActiveValue::set(id_a),
331 user_id_b: ActiveValue::set(id_b),
332 a_to_b: ActiveValue::set(a_to_b),
333 accepted: ActiveValue::set(false),
334 should_notify: ActiveValue::set(true),
335 ..Default::default()
336 })
337 .on_conflict(
338 OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB])
339 .values([
340 (contact::Column::Accepted, true.into()),
341 (contact::Column::ShouldNotify, false.into()),
342 ])
343 .action_and_where(
344 contact::Column::Accepted.eq(false).and(
345 contact::Column::AToB
346 .eq(a_to_b)
347 .and(contact::Column::UserIdA.eq(id_b))
348 .or(contact::Column::AToB
349 .ne(a_to_b)
350 .and(contact::Column::UserIdA.eq(id_a))),
351 ),
352 )
353 .to_owned(),
354 )
355 .exec_without_returning(&tx)
356 .await?;
357
358 if rows_affected == 1 {
359 tx.commit().await?;
360 Ok(())
361 } else {
362 Err(anyhow!("contact already requested"))?
363 }
364 })
365 .await
366 }
367
368 pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
369 self.transact(|tx| async move {
370 let (id_a, id_b) = if responder_id < requester_id {
371 (responder_id, requester_id)
372 } else {
373 (requester_id, responder_id)
374 };
375
376 let result = contact::Entity::delete_many()
377 .filter(
378 contact::Column::UserIdA
379 .eq(id_a)
380 .and(contact::Column::UserIdB.eq(id_b)),
381 )
382 .exec(&tx)
383 .await?;
384
385 if result.rows_affected == 1 {
386 tx.commit().await?;
387 Ok(())
388 } else {
389 Err(anyhow!("no such contact"))?
390 }
391 })
392 .await
393 }
394
395 pub async fn dismiss_contact_notification(
396 &self,
397 user_id: UserId,
398 contact_user_id: UserId,
399 ) -> Result<()> {
400 self.transact(|tx| async move {
401 let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
402 (user_id, contact_user_id, true)
403 } else {
404 (contact_user_id, user_id, false)
405 };
406
407 let result = contact::Entity::update_many()
408 .set(contact::ActiveModel {
409 should_notify: ActiveValue::set(false),
410 ..Default::default()
411 })
412 .filter(
413 contact::Column::UserIdA
414 .eq(id_a)
415 .and(contact::Column::UserIdB.eq(id_b))
416 .and(
417 contact::Column::AToB
418 .eq(a_to_b)
419 .and(contact::Column::Accepted.eq(true))
420 .or(contact::Column::AToB
421 .ne(a_to_b)
422 .and(contact::Column::Accepted.eq(false))),
423 ),
424 )
425 .exec(&tx)
426 .await?;
427 if result.rows_affected == 0 {
428 Err(anyhow!("no such contact request"))?
429 } else {
430 tx.commit().await?;
431 Ok(())
432 }
433 })
434 .await
435 }
436
437 pub async fn respond_to_contact_request(
438 &self,
439 responder_id: UserId,
440 requester_id: UserId,
441 accept: bool,
442 ) -> Result<()> {
443 self.transact(|tx| async move {
444 let (id_a, id_b, a_to_b) = if responder_id < requester_id {
445 (responder_id, requester_id, false)
446 } else {
447 (requester_id, responder_id, true)
448 };
449 let rows_affected = if accept {
450 let result = contact::Entity::update_many()
451 .set(contact::ActiveModel {
452 accepted: ActiveValue::set(true),
453 should_notify: ActiveValue::set(true),
454 ..Default::default()
455 })
456 .filter(
457 contact::Column::UserIdA
458 .eq(id_a)
459 .and(contact::Column::UserIdB.eq(id_b))
460 .and(contact::Column::AToB.eq(a_to_b)),
461 )
462 .exec(&tx)
463 .await?;
464 result.rows_affected
465 } else {
466 let result = contact::Entity::delete_many()
467 .filter(
468 contact::Column::UserIdA
469 .eq(id_a)
470 .and(contact::Column::UserIdB.eq(id_b))
471 .and(contact::Column::AToB.eq(a_to_b))
472 .and(contact::Column::Accepted.eq(false)),
473 )
474 .exec(&tx)
475 .await?;
476
477 result.rows_affected
478 };
479
480 if rows_affected == 1 {
481 tx.commit().await?;
482 Ok(())
483 } else {
484 Err(anyhow!("no such contact request"))?
485 }
486 })
487 .await
488 }
489
490 pub fn fuzzy_like_string(string: &str) -> String {
491 let mut result = String::with_capacity(string.len() * 2 + 1);
492 for c in string.chars() {
493 if c.is_alphanumeric() {
494 result.push('%');
495 result.push(c);
496 }
497 }
498 result.push('%');
499 result
500 }
501
502 // projects
503
504 pub async fn share_project(
505 &self,
506 room_id: RoomId,
507 connection_id: ConnectionId,
508 worktrees: &[proto::WorktreeMetadata],
509 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
510 self.transact(|tx| async move {
511 let participant = room_participant::Entity::find()
512 .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
513 .one(&tx)
514 .await?
515 .ok_or_else(|| anyhow!("could not find participant"))?;
516 if participant.room_id != room_id {
517 return Err(anyhow!("shared project on unexpected room"))?;
518 }
519
520 let project = project::ActiveModel {
521 room_id: ActiveValue::set(participant.room_id),
522 host_user_id: ActiveValue::set(participant.user_id),
523 host_connection_id: ActiveValue::set(connection_id.0 as i32),
524 ..Default::default()
525 }
526 .insert(&tx)
527 .await?;
528
529 worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
530 id: ActiveValue::set(worktree.id as i32),
531 project_id: ActiveValue::set(project.id),
532 abs_path: ActiveValue::set(worktree.abs_path.clone()),
533 root_name: ActiveValue::set(worktree.root_name.clone()),
534 visible: ActiveValue::set(worktree.visible),
535 scan_id: ActiveValue::set(0),
536 is_complete: ActiveValue::set(false),
537 }))
538 .exec(&tx)
539 .await?;
540
541 project_collaborator::ActiveModel {
542 project_id: ActiveValue::set(project.id),
543 connection_id: ActiveValue::set(connection_id.0 as i32),
544 user_id: ActiveValue::set(participant.user_id),
545 replica_id: ActiveValue::set(0),
546 is_host: ActiveValue::set(true),
547 ..Default::default()
548 }
549 .insert(&tx)
550 .await?;
551
552 let room = self.get_room(room_id, &tx).await?;
553 self.commit_room_transaction(room_id, tx, (project.id, room))
554 .await
555 })
556 .await
557 }
558
559 async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
560 let db_room = room::Entity::find_by_id(room_id)
561 .one(tx)
562 .await?
563 .ok_or_else(|| anyhow!("could not find room"))?;
564
565 let mut db_participants = db_room
566 .find_related(room_participant::Entity)
567 .stream(tx)
568 .await?;
569 let mut participants = HashMap::default();
570 let mut pending_participants = Vec::new();
571 while let Some(db_participant) = db_participants.next().await {
572 let db_participant = db_participant?;
573 if let Some(answering_connection_id) = db_participant.answering_connection_id {
574 let location = match (
575 db_participant.location_kind,
576 db_participant.location_project_id,
577 ) {
578 (Some(0), Some(project_id)) => {
579 Some(proto::participant_location::Variant::SharedProject(
580 proto::participant_location::SharedProject {
581 id: project_id.to_proto(),
582 },
583 ))
584 }
585 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
586 Default::default(),
587 )),
588 _ => Some(proto::participant_location::Variant::External(
589 Default::default(),
590 )),
591 };
592 participants.insert(
593 answering_connection_id,
594 proto::Participant {
595 user_id: db_participant.user_id.to_proto(),
596 peer_id: answering_connection_id as u32,
597 projects: Default::default(),
598 location: Some(proto::ParticipantLocation { variant: location }),
599 },
600 );
601 } else {
602 pending_participants.push(proto::PendingParticipant {
603 user_id: db_participant.user_id.to_proto(),
604 calling_user_id: db_participant.calling_user_id.to_proto(),
605 initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
606 });
607 }
608 }
609
610 let mut db_projects = db_room
611 .find_related(project::Entity)
612 .find_with_related(worktree::Entity)
613 .stream(tx)
614 .await?;
615
616 while let Some(row) = db_projects.next().await {
617 let (db_project, db_worktree) = row?;
618 if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
619 let project = if let Some(project) = participant
620 .projects
621 .iter_mut()
622 .find(|project| project.id == db_project.id.to_proto())
623 {
624 project
625 } else {
626 participant.projects.push(proto::ParticipantProject {
627 id: db_project.id.to_proto(),
628 worktree_root_names: Default::default(),
629 });
630 participant.projects.last_mut().unwrap()
631 };
632
633 if let Some(db_worktree) = db_worktree {
634 project.worktree_root_names.push(db_worktree.root_name);
635 }
636 }
637 }
638
639 Ok(proto::Room {
640 id: db_room.id.to_proto(),
641 live_kit_room: db_room.live_kit_room,
642 participants: participants.into_values().collect(),
643 pending_participants,
644 })
645 }
646
647 async fn commit_room_transaction<T>(
648 &self,
649 room_id: RoomId,
650 tx: DatabaseTransaction,
651 data: T,
652 ) -> Result<RoomGuard<T>> {
653 let lock = self.rooms.entry(room_id).or_default().clone();
654 let _guard = lock.lock_owned().await;
655 tx.commit().await?;
656 Ok(RoomGuard {
657 data,
658 _guard,
659 _not_send: PhantomData,
660 })
661 }
662
663 pub async fn create_access_token_hash(
664 &self,
665 user_id: UserId,
666 access_token_hash: &str,
667 max_access_token_count: usize,
668 ) -> Result<()> {
669 self.transact(|tx| async {
670 let tx = tx;
671
672 access_token::ActiveModel {
673 user_id: ActiveValue::set(user_id),
674 hash: ActiveValue::set(access_token_hash.into()),
675 ..Default::default()
676 }
677 .insert(&tx)
678 .await?;
679
680 access_token::Entity::delete_many()
681 .filter(
682 access_token::Column::Id.in_subquery(
683 Query::select()
684 .column(access_token::Column::Id)
685 .from(access_token::Entity)
686 .and_where(access_token::Column::UserId.eq(user_id))
687 .order_by(access_token::Column::Id, sea_orm::Order::Desc)
688 .limit(10000)
689 .offset(max_access_token_count as u64)
690 .to_owned(),
691 ),
692 )
693 .exec(&tx)
694 .await?;
695 tx.commit().await?;
696 Ok(())
697 })
698 .await
699 }
700
701 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
702 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
703 enum QueryAs {
704 Hash,
705 }
706
707 self.transact(|tx| async move {
708 Ok(access_token::Entity::find()
709 .select_only()
710 .column(access_token::Column::Hash)
711 .filter(access_token::Column::UserId.eq(user_id))
712 .order_by_desc(access_token::Column::Id)
713 .into_values::<_, QueryAs>()
714 .all(&tx)
715 .await?)
716 })
717 .await
718 }
719
720 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
721 where
722 F: Send + Fn(DatabaseTransaction) -> Fut,
723 Fut: Send + Future<Output = Result<T>>,
724 {
725 let body = async {
726 loop {
727 let tx = self.pool.begin().await?;
728
729 // In Postgres, serializable transactions are opt-in
730 if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() {
731 tx.execute(sea_orm::Statement::from_string(
732 sea_orm::DatabaseBackend::Postgres,
733 "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
734 ))
735 .await?;
736 }
737
738 match f(tx).await {
739 Ok(result) => return Ok(result),
740 Err(error) => match error {
741 Error::Database2(
742 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
743 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
744 ) if error
745 .as_database_error()
746 .and_then(|error| error.code())
747 .as_deref()
748 == Some("40001") =>
749 {
750 // Retry (don't break the loop)
751 }
752 error @ _ => return Err(error),
753 },
754 }
755 }
756 };
757
758 #[cfg(test)]
759 {
760 if let Some(background) = self.background.as_ref() {
761 background.simulate_random_delay().await;
762 }
763
764 self.runtime.as_ref().unwrap().block_on(body)
765 }
766
767 #[cfg(not(test))]
768 {
769 body.await
770 }
771 }
772}
773
774pub struct RoomGuard<T> {
775 data: T,
776 _guard: OwnedMutexGuard<()>,
777 _not_send: PhantomData<Rc<()>>,
778}
779
780impl<T> Deref for RoomGuard<T> {
781 type Target = T;
782
783 fn deref(&self) -> &T {
784 &self.data
785 }
786}
787
788impl<T> DerefMut for RoomGuard<T> {
789 fn deref_mut(&mut self) -> &mut T {
790 &mut self.data
791 }
792}
793
794#[derive(Debug, Serialize, Deserialize)]
795pub struct NewUserParams {
796 pub github_login: String,
797 pub github_user_id: i32,
798 pub invite_count: i32,
799}
800
801#[derive(Debug)]
802pub struct NewUserResult {
803 pub user_id: UserId,
804 pub metrics_id: String,
805 pub inviting_user_id: Option<UserId>,
806 pub signup_device_id: Option<String>,
807}
808
809fn random_invite_code() -> String {
810 nanoid::nanoid!(16)
811}
812
813fn random_email_confirmation_code() -> String {
814 nanoid::nanoid!(64)
815}
816
817macro_rules! id_type {
818 ($name:ident) => {
819 #[derive(
820 Clone,
821 Copy,
822 Debug,
823 Default,
824 PartialEq,
825 Eq,
826 PartialOrd,
827 Ord,
828 Hash,
829 sqlx::Type,
830 Serialize,
831 Deserialize,
832 )]
833 #[sqlx(transparent)]
834 #[serde(transparent)]
835 pub struct $name(pub i32);
836
837 impl $name {
838 #[allow(unused)]
839 pub const MAX: Self = Self(i32::MAX);
840
841 #[allow(unused)]
842 pub fn from_proto(value: u64) -> Self {
843 Self(value as i32)
844 }
845
846 #[allow(unused)]
847 pub fn to_proto(self) -> u64 {
848 self.0 as u64
849 }
850 }
851
852 impl std::fmt::Display for $name {
853 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
854 self.0.fmt(f)
855 }
856 }
857
858 impl From<$name> for sea_query::Value {
859 fn from(value: $name) -> Self {
860 sea_query::Value::Int(Some(value.0))
861 }
862 }
863
864 impl sea_orm::TryGetable for $name {
865 fn try_get(
866 res: &sea_orm::QueryResult,
867 pre: &str,
868 col: &str,
869 ) -> Result<Self, sea_orm::TryGetError> {
870 Ok(Self(i32::try_get(res, pre, col)?))
871 }
872 }
873
874 impl sea_query::ValueType for $name {
875 fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
876 match v {
877 Value::TinyInt(Some(int)) => {
878 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
879 }
880 Value::SmallInt(Some(int)) => {
881 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
882 }
883 Value::Int(Some(int)) => {
884 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
885 }
886 Value::BigInt(Some(int)) => {
887 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
888 }
889 Value::TinyUnsigned(Some(int)) => {
890 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
891 }
892 Value::SmallUnsigned(Some(int)) => {
893 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
894 }
895 Value::Unsigned(Some(int)) => {
896 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
897 }
898 Value::BigUnsigned(Some(int)) => {
899 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
900 }
901 _ => Err(sea_query::ValueTypeErr),
902 }
903 }
904
905 fn type_name() -> String {
906 stringify!($name).into()
907 }
908
909 fn array_type() -> sea_query::ArrayType {
910 sea_query::ArrayType::Int
911 }
912
913 fn column_type() -> sea_query::ColumnType {
914 sea_query::ColumnType::Integer(None)
915 }
916 }
917
918 impl sea_orm::TryFromU64 for $name {
919 fn try_from_u64(n: u64) -> Result<Self, DbErr> {
920 Ok(Self(n.try_into().map_err(|_| {
921 DbErr::ConvertFromU64(concat!(
922 "error converting ",
923 stringify!($name),
924 " to u64"
925 ))
926 })?))
927 }
928 }
929
930 impl sea_query::Nullable for $name {
931 fn null() -> Value {
932 Value::Int(None)
933 }
934 }
935 };
936}
937
938id_type!(AccessTokenId);
939id_type!(ContactId);
940id_type!(UserId);
941id_type!(RoomId);
942id_type!(RoomParticipantId);
943id_type!(ProjectId);
944id_type!(ProjectCollaboratorId);
945id_type!(WorktreeId);
946
947#[cfg(test)]
948pub use test::*;
949
950#[cfg(test)]
951mod test {
952 use super::*;
953 use gpui::executor::Background;
954 use lazy_static::lazy_static;
955 use parking_lot::Mutex;
956 use rand::prelude::*;
957 use sea_orm::ConnectionTrait;
958 use sqlx::migrate::MigrateDatabase;
959 use std::sync::Arc;
960
961 pub struct TestDb {
962 pub db: Option<Arc<Database>>,
963 pub connection: Option<sqlx::AnyConnection>,
964 }
965
966 impl TestDb {
967 pub fn sqlite(background: Arc<Background>) -> Self {
968 let url = format!("sqlite::memory:");
969 let runtime = tokio::runtime::Builder::new_current_thread()
970 .enable_io()
971 .enable_time()
972 .build()
973 .unwrap();
974
975 let mut db = runtime.block_on(async {
976 let mut options = ConnectOptions::new(url);
977 options.max_connections(5);
978 let db = Database::new(options).await.unwrap();
979 let sql = include_str!(concat!(
980 env!("CARGO_MANIFEST_DIR"),
981 "/migrations.sqlite/20221109000000_test_schema.sql"
982 ));
983 db.pool
984 .execute(sea_orm::Statement::from_string(
985 db.pool.get_database_backend(),
986 sql.into(),
987 ))
988 .await
989 .unwrap();
990 db
991 });
992
993 db.background = Some(background);
994 db.runtime = Some(runtime);
995
996 Self {
997 db: Some(Arc::new(db)),
998 connection: None,
999 }
1000 }
1001
1002 pub fn postgres(background: Arc<Background>) -> Self {
1003 lazy_static! {
1004 static ref LOCK: Mutex<()> = Mutex::new(());
1005 }
1006
1007 let _guard = LOCK.lock();
1008 let mut rng = StdRng::from_entropy();
1009 let url = format!(
1010 "postgres://postgres@localhost/zed-test-{}",
1011 rng.gen::<u128>()
1012 );
1013 let runtime = tokio::runtime::Builder::new_current_thread()
1014 .enable_io()
1015 .enable_time()
1016 .build()
1017 .unwrap();
1018
1019 let mut db = runtime.block_on(async {
1020 sqlx::Postgres::create_database(&url)
1021 .await
1022 .expect("failed to create test db");
1023 let mut options = ConnectOptions::new(url);
1024 options
1025 .max_connections(5)
1026 .idle_timeout(Duration::from_secs(0));
1027 let db = Database::new(options).await.unwrap();
1028 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1029 db.migrate(Path::new(migrations_path), false).await.unwrap();
1030 db
1031 });
1032
1033 db.background = Some(background);
1034 db.runtime = Some(runtime);
1035
1036 Self {
1037 db: Some(Arc::new(db)),
1038 connection: None,
1039 }
1040 }
1041
1042 pub fn db(&self) -> &Arc<Database> {
1043 self.db.as_ref().unwrap()
1044 }
1045 }
1046
1047 impl Drop for TestDb {
1048 fn drop(&mut self) {
1049 let db = self.db.take().unwrap();
1050 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
1051 db.runtime.as_ref().unwrap().block_on(async {
1052 use util::ResultExt;
1053 let query = "
1054 SELECT pg_terminate_backend(pg_stat_activity.pid)
1055 FROM pg_stat_activity
1056 WHERE
1057 pg_stat_activity.datname = current_database() AND
1058 pid <> pg_backend_pid();
1059 ";
1060 db.pool
1061 .execute(sea_orm::Statement::from_string(
1062 db.pool.get_database_backend(),
1063 query.into(),
1064 ))
1065 .await
1066 .log_err();
1067 sqlx::Postgres::drop_database(db.options.get_url())
1068 .await
1069 .log_err();
1070 })
1071 }
1072 }
1073 }
1074}