db2.rs

   1mod access_token;
   2mod contact;
   3mod project;
   4mod project_collaborator;
   5mod room;
   6mod room_participant;
   7#[cfg(test)]
   8mod tests;
   9mod user;
  10mod worktree;
  11
  12use crate::{Error, Result};
  13use anyhow::anyhow;
  14use collections::HashMap;
  15use dashmap::DashMap;
  16use futures::StreamExt;
  17use rpc::{proto, ConnectionId};
  18use sea_orm::{
  19    entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
  20    TransactionTrait,
  21};
  22use sea_orm::{
  23    ActiveValue, ConnectionTrait, FromQueryResult, IntoActiveModel, JoinType, QueryOrder,
  24    QuerySelect,
  25};
  26use sea_query::{Alias, Expr, OnConflict, Query};
  27use serde::{Deserialize, Serialize};
  28use sqlx::migrate::{Migrate, Migration, MigrationSource};
  29use sqlx::Connection;
  30use std::ops::{Deref, DerefMut};
  31use std::path::Path;
  32use std::time::Duration;
  33use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
  34use tokio::sync::{Mutex, OwnedMutexGuard};
  35
  36pub use contact::Contact;
  37pub use user::Model as User;
  38
  39pub struct Database {
  40    options: ConnectOptions,
  41    pool: DatabaseConnection,
  42    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
  43    #[cfg(test)]
  44    background: Option<std::sync::Arc<gpui::executor::Background>>,
  45    #[cfg(test)]
  46    runtime: Option<tokio::runtime::Runtime>,
  47}
  48
  49impl Database {
  50    pub async fn new(options: ConnectOptions) -> Result<Self> {
  51        Ok(Self {
  52            options: options.clone(),
  53            pool: sea_orm::Database::connect(options).await?,
  54            rooms: DashMap::with_capacity(16384),
  55            #[cfg(test)]
  56            background: None,
  57            #[cfg(test)]
  58            runtime: None,
  59        })
  60    }
  61
  62    pub async fn migrate(
  63        &self,
  64        migrations_path: &Path,
  65        ignore_checksum_mismatch: bool,
  66    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
  67        let migrations = MigrationSource::resolve(migrations_path)
  68            .await
  69            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
  70
  71        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
  72
  73        connection.ensure_migrations_table().await?;
  74        let applied_migrations: HashMap<_, _> = connection
  75            .list_applied_migrations()
  76            .await?
  77            .into_iter()
  78            .map(|m| (m.version, m))
  79            .collect();
  80
  81        let mut new_migrations = Vec::new();
  82        for migration in migrations {
  83            match applied_migrations.get(&migration.version) {
  84                Some(applied_migration) => {
  85                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
  86                    {
  87                        Err(anyhow!(
  88                            "checksum mismatch for applied migration {}",
  89                            migration.description
  90                        ))?;
  91                    }
  92                }
  93                None => {
  94                    let elapsed = connection.apply(&migration).await?;
  95                    new_migrations.push((migration, elapsed));
  96                }
  97            }
  98        }
  99
 100        Ok(new_migrations)
 101    }
 102
 103    // users
 104
 105    pub async fn create_user(
 106        &self,
 107        email_address: &str,
 108        admin: bool,
 109        params: NewUserParams,
 110    ) -> Result<NewUserResult> {
 111        self.transact(|tx| async {
 112            let user = user::Entity::insert(user::ActiveModel {
 113                email_address: ActiveValue::set(Some(email_address.into())),
 114                github_login: ActiveValue::set(params.github_login.clone()),
 115                github_user_id: ActiveValue::set(Some(params.github_user_id)),
 116                admin: ActiveValue::set(admin),
 117                metrics_id: ActiveValue::set(Uuid::new_v4()),
 118                ..Default::default()
 119            })
 120            .on_conflict(
 121                OnConflict::column(user::Column::GithubLogin)
 122                    .update_column(user::Column::GithubLogin)
 123                    .to_owned(),
 124            )
 125            .exec_with_returning(&tx)
 126            .await?;
 127
 128            tx.commit().await?;
 129
 130            Ok(NewUserResult {
 131                user_id: user.id,
 132                metrics_id: user.metrics_id.to_string(),
 133                signup_device_id: None,
 134                inviting_user_id: None,
 135            })
 136        })
 137        .await
 138    }
 139
 140    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
 141        self.transact(|tx| async {
 142            let tx = tx;
 143            Ok(user::Entity::find()
 144                .filter(user::Column::Id.is_in(ids.iter().copied()))
 145                .all(&tx)
 146                .await?)
 147        })
 148        .await
 149    }
 150
 151    pub async fn get_user_by_github_account(
 152        &self,
 153        github_login: &str,
 154        github_user_id: Option<i32>,
 155    ) -> Result<Option<User>> {
 156        self.transact(|tx| async {
 157            let tx = tx;
 158            if let Some(github_user_id) = github_user_id {
 159                if let Some(user_by_github_user_id) = user::Entity::find()
 160                    .filter(user::Column::GithubUserId.eq(github_user_id))
 161                    .one(&tx)
 162                    .await?
 163                {
 164                    let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
 165                    user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
 166                    Ok(Some(user_by_github_user_id.update(&tx).await?))
 167                } else if let Some(user_by_github_login) = user::Entity::find()
 168                    .filter(user::Column::GithubLogin.eq(github_login))
 169                    .one(&tx)
 170                    .await?
 171                {
 172                    let mut user_by_github_login = user_by_github_login.into_active_model();
 173                    user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
 174                    Ok(Some(user_by_github_login.update(&tx).await?))
 175                } else {
 176                    Ok(None)
 177                }
 178            } else {
 179                Ok(user::Entity::find()
 180                    .filter(user::Column::GithubLogin.eq(github_login))
 181                    .one(&tx)
 182                    .await?)
 183            }
 184        })
 185        .await
 186    }
 187
 188    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 189        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 190        enum QueryAs {
 191            MetricsId,
 192        }
 193
 194        self.transact(|tx| async move {
 195            let metrics_id: Uuid = user::Entity::find_by_id(id)
 196                .select_only()
 197                .column(user::Column::MetricsId)
 198                .into_values::<_, QueryAs>()
 199                .one(&tx)
 200                .await?
 201                .ok_or_else(|| anyhow!("could not find user"))?;
 202            Ok(metrics_id.to_string())
 203        })
 204        .await
 205    }
 206
 207    // contacts
 208
 209    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 210        #[derive(Debug, FromQueryResult)]
 211        struct ContactWithUserBusyStatuses {
 212            user_id_a: UserId,
 213            user_id_b: UserId,
 214            a_to_b: bool,
 215            accepted: bool,
 216            should_notify: bool,
 217            user_a_busy: bool,
 218            user_b_busy: bool,
 219        }
 220
 221        self.transact(|tx| async move {
 222            let user_a_participant = Alias::new("user_a_participant");
 223            let user_b_participant = Alias::new("user_b_participant");
 224            let mut db_contacts = contact::Entity::find()
 225                .column_as(
 226                    Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
 227                        .is_not_null(),
 228                    "user_a_busy",
 229                )
 230                .column_as(
 231                    Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
 232                        .is_not_null(),
 233                    "user_b_busy",
 234                )
 235                .filter(
 236                    contact::Column::UserIdA
 237                        .eq(user_id)
 238                        .or(contact::Column::UserIdB.eq(user_id)),
 239                )
 240                .join_as(
 241                    JoinType::LeftJoin,
 242                    contact::Relation::UserARoomParticipant.def(),
 243                    user_a_participant,
 244                )
 245                .join_as(
 246                    JoinType::LeftJoin,
 247                    contact::Relation::UserBRoomParticipant.def(),
 248                    user_b_participant,
 249                )
 250                .into_model::<ContactWithUserBusyStatuses>()
 251                .stream(&tx)
 252                .await?;
 253
 254            let mut contacts = Vec::new();
 255            while let Some(db_contact) = db_contacts.next().await {
 256                let db_contact = db_contact?;
 257                if db_contact.user_id_a == user_id {
 258                    if db_contact.accepted {
 259                        contacts.push(Contact::Accepted {
 260                            user_id: db_contact.user_id_b,
 261                            should_notify: db_contact.should_notify && db_contact.a_to_b,
 262                            busy: db_contact.user_b_busy,
 263                        });
 264                    } else if db_contact.a_to_b {
 265                        contacts.push(Contact::Outgoing {
 266                            user_id: db_contact.user_id_b,
 267                        })
 268                    } else {
 269                        contacts.push(Contact::Incoming {
 270                            user_id: db_contact.user_id_b,
 271                            should_notify: db_contact.should_notify,
 272                        });
 273                    }
 274                } else if db_contact.accepted {
 275                    contacts.push(Contact::Accepted {
 276                        user_id: db_contact.user_id_a,
 277                        should_notify: db_contact.should_notify && !db_contact.a_to_b,
 278                        busy: db_contact.user_a_busy,
 279                    });
 280                } else if db_contact.a_to_b {
 281                    contacts.push(Contact::Incoming {
 282                        user_id: db_contact.user_id_a,
 283                        should_notify: db_contact.should_notify,
 284                    });
 285                } else {
 286                    contacts.push(Contact::Outgoing {
 287                        user_id: db_contact.user_id_a,
 288                    });
 289                }
 290            }
 291
 292            contacts.sort_unstable_by_key(|contact| contact.user_id());
 293
 294            Ok(contacts)
 295        })
 296        .await
 297    }
 298
 299    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 300        self.transact(|tx| async move {
 301            let (id_a, id_b) = if user_id_1 < user_id_2 {
 302                (user_id_1, user_id_2)
 303            } else {
 304                (user_id_2, user_id_1)
 305            };
 306
 307            Ok(contact::Entity::find()
 308                .filter(
 309                    contact::Column::UserIdA
 310                        .eq(id_a)
 311                        .and(contact::Column::UserIdB.eq(id_b))
 312                        .and(contact::Column::Accepted.eq(true)),
 313                )
 314                .one(&tx)
 315                .await?
 316                .is_some())
 317        })
 318        .await
 319    }
 320
 321    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 322        self.transact(|mut tx| async move {
 323            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 324                (sender_id, receiver_id, true)
 325            } else {
 326                (receiver_id, sender_id, false)
 327            };
 328
 329            let rows_affected = contact::Entity::insert(contact::ActiveModel {
 330                user_id_a: ActiveValue::set(id_a),
 331                user_id_b: ActiveValue::set(id_b),
 332                a_to_b: ActiveValue::set(a_to_b),
 333                accepted: ActiveValue::set(false),
 334                should_notify: ActiveValue::set(true),
 335                ..Default::default()
 336            })
 337            .on_conflict(
 338                OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB])
 339                    .values([
 340                        (contact::Column::Accepted, true.into()),
 341                        (contact::Column::ShouldNotify, false.into()),
 342                    ])
 343                    .action_and_where(
 344                        contact::Column::Accepted.eq(false).and(
 345                            contact::Column::AToB
 346                                .eq(a_to_b)
 347                                .and(contact::Column::UserIdA.eq(id_b))
 348                                .or(contact::Column::AToB
 349                                    .ne(a_to_b)
 350                                    .and(contact::Column::UserIdA.eq(id_a))),
 351                        ),
 352                    )
 353                    .to_owned(),
 354            )
 355            .exec_without_returning(&tx)
 356            .await?;
 357
 358            if rows_affected == 1 {
 359                tx.commit().await?;
 360                Ok(())
 361            } else {
 362                Err(anyhow!("contact already requested"))?
 363            }
 364        })
 365        .await
 366    }
 367
 368    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 369        self.transact(|mut tx| async move {
 370            // let (id_a, id_b) = if responder_id < requester_id {
 371            //     (responder_id, requester_id)
 372            // } else {
 373            //     (requester_id, responder_id)
 374            // };
 375            // let query = "
 376            //     DELETE FROM contacts
 377            //     WHERE user_id_a = $1 AND user_id_b = $2;
 378            // ";
 379            // let result = sqlx::query(query)
 380            //     .bind(id_a.0)
 381            //     .bind(id_b.0)
 382            //     .execute(&mut tx)
 383            //     .await?;
 384
 385            // if result.rows_affected() == 1 {
 386            //     tx.commit().await?;
 387            //     Ok(())
 388            // } else {
 389            //     Err(anyhow!("no such contact"))?
 390            // }
 391            todo!()
 392        })
 393        .await
 394    }
 395
 396    pub async fn dismiss_contact_notification(
 397        &self,
 398        user_id: UserId,
 399        contact_user_id: UserId,
 400    ) -> Result<()> {
 401        self.transact(|tx| async move {
 402            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 403                (user_id, contact_user_id, true)
 404            } else {
 405                (contact_user_id, user_id, false)
 406            };
 407
 408            let result = contact::Entity::update_many()
 409                .set(contact::ActiveModel {
 410                    should_notify: ActiveValue::set(false),
 411                    ..Default::default()
 412                })
 413                .filter(
 414                    contact::Column::UserIdA
 415                        .eq(id_a)
 416                        .and(contact::Column::UserIdB.eq(id_b))
 417                        .and(
 418                            contact::Column::AToB
 419                                .eq(a_to_b)
 420                                .and(contact::Column::Accepted.eq(true))
 421                                .or(contact::Column::AToB
 422                                    .ne(a_to_b)
 423                                    .and(contact::Column::Accepted.eq(false))),
 424                        ),
 425                )
 426                .exec(&tx)
 427                .await?;
 428            if result.rows_affected == 0 {
 429                Err(anyhow!("no such contact request"))?
 430            } else {
 431                tx.commit().await?;
 432                Ok(())
 433            }
 434        })
 435        .await
 436    }
 437
 438    pub async fn respond_to_contact_request(
 439        &self,
 440        responder_id: UserId,
 441        requester_id: UserId,
 442        accept: bool,
 443    ) -> Result<()> {
 444        self.transact(|tx| async move {
 445            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 446                (responder_id, requester_id, false)
 447            } else {
 448                (requester_id, responder_id, true)
 449            };
 450            let rows_affected = if accept {
 451                let result = contact::Entity::update_many()
 452                    .set(contact::ActiveModel {
 453                        accepted: ActiveValue::set(true),
 454                        should_notify: ActiveValue::set(true),
 455                        ..Default::default()
 456                    })
 457                    .filter(
 458                        contact::Column::UserIdA
 459                            .eq(id_a)
 460                            .and(contact::Column::UserIdB.eq(id_b))
 461                            .and(contact::Column::AToB.eq(a_to_b)),
 462                    )
 463                    .exec(&tx)
 464                    .await?;
 465                result.rows_affected
 466            } else {
 467                let result = contact::Entity::delete_many()
 468                    .filter(
 469                        contact::Column::UserIdA
 470                            .eq(id_a)
 471                            .and(contact::Column::UserIdB.eq(id_b))
 472                            .and(contact::Column::AToB.eq(a_to_b))
 473                            .and(contact::Column::Accepted.eq(false)),
 474                    )
 475                    .exec(&tx)
 476                    .await?;
 477
 478                result.rows_affected
 479            };
 480
 481            if rows_affected == 1 {
 482                tx.commit().await?;
 483                Ok(())
 484            } else {
 485                Err(anyhow!("no such contact request"))?
 486            }
 487        })
 488        .await
 489    }
 490
 491    // projects
 492
 493    pub async fn share_project(
 494        &self,
 495        room_id: RoomId,
 496        connection_id: ConnectionId,
 497        worktrees: &[proto::WorktreeMetadata],
 498    ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
 499        self.transact(|tx| async move {
 500            let participant = room_participant::Entity::find()
 501                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
 502                .one(&tx)
 503                .await?
 504                .ok_or_else(|| anyhow!("could not find participant"))?;
 505            if participant.room_id != room_id {
 506                return Err(anyhow!("shared project on unexpected room"))?;
 507            }
 508
 509            let project = project::ActiveModel {
 510                room_id: ActiveValue::set(participant.room_id),
 511                host_user_id: ActiveValue::set(participant.user_id),
 512                host_connection_id: ActiveValue::set(connection_id.0 as i32),
 513                ..Default::default()
 514            }
 515            .insert(&tx)
 516            .await?;
 517
 518            worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
 519                id: ActiveValue::set(worktree.id as i32),
 520                project_id: ActiveValue::set(project.id),
 521                abs_path: ActiveValue::set(worktree.abs_path.clone()),
 522                root_name: ActiveValue::set(worktree.root_name.clone()),
 523                visible: ActiveValue::set(worktree.visible),
 524                scan_id: ActiveValue::set(0),
 525                is_complete: ActiveValue::set(false),
 526            }))
 527            .exec(&tx)
 528            .await?;
 529
 530            project_collaborator::ActiveModel {
 531                project_id: ActiveValue::set(project.id),
 532                connection_id: ActiveValue::set(connection_id.0 as i32),
 533                user_id: ActiveValue::set(participant.user_id),
 534                replica_id: ActiveValue::set(0),
 535                is_host: ActiveValue::set(true),
 536                ..Default::default()
 537            }
 538            .insert(&tx)
 539            .await?;
 540
 541            let room = self.get_room(room_id, &tx).await?;
 542            self.commit_room_transaction(room_id, tx, (project.id, room))
 543                .await
 544        })
 545        .await
 546    }
 547
 548    async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
 549        let db_room = room::Entity::find_by_id(room_id)
 550            .one(tx)
 551            .await?
 552            .ok_or_else(|| anyhow!("could not find room"))?;
 553
 554        let mut db_participants = db_room
 555            .find_related(room_participant::Entity)
 556            .stream(tx)
 557            .await?;
 558        let mut participants = HashMap::default();
 559        let mut pending_participants = Vec::new();
 560        while let Some(db_participant) = db_participants.next().await {
 561            let db_participant = db_participant?;
 562            if let Some(answering_connection_id) = db_participant.answering_connection_id {
 563                let location = match (
 564                    db_participant.location_kind,
 565                    db_participant.location_project_id,
 566                ) {
 567                    (Some(0), Some(project_id)) => {
 568                        Some(proto::participant_location::Variant::SharedProject(
 569                            proto::participant_location::SharedProject {
 570                                id: project_id.to_proto(),
 571                            },
 572                        ))
 573                    }
 574                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
 575                        Default::default(),
 576                    )),
 577                    _ => Some(proto::participant_location::Variant::External(
 578                        Default::default(),
 579                    )),
 580                };
 581                participants.insert(
 582                    answering_connection_id,
 583                    proto::Participant {
 584                        user_id: db_participant.user_id.to_proto(),
 585                        peer_id: answering_connection_id as u32,
 586                        projects: Default::default(),
 587                        location: Some(proto::ParticipantLocation { variant: location }),
 588                    },
 589                );
 590            } else {
 591                pending_participants.push(proto::PendingParticipant {
 592                    user_id: db_participant.user_id.to_proto(),
 593                    calling_user_id: db_participant.calling_user_id.to_proto(),
 594                    initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
 595                });
 596            }
 597        }
 598
 599        let mut db_projects = db_room
 600            .find_related(project::Entity)
 601            .find_with_related(worktree::Entity)
 602            .stream(tx)
 603            .await?;
 604
 605        while let Some(row) = db_projects.next().await {
 606            let (db_project, db_worktree) = row?;
 607            if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
 608                let project = if let Some(project) = participant
 609                    .projects
 610                    .iter_mut()
 611                    .find(|project| project.id == db_project.id.to_proto())
 612                {
 613                    project
 614                } else {
 615                    participant.projects.push(proto::ParticipantProject {
 616                        id: db_project.id.to_proto(),
 617                        worktree_root_names: Default::default(),
 618                    });
 619                    participant.projects.last_mut().unwrap()
 620                };
 621
 622                if let Some(db_worktree) = db_worktree {
 623                    project.worktree_root_names.push(db_worktree.root_name);
 624                }
 625            }
 626        }
 627
 628        Ok(proto::Room {
 629            id: db_room.id.to_proto(),
 630            live_kit_room: db_room.live_kit_room,
 631            participants: participants.into_values().collect(),
 632            pending_participants,
 633        })
 634    }
 635
 636    async fn commit_room_transaction<T>(
 637        &self,
 638        room_id: RoomId,
 639        tx: DatabaseTransaction,
 640        data: T,
 641    ) -> Result<RoomGuard<T>> {
 642        let lock = self.rooms.entry(room_id).or_default().clone();
 643        let _guard = lock.lock_owned().await;
 644        tx.commit().await?;
 645        Ok(RoomGuard {
 646            data,
 647            _guard,
 648            _not_send: PhantomData,
 649        })
 650    }
 651
 652    pub async fn create_access_token_hash(
 653        &self,
 654        user_id: UserId,
 655        access_token_hash: &str,
 656        max_access_token_count: usize,
 657    ) -> Result<()> {
 658        self.transact(|tx| async {
 659            let tx = tx;
 660
 661            access_token::ActiveModel {
 662                user_id: ActiveValue::set(user_id),
 663                hash: ActiveValue::set(access_token_hash.into()),
 664                ..Default::default()
 665            }
 666            .insert(&tx)
 667            .await?;
 668
 669            access_token::Entity::delete_many()
 670                .filter(
 671                    access_token::Column::Id.in_subquery(
 672                        Query::select()
 673                            .column(access_token::Column::Id)
 674                            .from(access_token::Entity)
 675                            .and_where(access_token::Column::UserId.eq(user_id))
 676                            .order_by(access_token::Column::Id, sea_orm::Order::Desc)
 677                            .limit(10000)
 678                            .offset(max_access_token_count as u64)
 679                            .to_owned(),
 680                    ),
 681                )
 682                .exec(&tx)
 683                .await?;
 684            tx.commit().await?;
 685            Ok(())
 686        })
 687        .await
 688    }
 689
 690    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 691        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 692        enum QueryAs {
 693            Hash,
 694        }
 695
 696        self.transact(|tx| async move {
 697            Ok(access_token::Entity::find()
 698                .select_only()
 699                .column(access_token::Column::Hash)
 700                .filter(access_token::Column::UserId.eq(user_id))
 701                .order_by_desc(access_token::Column::Id)
 702                .into_values::<_, QueryAs>()
 703                .all(&tx)
 704                .await?)
 705        })
 706        .await
 707    }
 708
 709    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
 710    where
 711        F: Send + Fn(DatabaseTransaction) -> Fut,
 712        Fut: Send + Future<Output = Result<T>>,
 713    {
 714        let body = async {
 715            loop {
 716                let tx = self.pool.begin().await?;
 717
 718                // In Postgres, serializable transactions are opt-in
 719                if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() {
 720                    tx.execute(sea_orm::Statement::from_string(
 721                        sea_orm::DatabaseBackend::Postgres,
 722                        "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
 723                    ))
 724                    .await?;
 725                }
 726
 727                match f(tx).await {
 728                    Ok(result) => return Ok(result),
 729                    Err(error) => match error {
 730                        Error::Database2(
 731                            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
 732                            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
 733                        ) if error
 734                            .as_database_error()
 735                            .and_then(|error| error.code())
 736                            .as_deref()
 737                            == Some("40001") =>
 738                        {
 739                            // Retry (don't break the loop)
 740                        }
 741                        error @ _ => return Err(error),
 742                    },
 743                }
 744            }
 745        };
 746
 747        #[cfg(test)]
 748        {
 749            if let Some(background) = self.background.as_ref() {
 750                background.simulate_random_delay().await;
 751            }
 752
 753            self.runtime.as_ref().unwrap().block_on(body)
 754        }
 755
 756        #[cfg(not(test))]
 757        {
 758            body.await
 759        }
 760    }
 761}
 762
 763pub struct RoomGuard<T> {
 764    data: T,
 765    _guard: OwnedMutexGuard<()>,
 766    _not_send: PhantomData<Rc<()>>,
 767}
 768
 769impl<T> Deref for RoomGuard<T> {
 770    type Target = T;
 771
 772    fn deref(&self) -> &T {
 773        &self.data
 774    }
 775}
 776
 777impl<T> DerefMut for RoomGuard<T> {
 778    fn deref_mut(&mut self) -> &mut T {
 779        &mut self.data
 780    }
 781}
 782
 783#[derive(Debug, Serialize, Deserialize)]
 784pub struct NewUserParams {
 785    pub github_login: String,
 786    pub github_user_id: i32,
 787    pub invite_count: i32,
 788}
 789
 790#[derive(Debug)]
 791pub struct NewUserResult {
 792    pub user_id: UserId,
 793    pub metrics_id: String,
 794    pub inviting_user_id: Option<UserId>,
 795    pub signup_device_id: Option<String>,
 796}
 797
 798fn random_invite_code() -> String {
 799    nanoid::nanoid!(16)
 800}
 801
 802fn random_email_confirmation_code() -> String {
 803    nanoid::nanoid!(64)
 804}
 805
 806macro_rules! id_type {
 807    ($name:ident) => {
 808        #[derive(
 809            Clone,
 810            Copy,
 811            Debug,
 812            Default,
 813            PartialEq,
 814            Eq,
 815            PartialOrd,
 816            Ord,
 817            Hash,
 818            sqlx::Type,
 819            Serialize,
 820            Deserialize,
 821        )]
 822        #[sqlx(transparent)]
 823        #[serde(transparent)]
 824        pub struct $name(pub i32);
 825
 826        impl $name {
 827            #[allow(unused)]
 828            pub const MAX: Self = Self(i32::MAX);
 829
 830            #[allow(unused)]
 831            pub fn from_proto(value: u64) -> Self {
 832                Self(value as i32)
 833            }
 834
 835            #[allow(unused)]
 836            pub fn to_proto(self) -> u64 {
 837                self.0 as u64
 838            }
 839        }
 840
 841        impl std::fmt::Display for $name {
 842            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 843                self.0.fmt(f)
 844            }
 845        }
 846
 847        impl From<$name> for sea_query::Value {
 848            fn from(value: $name) -> Self {
 849                sea_query::Value::Int(Some(value.0))
 850            }
 851        }
 852
 853        impl sea_orm::TryGetable for $name {
 854            fn try_get(
 855                res: &sea_orm::QueryResult,
 856                pre: &str,
 857                col: &str,
 858            ) -> Result<Self, sea_orm::TryGetError> {
 859                Ok(Self(i32::try_get(res, pre, col)?))
 860            }
 861        }
 862
 863        impl sea_query::ValueType for $name {
 864            fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
 865                match v {
 866                    Value::TinyInt(Some(int)) => {
 867                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 868                    }
 869                    Value::SmallInt(Some(int)) => {
 870                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 871                    }
 872                    Value::Int(Some(int)) => {
 873                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 874                    }
 875                    Value::BigInt(Some(int)) => {
 876                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 877                    }
 878                    Value::TinyUnsigned(Some(int)) => {
 879                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 880                    }
 881                    Value::SmallUnsigned(Some(int)) => {
 882                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 883                    }
 884                    Value::Unsigned(Some(int)) => {
 885                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 886                    }
 887                    Value::BigUnsigned(Some(int)) => {
 888                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
 889                    }
 890                    _ => Err(sea_query::ValueTypeErr),
 891                }
 892            }
 893
 894            fn type_name() -> String {
 895                stringify!($name).into()
 896            }
 897
 898            fn array_type() -> sea_query::ArrayType {
 899                sea_query::ArrayType::Int
 900            }
 901
 902            fn column_type() -> sea_query::ColumnType {
 903                sea_query::ColumnType::Integer(None)
 904            }
 905        }
 906
 907        impl sea_orm::TryFromU64 for $name {
 908            fn try_from_u64(n: u64) -> Result<Self, DbErr> {
 909                Ok(Self(n.try_into().map_err(|_| {
 910                    DbErr::ConvertFromU64(concat!(
 911                        "error converting ",
 912                        stringify!($name),
 913                        " to u64"
 914                    ))
 915                })?))
 916            }
 917        }
 918
 919        impl sea_query::Nullable for $name {
 920            fn null() -> Value {
 921                Value::Int(None)
 922            }
 923        }
 924    };
 925}
 926
 927id_type!(AccessTokenId);
 928id_type!(ContactId);
 929id_type!(UserId);
 930id_type!(RoomId);
 931id_type!(RoomParticipantId);
 932id_type!(ProjectId);
 933id_type!(ProjectCollaboratorId);
 934id_type!(WorktreeId);
 935
 936#[cfg(test)]
 937pub use test::*;
 938
 939#[cfg(test)]
 940mod test {
 941    use super::*;
 942    use gpui::executor::Background;
 943    use lazy_static::lazy_static;
 944    use parking_lot::Mutex;
 945    use rand::prelude::*;
 946    use sea_orm::ConnectionTrait;
 947    use sqlx::migrate::MigrateDatabase;
 948    use std::sync::Arc;
 949
 950    pub struct TestDb {
 951        pub db: Option<Arc<Database>>,
 952        pub connection: Option<sqlx::AnyConnection>,
 953    }
 954
 955    impl TestDb {
 956        pub fn sqlite(background: Arc<Background>) -> Self {
 957            let url = format!("sqlite::memory:");
 958            let runtime = tokio::runtime::Builder::new_current_thread()
 959                .enable_io()
 960                .enable_time()
 961                .build()
 962                .unwrap();
 963
 964            let mut db = runtime.block_on(async {
 965                let mut options = ConnectOptions::new(url);
 966                options.max_connections(5);
 967                let db = Database::new(options).await.unwrap();
 968                let sql = include_str!(concat!(
 969                    env!("CARGO_MANIFEST_DIR"),
 970                    "/migrations.sqlite/20221109000000_test_schema.sql"
 971                ));
 972                db.pool
 973                    .execute(sea_orm::Statement::from_string(
 974                        db.pool.get_database_backend(),
 975                        sql.into(),
 976                    ))
 977                    .await
 978                    .unwrap();
 979                db
 980            });
 981
 982            db.background = Some(background);
 983            db.runtime = Some(runtime);
 984
 985            Self {
 986                db: Some(Arc::new(db)),
 987                connection: None,
 988            }
 989        }
 990
 991        pub fn postgres(background: Arc<Background>) -> Self {
 992            lazy_static! {
 993                static ref LOCK: Mutex<()> = Mutex::new(());
 994            }
 995
 996            let _guard = LOCK.lock();
 997            let mut rng = StdRng::from_entropy();
 998            let url = format!(
 999                "postgres://postgres@localhost/zed-test-{}",
1000                rng.gen::<u128>()
1001            );
1002            let runtime = tokio::runtime::Builder::new_current_thread()
1003                .enable_io()
1004                .enable_time()
1005                .build()
1006                .unwrap();
1007
1008            let mut db = runtime.block_on(async {
1009                sqlx::Postgres::create_database(&url)
1010                    .await
1011                    .expect("failed to create test db");
1012                let mut options = ConnectOptions::new(url);
1013                options
1014                    .max_connections(5)
1015                    .idle_timeout(Duration::from_secs(0));
1016                let db = Database::new(options).await.unwrap();
1017                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1018                db.migrate(Path::new(migrations_path), false).await.unwrap();
1019                db
1020            });
1021
1022            db.background = Some(background);
1023            db.runtime = Some(runtime);
1024
1025            Self {
1026                db: Some(Arc::new(db)),
1027                connection: None,
1028            }
1029        }
1030
1031        pub fn db(&self) -> &Arc<Database> {
1032            self.db.as_ref().unwrap()
1033        }
1034    }
1035
1036    impl Drop for TestDb {
1037        fn drop(&mut self) {
1038            let db = self.db.take().unwrap();
1039            if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
1040                db.runtime.as_ref().unwrap().block_on(async {
1041                    use util::ResultExt;
1042                    let query = "
1043                        SELECT pg_terminate_backend(pg_stat_activity.pid)
1044                        FROM pg_stat_activity
1045                        WHERE
1046                            pg_stat_activity.datname = current_database() AND
1047                            pid <> pg_backend_pid();
1048                    ";
1049                    db.pool
1050                        .execute(sea_orm::Statement::from_string(
1051                            db.pool.get_database_backend(),
1052                            query.into(),
1053                        ))
1054                        .await
1055                        .log_err();
1056                    sqlx::Postgres::drop_database(db.options.get_url())
1057                        .await
1058                        .log_err();
1059                })
1060            }
1061        }
1062    }
1063}