db2.rs

   1mod access_token;
   2mod contact;
   3mod project;
   4mod project_collaborator;
   5mod room;
   6mod room_participant;
   7#[cfg(test)]
   8mod tests;
   9mod user;
  10mod worktree;
  11
  12use crate::{Error, Result};
  13use anyhow::anyhow;
  14use collections::HashMap;
  15use dashmap::DashMap;
  16use futures::StreamExt;
  17use rpc::{proto, ConnectionId};
  18use sea_orm::{
  19    entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
  20    TransactionTrait,
  21};
  22use sea_orm::{
  23    ActiveValue, ConnectionTrait, FromQueryResult, IntoActiveModel, JoinType, QueryOrder,
  24    QuerySelect,
  25};
  26use sea_query::{Alias, Expr, OnConflict, Query};
  27use serde::{Deserialize, Serialize};
  28use sqlx::migrate::{Migrate, Migration, MigrationSource};
  29use sqlx::Connection;
  30use std::ops::{Deref, DerefMut};
  31use std::path::Path;
  32use std::time::Duration;
  33use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
  34use tokio::sync::{Mutex, OwnedMutexGuard};
  35
  36pub use contact::Contact;
  37pub use user::Model as User;
  38
  39pub struct Database {
  40    options: ConnectOptions,
  41    pool: DatabaseConnection,
  42    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
  43    #[cfg(test)]
  44    background: Option<std::sync::Arc<gpui::executor::Background>>,
  45    #[cfg(test)]
  46    runtime: Option<tokio::runtime::Runtime>,
  47}
  48
  49impl Database {
  50    pub async fn new(options: ConnectOptions) -> Result<Self> {
  51        Ok(Self {
  52            options: options.clone(),
  53            pool: sea_orm::Database::connect(options).await?,
  54            rooms: DashMap::with_capacity(16384),
  55            #[cfg(test)]
  56            background: None,
  57            #[cfg(test)]
  58            runtime: None,
  59        })
  60    }
  61
  62    pub async fn migrate(
  63        &self,
  64        migrations_path: &Path,
  65        ignore_checksum_mismatch: bool,
  66    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
  67        let migrations = MigrationSource::resolve(migrations_path)
  68            .await
  69            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
  70
  71        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
  72
  73        connection.ensure_migrations_table().await?;
  74        let applied_migrations: HashMap<_, _> = connection
  75            .list_applied_migrations()
  76            .await?
  77            .into_iter()
  78            .map(|m| (m.version, m))
  79            .collect();
  80
  81        let mut new_migrations = Vec::new();
  82        for migration in migrations {
  83            match applied_migrations.get(&migration.version) {
  84                Some(applied_migration) => {
  85                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
  86                    {
  87                        Err(anyhow!(
  88                            "checksum mismatch for applied migration {}",
  89                            migration.description
  90                        ))?;
  91                    }
  92                }
  93                None => {
  94                    let elapsed = connection.apply(&migration).await?;
  95                    new_migrations.push((migration, elapsed));
  96                }
  97            }
  98        }
  99
 100        Ok(new_migrations)
 101    }
 102
 103    // users
 104
 105    pub async fn create_user(
 106        &self,
 107        email_address: &str,
 108        admin: bool,
 109        params: NewUserParams,
 110    ) -> Result<NewUserResult> {
 111        self.transact(|tx| async {
 112            let user = user::Entity::insert(user::ActiveModel {
 113                email_address: ActiveValue::set(Some(email_address.into())),
 114                github_login: ActiveValue::set(params.github_login.clone()),
 115                github_user_id: ActiveValue::set(Some(params.github_user_id)),
 116                admin: ActiveValue::set(admin),
 117                metrics_id: ActiveValue::set(Uuid::new_v4()),
 118                ..Default::default()
 119            })
 120            .on_conflict(
 121                OnConflict::column(user::Column::GithubLogin)
 122                    .update_column(user::Column::GithubLogin)
 123                    .to_owned(),
 124            )
 125            .exec_with_returning(&tx)
 126            .await?;
 127
 128            tx.commit().await?;
 129
 130            Ok(NewUserResult {
 131                user_id: user.id,
 132                metrics_id: user.metrics_id.to_string(),
 133                signup_device_id: None,
 134                inviting_user_id: None,
 135            })
 136        })
 137        .await
 138    }
 139
 140    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
 141        self.transact(|tx| async {
 142            let tx = tx;
 143            Ok(user::Entity::find()
 144                .filter(user::Column::Id.is_in(ids.iter().copied()))
 145                .all(&tx)
 146                .await?)
 147        })
 148        .await
 149    }
 150
 151    pub async fn get_user_by_github_account(
 152        &self,
 153        github_login: &str,
 154        github_user_id: Option<i32>,
 155    ) -> Result<Option<User>> {
 156        self.transact(|tx| async {
 157            let tx = tx;
 158            if let Some(github_user_id) = github_user_id {
 159                if let Some(user_by_github_user_id) = user::Entity::find()
 160                    .filter(user::Column::GithubUserId.eq(github_user_id))
 161                    .one(&tx)
 162                    .await?
 163                {
 164                    let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
 165                    user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
 166                    Ok(Some(user_by_github_user_id.update(&tx).await?))
 167                } else if let Some(user_by_github_login) = user::Entity::find()
 168                    .filter(user::Column::GithubLogin.eq(github_login))
 169                    .one(&tx)
 170                    .await?
 171                {
 172                    let mut user_by_github_login = user_by_github_login.into_active_model();
 173                    user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
 174                    Ok(Some(user_by_github_login.update(&tx).await?))
 175                } else {
 176                    Ok(None)
 177                }
 178            } else {
 179                Ok(user::Entity::find()
 180                    .filter(user::Column::GithubLogin.eq(github_login))
 181                    .one(&tx)
 182                    .await?)
 183            }
 184        })
 185        .await
 186    }
 187
 188    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 189        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 190        enum QueryAs {
 191            MetricsId,
 192        }
 193
 194        self.transact(|tx| async move {
 195            let metrics_id: Uuid = user::Entity::find_by_id(id)
 196                .select_only()
 197                .column(user::Column::MetricsId)
 198                .into_values::<_, QueryAs>()
 199                .one(&tx)
 200                .await?
 201                .ok_or_else(|| anyhow!("could not find user"))?;
 202            Ok(metrics_id.to_string())
 203        })
 204        .await
 205    }
 206
 207    // contacts
 208
 209    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 210        #[derive(Debug, FromQueryResult)]
 211        struct ContactWithUserBusyStatuses {
 212            user_id_a: UserId,
 213            user_id_b: UserId,
 214            a_to_b: bool,
 215            accepted: bool,
 216            should_notify: bool,
 217            user_a_busy: bool,
 218            user_b_busy: bool,
 219        }
 220
 221        self.transact(|tx| async move {
 222            let user_a_participant = Alias::new("user_a_participant");
 223            let user_b_participant = Alias::new("user_b_participant");
 224            let mut db_contacts = contact::Entity::find()
 225                .column_as(
 226                    Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
 227                        .is_not_null(),
 228                    "user_a_busy",
 229                )
 230                .column_as(
 231                    Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
 232                        .is_not_null(),
 233                    "user_b_busy",
 234                )
 235                .filter(
 236                    contact::Column::UserIdA
 237                        .eq(user_id)
 238                        .or(contact::Column::UserIdB.eq(user_id)),
 239                )
 240                .join_as(
 241                    JoinType::LeftJoin,
 242                    contact::Relation::UserARoomParticipant.def(),
 243                    user_a_participant,
 244                )
 245                .join_as(
 246                    JoinType::LeftJoin,
 247                    contact::Relation::UserBRoomParticipant.def(),
 248                    user_b_participant,
 249                )
 250                .into_model::<ContactWithUserBusyStatuses>()
 251                .stream(&tx)
 252                .await?;
 253
 254            let mut contacts = Vec::new();
 255            while let Some(db_contact) = db_contacts.next().await {
 256                let db_contact = db_contact?;
 257                if db_contact.user_id_a == user_id {
 258                    if db_contact.accepted {
 259                        contacts.push(Contact::Accepted {
 260                            user_id: db_contact.user_id_b,
 261                            should_notify: db_contact.should_notify && db_contact.a_to_b,
 262                            busy: db_contact.user_b_busy,
 263                        });
 264                    } else if db_contact.a_to_b {
 265                        contacts.push(Contact::Outgoing {
 266                            user_id: db_contact.user_id_b,
 267                        })
 268                    } else {
 269                        contacts.push(Contact::Incoming {
 270                            user_id: db_contact.user_id_b,
 271                            should_notify: db_contact.should_notify,
 272                        });
 273                    }
 274                } else if db_contact.accepted {
 275                    contacts.push(Contact::Accepted {
 276                        user_id: db_contact.user_id_a,
 277                        should_notify: db_contact.should_notify && !db_contact.a_to_b,
 278                        busy: db_contact.user_a_busy,
 279                    });
 280                } else if db_contact.a_to_b {
 281                    contacts.push(Contact::Incoming {
 282                        user_id: db_contact.user_id_a,
 283                        should_notify: db_contact.should_notify,
 284                    });
 285                } else {
 286                    contacts.push(Contact::Outgoing {
 287                        user_id: db_contact.user_id_a,
 288                    });
 289                }
 290            }
 291
 292            contacts.sort_unstable_by_key(|contact| contact.user_id());
 293
 294            Ok(contacts)
 295        })
 296        .await
 297    }
 298
 299    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 300        self.transact(|tx| async move {
 301            let (id_a, id_b) = if user_id_1 < user_id_2 {
 302                (user_id_1, user_id_2)
 303            } else {
 304                (user_id_2, user_id_1)
 305            };
 306
 307            Ok(contact::Entity::find()
 308                .filter(
 309                    contact::Column::UserIdA
 310                        .eq(id_a)
 311                        .and(contact::Column::UserIdB.eq(id_b))
 312                        .and(contact::Column::Accepted.eq(true)),
 313                )
 314                .one(&tx)
 315                .await?
 316                .is_some())
 317        })
 318        .await
 319    }
 320
 321    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 322        self.transact(|mut tx| async move {
 323            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 324                (sender_id, receiver_id, true)
 325            } else {
 326                (receiver_id, sender_id, false)
 327            };
 328
 329            let rows_affected = contact::Entity::insert(contact::ActiveModel {
 330                user_id_a: ActiveValue::set(id_a),
 331                user_id_b: ActiveValue::set(id_b),
 332                a_to_b: ActiveValue::set(a_to_b),
 333                accepted: ActiveValue::set(false),
 334                should_notify: ActiveValue::set(true),
 335                ..Default::default()
 336            })
 337            .on_conflict(
 338                OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB])
 339                    .values([
 340                        (contact::Column::Accepted, true.into()),
 341                        (contact::Column::ShouldNotify, false.into()),
 342                    ])
 343                    .action_and_where(
 344                        contact::Column::Accepted.eq(false).and(
 345                            contact::Column::AToB
 346                                .eq(a_to_b)
 347                                .and(contact::Column::UserIdA.eq(id_b))
 348                                .or(contact::Column::AToB
 349                                    .ne(a_to_b)
 350                                    .and(contact::Column::UserIdA.eq(id_a))),
 351                        ),
 352                    )
 353                    .to_owned(),
 354            )
 355            .exec_without_returning(&tx)
 356            .await?;
 357
 358            if rows_affected == 1 {
 359                tx.commit().await?;
 360                Ok(())
 361            } else {
 362                Err(anyhow!("contact already requested"))?
 363            }
 364        })
 365        .await
 366    }
 367
 368    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 369        self.transact(|tx| async move {
 370            let (id_a, id_b) = if responder_id < requester_id {
 371                (responder_id, requester_id)
 372            } else {
 373                (requester_id, responder_id)
 374            };
 375
 376            let result = contact::Entity::delete_many()
 377                .filter(
 378                    contact::Column::UserIdA
 379                        .eq(id_a)
 380                        .and(contact::Column::UserIdB.eq(id_b)),
 381                )
 382                .exec(&tx)
 383                .await?;
 384
 385            if result.rows_affected == 1 {
 386                tx.commit().await?;
 387                Ok(())
 388            } else {
 389                Err(anyhow!("no such contact"))?
 390            }
 391        })
 392        .await
 393    }
 394
 395    pub async fn dismiss_contact_notification(
 396        &self,
 397        user_id: UserId,
 398        contact_user_id: UserId,
 399    ) -> Result<()> {
 400        self.transact(|tx| async move {
 401            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 402                (user_id, contact_user_id, true)
 403            } else {
 404                (contact_user_id, user_id, false)
 405            };
 406
 407            let result = contact::Entity::update_many()
 408                .set(contact::ActiveModel {
 409                    should_notify: ActiveValue::set(false),
 410                    ..Default::default()
 411                })
 412                .filter(
 413                    contact::Column::UserIdA
 414                        .eq(id_a)
 415                        .and(contact::Column::UserIdB.eq(id_b))
 416                        .and(
 417                            contact::Column::AToB
 418                                .eq(a_to_b)
 419                                .and(contact::Column::Accepted.eq(true))
 420                                .or(contact::Column::AToB
 421                                    .ne(a_to_b)
 422                                    .and(contact::Column::Accepted.eq(false))),
 423                        ),
 424                )
 425                .exec(&tx)
 426                .await?;
 427            if result.rows_affected == 0 {
 428                Err(anyhow!("no such contact request"))?
 429            } else {
 430                tx.commit().await?;
 431                Ok(())
 432            }
 433        })
 434        .await
 435    }
 436
 437    pub async fn respond_to_contact_request(
 438        &self,
 439        responder_id: UserId,
 440        requester_id: UserId,
 441        accept: bool,
 442    ) -> Result<()> {
 443        self.transact(|tx| async move {
 444            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 445                (responder_id, requester_id, false)
 446            } else {
 447                (requester_id, responder_id, true)
 448            };
 449            let rows_affected = if accept {
 450                let result = contact::Entity::update_many()
 451                    .set(contact::ActiveModel {
 452                        accepted: ActiveValue::set(true),
 453                        should_notify: ActiveValue::set(true),
 454                        ..Default::default()
 455                    })
 456                    .filter(
 457                        contact::Column::UserIdA
 458                            .eq(id_a)
 459                            .and(contact::Column::UserIdB.eq(id_b))
 460                            .and(contact::Column::AToB.eq(a_to_b)),
 461                    )
 462                    .exec(&tx)
 463                    .await?;
 464                result.rows_affected
 465            } else {
 466                let result = contact::Entity::delete_many()
 467                    .filter(
 468                        contact::Column::UserIdA
 469                            .eq(id_a)
 470                            .and(contact::Column::UserIdB.eq(id_b))
 471                            .and(contact::Column::AToB.eq(a_to_b))
 472                            .and(contact::Column::Accepted.eq(false)),
 473                    )
 474                    .exec(&tx)
 475                    .await?;
 476
 477                result.rows_affected
 478            };
 479
 480            if rows_affected == 1 {
 481                tx.commit().await?;
 482                Ok(())
 483            } else {
 484                Err(anyhow!("no such contact request"))?
 485            }
 486        })
 487        .await
 488    }
 489
 490    pub fn fuzzy_like_string(string: &str) -> String {
 491        let mut result = String::with_capacity(string.len() * 2 + 1);
 492        for c in string.chars() {
 493            if c.is_alphanumeric() {
 494                result.push('%');
 495                result.push(c);
 496            }
 497        }
 498        result.push('%');
 499        result
 500    }
 501
 502    // projects
 503
 504    pub async fn share_project(
 505        &self,
 506        room_id: RoomId,
 507        connection_id: ConnectionId,
 508        worktrees: &[proto::WorktreeMetadata],
 509    ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
 510        self.transact(|tx| async move {
 511            let participant = room_participant::Entity::find()
 512                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
 513                .one(&tx)
 514                .await?
 515                .ok_or_else(|| anyhow!("could not find participant"))?;
 516            if participant.room_id != room_id {
 517                return Err(anyhow!("shared project on unexpected room"))?;
 518            }
 519
 520            let project = project::ActiveModel {
 521                room_id: ActiveValue::set(participant.room_id),
 522                host_user_id: ActiveValue::set(participant.user_id),
 523                host_connection_id: ActiveValue::set(connection_id.0 as i32),
 524                ..Default::default()
 525            }
 526            .insert(&tx)
 527            .await?;
 528
 529            worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
 530                id: ActiveValue::set(worktree.id as i32),
 531                project_id: ActiveValue::set(project.id),
 532                abs_path: ActiveValue::set(worktree.abs_path.clone()),
 533                root_name: ActiveValue::set(worktree.root_name.clone()),
 534                visible: ActiveValue::set(worktree.visible),
 535                scan_id: ActiveValue::set(0),
 536                is_complete: ActiveValue::set(false),
 537            }))
 538            .exec(&tx)
 539            .await?;
 540
 541            project_collaborator::ActiveModel {
 542                project_id: ActiveValue::set(project.id),
 543                connection_id: ActiveValue::set(connection_id.0 as i32),
 544                user_id: ActiveValue::set(participant.user_id),
 545                replica_id: ActiveValue::set(0),
 546                is_host: ActiveValue::set(true),
 547                ..Default::default()
 548            }
 549            .insert(&tx)
 550            .await?;
 551
 552            let room = self.get_room(room_id, &tx).await?;
 553            self.commit_room_transaction(room_id, tx, (project.id, room))
 554                .await
 555        })
 556        .await
 557    }
 558
 559    async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
 560        let db_room = room::Entity::find_by_id(room_id)
 561            .one(tx)
 562            .await?
 563            .ok_or_else(|| anyhow!("could not find room"))?;
 564
 565        let mut db_participants = db_room
 566            .find_related(room_participant::Entity)
 567            .stream(tx)
 568            .await?;
 569        let mut participants = HashMap::default();
 570        let mut pending_participants = Vec::new();
 571        while let Some(db_participant) = db_participants.next().await {
 572            let db_participant = db_participant?;
 573            if let Some(answering_connection_id) = db_participant.answering_connection_id {
 574                let location = match (
 575                    db_participant.location_kind,
 576                    db_participant.location_project_id,
 577                ) {
 578                    (Some(0), Some(project_id)) => {
 579                        Some(proto::participant_location::Variant::SharedProject(
 580                            proto::participant_location::SharedProject {
 581                                id: project_id.to_proto(),
 582                            },
 583                        ))
 584                    }
 585                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
 586                        Default::default(),
 587                    )),
 588                    _ => Some(proto::participant_location::Variant::External(
 589                        Default::default(),
 590                    )),
 591                };
 592                participants.insert(
 593                    answering_connection_id,
 594                    proto::Participant {
 595                        user_id: db_participant.user_id.to_proto(),
 596                        peer_id: answering_connection_id as u32,
 597                        projects: Default::default(),
 598                        location: Some(proto::ParticipantLocation { variant: location }),
 599                    },
 600                );
 601            } else {
 602                pending_participants.push(proto::PendingParticipant {
 603                    user_id: db_participant.user_id.to_proto(),
 604                    calling_user_id: db_participant.calling_user_id.to_proto(),
 605                    initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
 606                });
 607            }
 608        }
 609
 610        let mut db_projects = db_room
 611            .find_related(project::Entity)
 612            .find_with_related(worktree::Entity)
 613            .stream(tx)
 614            .await?;
 615
 616        while let Some(row) = db_projects.next().await {
 617            let (db_project, db_worktree) = row?;
 618            if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
 619                let project = if let Some(project) = participant
 620                    .projects
 621                    .iter_mut()
 622                    .find(|project| project.id == db_project.id.to_proto())
 623                {
 624                    project
 625                } else {
 626                    participant.projects.push(proto::ParticipantProject {
 627                        id: db_project.id.to_proto(),
 628                        worktree_root_names: Default::default(),
 629                    });
 630                    participant.projects.last_mut().unwrap()
 631                };
 632
 633                if let Some(db_worktree) = db_worktree {
 634                    project.worktree_root_names.push(db_worktree.root_name);
 635                }
 636            }
 637        }
 638
 639        Ok(proto::Room {
 640            id: db_room.id.to_proto(),
 641            live_kit_room: db_room.live_kit_room,
 642            participants: participants.into_values().collect(),
 643            pending_participants,
 644        })
 645    }
 646
 647    async fn commit_room_transaction<T>(
 648        &self,
 649        room_id: RoomId,
 650        tx: DatabaseTransaction,
 651        data: T,
 652    ) -> Result<RoomGuard<T>> {
 653        let lock = self.rooms.entry(room_id).or_default().clone();
 654        let _guard = lock.lock_owned().await;
 655        tx.commit().await?;
 656        Ok(RoomGuard {
 657            data,
 658            _guard,
 659            _not_send: PhantomData,
 660        })
 661    }
 662
 663    pub async fn create_access_token_hash(
 664        &self,
 665        user_id: UserId,
 666        access_token_hash: &str,
 667        max_access_token_count: usize,
 668    ) -> Result<()> {
 669        self.transact(|tx| async {
 670            let tx = tx;
 671
 672            access_token::ActiveModel {
 673                user_id: ActiveValue::set(user_id),
 674                hash: ActiveValue::set(access_token_hash.into()),
 675                ..Default::default()
 676            }
 677            .insert(&tx)
 678            .await?;
 679
 680            access_token::Entity::delete_many()
 681                .filter(
 682                    access_token::Column::Id.in_subquery(
 683                        Query::select()
 684                            .column(access_token::Column::Id)
 685                            .from(access_token::Entity)
 686                            .and_where(access_token::Column::UserId.eq(user_id))
 687                            .order_by(access_token::Column::Id, sea_orm::Order::Desc)
 688                            .limit(10000)
 689                            .offset(max_access_token_count as u64)
 690                            .to_owned(),
 691                    ),
 692                )
 693                .exec(&tx)
 694                .await?;
 695            tx.commit().await?;
 696            Ok(())
 697        })
 698        .await
 699    }
 700
 701    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 702        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 703        enum QueryAs {
 704            Hash,
 705        }
 706
 707        self.transact(|tx| async move {
 708            Ok(access_token::Entity::find()
 709                .select_only()
 710                .column(access_token::Column::Hash)
 711                .filter(access_token::Column::UserId.eq(user_id))
 712                .order_by_desc(access_token::Column::Id)
 713                .into_values::<_, QueryAs>()
 714                .all(&tx)
 715                .await?)
 716        })
 717        .await
 718    }
 719
 720    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
 721    where
 722        F: Send + Fn(DatabaseTransaction) -> Fut,
 723        Fut: Send + Future<Output = Result<T>>,
 724    {
 725        let body = async {
 726            loop {
 727                let tx = self.pool.begin().await?;
 728
 729                // In Postgres, serializable transactions are opt-in
 730                if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() {
 731                    tx.execute(sea_orm::Statement::from_string(
 732                        sea_orm::DatabaseBackend::Postgres,
 733                        "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
 734                    ))
 735                    .await?;
 736                }
 737
 738                match f(tx).await {
 739                    Ok(result) => return Ok(result),
 740                    Err(error) => match error {
 741                        Error::Database2(
 742                            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
 743                            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
 744                        ) if error
 745                            .as_database_error()
 746                            .and_then(|error| error.code())
 747                            .as_deref()
 748                            == Some("40001") =>
 749                        {
 750                            // Retry (don't break the loop)
 751                        }
 752                        error @ _ => return Err(error),
 753                    },
 754                }
 755            }
 756        };
 757
 758        #[cfg(test)]
 759        {
 760            if let Some(background) = self.background.as_ref() {
 761                background.simulate_random_delay().await;
 762            }
 763
 764            self.runtime.as_ref().unwrap().block_on(body)
 765        }
 766
 767        #[cfg(not(test))]
 768        {
 769            body.await
 770        }
 771    }
 772}
 773
 774pub struct RoomGuard<T> {
 775    data: T,
 776    _guard: OwnedMutexGuard<()>,
 777    _not_send: PhantomData<Rc<()>>,
 778}
 779
 780impl<T> Deref for RoomGuard<T> {
 781    type Target = T;
 782
 783    fn deref(&self) -> &T {
 784        &self.data
 785    }
 786}
 787
 788impl<T> DerefMut for RoomGuard<T> {
 789    fn deref_mut(&mut self) -> &mut T {
 790        &mut self.data
 791    }
 792}
 793
 794#[derive(Debug, Serialize, Deserialize)]
 795pub struct NewUserParams {
 796    pub github_login: String,
 797    pub github_user_id: i32,
 798    pub invite_count: i32,
 799}
 800
 801#[derive(Debug)]
 802pub struct NewUserResult {
 803    pub user_id: UserId,
 804    pub metrics_id: String,
 805    pub inviting_user_id: Option<UserId>,
 806    pub signup_device_id: Option<String>,
 807}
 808
 809fn random_invite_code() -> String {
 810    nanoid::nanoid!(16)
 811}
 812
 813fn random_email_confirmation_code() -> String {
 814    nanoid::nanoid!(64)
 815}
 816
 817macro_rules! id_type {
 818    ($name:ident) => {
 819        #[derive(
 820            Clone,
 821            Copy,
 822            Debug,
 823            Default,
 824            PartialEq,
 825            Eq,
 826            PartialOrd,
 827            Ord,
 828            Hash,
 829            sqlx::Type,
 830            Serialize,
 831            Deserialize,
 832        )]
 833        #[sqlx(transparent)]
 834        #[serde(transparent)]
 835        pub struct $name(pub i32);
 836
 837        impl $name {
 838            #[allow(unused)]
 839            pub const MAX: Self = Self(i32::MAX);
 840
 841            #[allow(unused)]
 842            pub fn from_proto(value: u64) -> Self {
 843                Self(value as i32)
 844            }
 845
 846            #[allow(unused)]
 847            pub fn to_proto(self) -> u64 {
 848                self.0 as u64
 849            }
 850        }
 851
 852        impl std::fmt::Display for $name {
 853            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 854                self.0.fmt(f)
 855            }
 856        }
 857
 858        impl From<$name> for sea_query::Value {
 859            fn from(value: $name) -> Self {
 860                sea_query::Value::Int(Some(value.0))
 861            }
 862        }
 863
 864        impl sea_orm::TryGetable for $name {
 865            fn try_get(
 866                res: &sea_orm::QueryResult,
 867                pre: &str,
 868                col: &str,
 869            ) -> Result<Self, sea_orm::TryGetError> {
 870                Ok(Self(i32::try_get(res, pre, col)?))
 871            }
 872        }
 873
 874        impl sea_query::ValueType for $name {
 875            fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
 876                match v {
 877                    Value::TinyInt(Some(int)) => {
 878                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 879                    }
 880                    Value::SmallInt(Some(int)) => {
 881                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 882                    }
 883                    Value::Int(Some(int)) => {
 884                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 885                    }
 886                    Value::BigInt(Some(int)) => {
 887                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 888                    }
 889                    Value::TinyUnsigned(Some(int)) => {
 890                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 891                    }
 892                    Value::SmallUnsigned(Some(int)) => {
 893                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 894                    }
 895                    Value::Unsigned(Some(int)) => {
 896                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 897                    }
 898                    Value::BigUnsigned(Some(int)) => {
 899                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 900                    }
 901                    _ => Err(sea_query::ValueTypeErr),
 902                }
 903            }
 904
 905            fn type_name() -> String {
 906                stringify!($name).into()
 907            }
 908
 909            fn array_type() -> sea_query::ArrayType {
 910                sea_query::ArrayType::Int
 911            }
 912
 913            fn column_type() -> sea_query::ColumnType {
 914                sea_query::ColumnType::Integer(None)
 915            }
 916        }
 917
 918        impl sea_orm::TryFromU64 for $name {
 919            fn try_from_u64(n: u64) -> Result<Self, DbErr> {
 920                Ok(Self(n.try_into().map_err(|_| {
 921                    DbErr::ConvertFromU64(concat!(
 922                        "error converting ",
 923                        stringify!($name),
 924                        " to u64"
 925                    ))
 926                })?))
 927            }
 928        }
 929
 930        impl sea_query::Nullable for $name {
 931            fn null() -> Value {
 932                Value::Int(None)
 933            }
 934        }
 935    };
 936}
 937
 938id_type!(AccessTokenId);
 939id_type!(ContactId);
 940id_type!(UserId);
 941id_type!(RoomId);
 942id_type!(RoomParticipantId);
 943id_type!(ProjectId);
 944id_type!(ProjectCollaboratorId);
 945id_type!(WorktreeId);
 946
 947#[cfg(test)]
 948pub use test::*;
 949
 950#[cfg(test)]
 951mod test {
 952    use super::*;
 953    use gpui::executor::Background;
 954    use lazy_static::lazy_static;
 955    use parking_lot::Mutex;
 956    use rand::prelude::*;
 957    use sea_orm::ConnectionTrait;
 958    use sqlx::migrate::MigrateDatabase;
 959    use std::sync::Arc;
 960
 961    pub struct TestDb {
 962        pub db: Option<Arc<Database>>,
 963        pub connection: Option<sqlx::AnyConnection>,
 964    }
 965
 966    impl TestDb {
 967        pub fn sqlite(background: Arc<Background>) -> Self {
 968            let url = format!("sqlite::memory:");
 969            let runtime = tokio::runtime::Builder::new_current_thread()
 970                .enable_io()
 971                .enable_time()
 972                .build()
 973                .unwrap();
 974
 975            let mut db = runtime.block_on(async {
 976                let mut options = ConnectOptions::new(url);
 977                options.max_connections(5);
 978                let db = Database::new(options).await.unwrap();
 979                let sql = include_str!(concat!(
 980                    env!("CARGO_MANIFEST_DIR"),
 981                    "/migrations.sqlite/20221109000000_test_schema.sql"
 982                ));
 983                db.pool
 984                    .execute(sea_orm::Statement::from_string(
 985                        db.pool.get_database_backend(),
 986                        sql.into(),
 987                    ))
 988                    .await
 989                    .unwrap();
 990                db
 991            });
 992
 993            db.background = Some(background);
 994            db.runtime = Some(runtime);
 995
 996            Self {
 997                db: Some(Arc::new(db)),
 998                connection: None,
 999            }
1000        }
1001
1002        pub fn postgres(background: Arc<Background>) -> Self {
1003            lazy_static! {
1004                static ref LOCK: Mutex<()> = Mutex::new(());
1005            }
1006
1007            let _guard = LOCK.lock();
1008            let mut rng = StdRng::from_entropy();
1009            let url = format!(
1010                "postgres://postgres@localhost/zed-test-{}",
1011                rng.gen::<u128>()
1012            );
1013            let runtime = tokio::runtime::Builder::new_current_thread()
1014                .enable_io()
1015                .enable_time()
1016                .build()
1017                .unwrap();
1018
1019            let mut db = runtime.block_on(async {
1020                sqlx::Postgres::create_database(&url)
1021                    .await
1022                    .expect("failed to create test db");
1023                let mut options = ConnectOptions::new(url);
1024                options
1025                    .max_connections(5)
1026                    .idle_timeout(Duration::from_secs(0));
1027                let db = Database::new(options).await.unwrap();
1028                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1029                db.migrate(Path::new(migrations_path), false).await.unwrap();
1030                db
1031            });
1032
1033            db.background = Some(background);
1034            db.runtime = Some(runtime);
1035
1036            Self {
1037                db: Some(Arc::new(db)),
1038                connection: None,
1039            }
1040        }
1041
1042        pub fn db(&self) -> &Arc<Database> {
1043            self.db.as_ref().unwrap()
1044        }
1045    }
1046
1047    impl Drop for TestDb {
1048        fn drop(&mut self) {
1049            let db = self.db.take().unwrap();
1050            if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
1051                db.runtime.as_ref().unwrap().block_on(async {
1052                    use util::ResultExt;
1053                    let query = "
1054                        SELECT pg_terminate_backend(pg_stat_activity.pid)
1055                        FROM pg_stat_activity
1056                        WHERE
1057                            pg_stat_activity.datname = current_database() AND
1058                            pid <> pg_backend_pid();
1059                    ";
1060                    db.pool
1061                        .execute(sea_orm::Statement::from_string(
1062                            db.pool.get_database_backend(),
1063                            query.into(),
1064                        ))
1065                        .await
1066                        .log_err();
1067                    sqlx::Postgres::drop_database(db.options.get_url())
1068                        .await
1069                        .log_err();
1070                })
1071            }
1072        }
1073    }
1074}