db2.rs

   1mod access_token;
   2mod contact;
   3mod project;
   4mod project_collaborator;
   5mod room;
   6mod room_participant;
   7mod signup;
   8#[cfg(test)]
   9mod tests;
  10mod user;
  11mod worktree;
  12
  13use crate::{Error, Result};
  14use anyhow::anyhow;
  15use collections::HashMap;
  16use dashmap::DashMap;
  17use futures::StreamExt;
  18use hyper::StatusCode;
  19use rpc::{proto, ConnectionId};
  20use sea_orm::{
  21    entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
  22    TransactionTrait,
  23};
  24use sea_orm::{
  25    ActiveValue, ConnectionTrait, DatabaseBackend, FromQueryResult, IntoActiveModel, JoinType,
  26    QueryOrder, QuerySelect, Statement,
  27};
  28use sea_query::{Alias, Expr, OnConflict, Query};
  29use serde::{Deserialize, Serialize};
  30use sqlx::migrate::{Migrate, Migration, MigrationSource};
  31use sqlx::Connection;
  32use std::ops::{Deref, DerefMut};
  33use std::path::Path;
  34use std::time::Duration;
  35use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
  36use tokio::sync::{Mutex, OwnedMutexGuard};
  37
  38pub use contact::Contact;
  39pub use signup::Invite;
  40pub use user::Model as User;
  41
  42pub struct Database {
  43    options: ConnectOptions,
  44    pool: DatabaseConnection,
  45    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
  46    #[cfg(test)]
  47    background: Option<std::sync::Arc<gpui::executor::Background>>,
  48    #[cfg(test)]
  49    runtime: Option<tokio::runtime::Runtime>,
  50}
  51
  52impl Database {
  53    pub async fn new(options: ConnectOptions) -> Result<Self> {
  54        Ok(Self {
  55            options: options.clone(),
  56            pool: sea_orm::Database::connect(options).await?,
  57            rooms: DashMap::with_capacity(16384),
  58            #[cfg(test)]
  59            background: None,
  60            #[cfg(test)]
  61            runtime: None,
  62        })
  63    }
  64
  65    pub async fn migrate(
  66        &self,
  67        migrations_path: &Path,
  68        ignore_checksum_mismatch: bool,
  69    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
  70        let migrations = MigrationSource::resolve(migrations_path)
  71            .await
  72            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
  73
  74        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
  75
  76        connection.ensure_migrations_table().await?;
  77        let applied_migrations: HashMap<_, _> = connection
  78            .list_applied_migrations()
  79            .await?
  80            .into_iter()
  81            .map(|m| (m.version, m))
  82            .collect();
  83
  84        let mut new_migrations = Vec::new();
  85        for migration in migrations {
  86            match applied_migrations.get(&migration.version) {
  87                Some(applied_migration) => {
  88                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
  89                    {
  90                        Err(anyhow!(
  91                            "checksum mismatch for applied migration {}",
  92                            migration.description
  93                        ))?;
  94                    }
  95                }
  96                None => {
  97                    let elapsed = connection.apply(&migration).await?;
  98                    new_migrations.push((migration, elapsed));
  99                }
 100            }
 101        }
 102
 103        Ok(new_migrations)
 104    }
 105
 106    // users
 107
 108    pub async fn create_user(
 109        &self,
 110        email_address: &str,
 111        admin: bool,
 112        params: NewUserParams,
 113    ) -> Result<NewUserResult> {
 114        self.transact(|tx| async {
 115            let user = user::Entity::insert(user::ActiveModel {
 116                email_address: ActiveValue::set(Some(email_address.into())),
 117                github_login: ActiveValue::set(params.github_login.clone()),
 118                github_user_id: ActiveValue::set(Some(params.github_user_id)),
 119                admin: ActiveValue::set(admin),
 120                metrics_id: ActiveValue::set(Uuid::new_v4()),
 121                ..Default::default()
 122            })
 123            .on_conflict(
 124                OnConflict::column(user::Column::GithubLogin)
 125                    .update_column(user::Column::GithubLogin)
 126                    .to_owned(),
 127            )
 128            .exec_with_returning(&tx)
 129            .await?;
 130
 131            tx.commit().await?;
 132
 133            Ok(NewUserResult {
 134                user_id: user.id,
 135                metrics_id: user.metrics_id.to_string(),
 136                signup_device_id: None,
 137                inviting_user_id: None,
 138            })
 139        })
 140        .await
 141    }
 142
 143    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
 144        self.transact(|tx| async {
 145            let tx = tx;
 146            Ok(user::Entity::find()
 147                .filter(user::Column::Id.is_in(ids.iter().copied()))
 148                .all(&tx)
 149                .await?)
 150        })
 151        .await
 152    }
 153
 154    pub async fn get_user_by_github_account(
 155        &self,
 156        github_login: &str,
 157        github_user_id: Option<i32>,
 158    ) -> Result<Option<User>> {
 159        self.transact(|tx| async {
 160            let tx = tx;
 161            if let Some(github_user_id) = github_user_id {
 162                if let Some(user_by_github_user_id) = user::Entity::find()
 163                    .filter(user::Column::GithubUserId.eq(github_user_id))
 164                    .one(&tx)
 165                    .await?
 166                {
 167                    let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
 168                    user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
 169                    Ok(Some(user_by_github_user_id.update(&tx).await?))
 170                } else if let Some(user_by_github_login) = user::Entity::find()
 171                    .filter(user::Column::GithubLogin.eq(github_login))
 172                    .one(&tx)
 173                    .await?
 174                {
 175                    let mut user_by_github_login = user_by_github_login.into_active_model();
 176                    user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
 177                    Ok(Some(user_by_github_login.update(&tx).await?))
 178                } else {
 179                    Ok(None)
 180                }
 181            } else {
 182                Ok(user::Entity::find()
 183                    .filter(user::Column::GithubLogin.eq(github_login))
 184                    .one(&tx)
 185                    .await?)
 186            }
 187        })
 188        .await
 189    }
 190
 191    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 192        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 193        enum QueryAs {
 194            MetricsId,
 195        }
 196
 197        self.transact(|tx| async move {
 198            let metrics_id: Uuid = user::Entity::find_by_id(id)
 199                .select_only()
 200                .column(user::Column::MetricsId)
 201                .into_values::<_, QueryAs>()
 202                .one(&tx)
 203                .await?
 204                .ok_or_else(|| anyhow!("could not find user"))?;
 205            Ok(metrics_id.to_string())
 206        })
 207        .await
 208    }
 209
 210    // contacts
 211
 212    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 213        #[derive(Debug, FromQueryResult)]
 214        struct ContactWithUserBusyStatuses {
 215            user_id_a: UserId,
 216            user_id_b: UserId,
 217            a_to_b: bool,
 218            accepted: bool,
 219            should_notify: bool,
 220            user_a_busy: bool,
 221            user_b_busy: bool,
 222        }
 223
 224        self.transact(|tx| async move {
 225            let user_a_participant = Alias::new("user_a_participant");
 226            let user_b_participant = Alias::new("user_b_participant");
 227            let mut db_contacts = contact::Entity::find()
 228                .column_as(
 229                    Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
 230                        .is_not_null(),
 231                    "user_a_busy",
 232                )
 233                .column_as(
 234                    Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
 235                        .is_not_null(),
 236                    "user_b_busy",
 237                )
 238                .filter(
 239                    contact::Column::UserIdA
 240                        .eq(user_id)
 241                        .or(contact::Column::UserIdB.eq(user_id)),
 242                )
 243                .join_as(
 244                    JoinType::LeftJoin,
 245                    contact::Relation::UserARoomParticipant.def(),
 246                    user_a_participant,
 247                )
 248                .join_as(
 249                    JoinType::LeftJoin,
 250                    contact::Relation::UserBRoomParticipant.def(),
 251                    user_b_participant,
 252                )
 253                .into_model::<ContactWithUserBusyStatuses>()
 254                .stream(&tx)
 255                .await?;
 256
 257            let mut contacts = Vec::new();
 258            while let Some(db_contact) = db_contacts.next().await {
 259                let db_contact = db_contact?;
 260                if db_contact.user_id_a == user_id {
 261                    if db_contact.accepted {
 262                        contacts.push(Contact::Accepted {
 263                            user_id: db_contact.user_id_b,
 264                            should_notify: db_contact.should_notify && db_contact.a_to_b,
 265                            busy: db_contact.user_b_busy,
 266                        });
 267                    } else if db_contact.a_to_b {
 268                        contacts.push(Contact::Outgoing {
 269                            user_id: db_contact.user_id_b,
 270                        })
 271                    } else {
 272                        contacts.push(Contact::Incoming {
 273                            user_id: db_contact.user_id_b,
 274                            should_notify: db_contact.should_notify,
 275                        });
 276                    }
 277                } else if db_contact.accepted {
 278                    contacts.push(Contact::Accepted {
 279                        user_id: db_contact.user_id_a,
 280                        should_notify: db_contact.should_notify && !db_contact.a_to_b,
 281                        busy: db_contact.user_a_busy,
 282                    });
 283                } else if db_contact.a_to_b {
 284                    contacts.push(Contact::Incoming {
 285                        user_id: db_contact.user_id_a,
 286                        should_notify: db_contact.should_notify,
 287                    });
 288                } else {
 289                    contacts.push(Contact::Outgoing {
 290                        user_id: db_contact.user_id_a,
 291                    });
 292                }
 293            }
 294
 295            contacts.sort_unstable_by_key(|contact| contact.user_id());
 296
 297            Ok(contacts)
 298        })
 299        .await
 300    }
 301
 302    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 303        self.transact(|tx| async move {
 304            let (id_a, id_b) = if user_id_1 < user_id_2 {
 305                (user_id_1, user_id_2)
 306            } else {
 307                (user_id_2, user_id_1)
 308            };
 309
 310            Ok(contact::Entity::find()
 311                .filter(
 312                    contact::Column::UserIdA
 313                        .eq(id_a)
 314                        .and(contact::Column::UserIdB.eq(id_b))
 315                        .and(contact::Column::Accepted.eq(true)),
 316                )
 317                .one(&tx)
 318                .await?
 319                .is_some())
 320        })
 321        .await
 322    }
 323
 324    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 325        self.transact(|mut tx| async move {
 326            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 327                (sender_id, receiver_id, true)
 328            } else {
 329                (receiver_id, sender_id, false)
 330            };
 331
 332            let rows_affected = contact::Entity::insert(contact::ActiveModel {
 333                user_id_a: ActiveValue::set(id_a),
 334                user_id_b: ActiveValue::set(id_b),
 335                a_to_b: ActiveValue::set(a_to_b),
 336                accepted: ActiveValue::set(false),
 337                should_notify: ActiveValue::set(true),
 338                ..Default::default()
 339            })
 340            .on_conflict(
 341                OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB])
 342                    .values([
 343                        (contact::Column::Accepted, true.into()),
 344                        (contact::Column::ShouldNotify, false.into()),
 345                    ])
 346                    .action_and_where(
 347                        contact::Column::Accepted.eq(false).and(
 348                            contact::Column::AToB
 349                                .eq(a_to_b)
 350                                .and(contact::Column::UserIdA.eq(id_b))
 351                                .or(contact::Column::AToB
 352                                    .ne(a_to_b)
 353                                    .and(contact::Column::UserIdA.eq(id_a))),
 354                        ),
 355                    )
 356                    .to_owned(),
 357            )
 358            .exec_without_returning(&tx)
 359            .await?;
 360
 361            if rows_affected == 1 {
 362                tx.commit().await?;
 363                Ok(())
 364            } else {
 365                Err(anyhow!("contact already requested"))?
 366            }
 367        })
 368        .await
 369    }
 370
 371    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 372        self.transact(|tx| async move {
 373            let (id_a, id_b) = if responder_id < requester_id {
 374                (responder_id, requester_id)
 375            } else {
 376                (requester_id, responder_id)
 377            };
 378
 379            let result = contact::Entity::delete_many()
 380                .filter(
 381                    contact::Column::UserIdA
 382                        .eq(id_a)
 383                        .and(contact::Column::UserIdB.eq(id_b)),
 384                )
 385                .exec(&tx)
 386                .await?;
 387
 388            if result.rows_affected == 1 {
 389                tx.commit().await?;
 390                Ok(())
 391            } else {
 392                Err(anyhow!("no such contact"))?
 393            }
 394        })
 395        .await
 396    }
 397
 398    pub async fn dismiss_contact_notification(
 399        &self,
 400        user_id: UserId,
 401        contact_user_id: UserId,
 402    ) -> Result<()> {
 403        self.transact(|tx| async move {
 404            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 405                (user_id, contact_user_id, true)
 406            } else {
 407                (contact_user_id, user_id, false)
 408            };
 409
 410            let result = contact::Entity::update_many()
 411                .set(contact::ActiveModel {
 412                    should_notify: ActiveValue::set(false),
 413                    ..Default::default()
 414                })
 415                .filter(
 416                    contact::Column::UserIdA
 417                        .eq(id_a)
 418                        .and(contact::Column::UserIdB.eq(id_b))
 419                        .and(
 420                            contact::Column::AToB
 421                                .eq(a_to_b)
 422                                .and(contact::Column::Accepted.eq(true))
 423                                .or(contact::Column::AToB
 424                                    .ne(a_to_b)
 425                                    .and(contact::Column::Accepted.eq(false))),
 426                        ),
 427                )
 428                .exec(&tx)
 429                .await?;
 430            if result.rows_affected == 0 {
 431                Err(anyhow!("no such contact request"))?
 432            } else {
 433                tx.commit().await?;
 434                Ok(())
 435            }
 436        })
 437        .await
 438    }
 439
 440    pub async fn respond_to_contact_request(
 441        &self,
 442        responder_id: UserId,
 443        requester_id: UserId,
 444        accept: bool,
 445    ) -> Result<()> {
 446        self.transact(|tx| async move {
 447            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 448                (responder_id, requester_id, false)
 449            } else {
 450                (requester_id, responder_id, true)
 451            };
 452            let rows_affected = if accept {
 453                let result = contact::Entity::update_many()
 454                    .set(contact::ActiveModel {
 455                        accepted: ActiveValue::set(true),
 456                        should_notify: ActiveValue::set(true),
 457                        ..Default::default()
 458                    })
 459                    .filter(
 460                        contact::Column::UserIdA
 461                            .eq(id_a)
 462                            .and(contact::Column::UserIdB.eq(id_b))
 463                            .and(contact::Column::AToB.eq(a_to_b)),
 464                    )
 465                    .exec(&tx)
 466                    .await?;
 467                result.rows_affected
 468            } else {
 469                let result = contact::Entity::delete_many()
 470                    .filter(
 471                        contact::Column::UserIdA
 472                            .eq(id_a)
 473                            .and(contact::Column::UserIdB.eq(id_b))
 474                            .and(contact::Column::AToB.eq(a_to_b))
 475                            .and(contact::Column::Accepted.eq(false)),
 476                    )
 477                    .exec(&tx)
 478                    .await?;
 479
 480                result.rows_affected
 481            };
 482
 483            if rows_affected == 1 {
 484                tx.commit().await?;
 485                Ok(())
 486            } else {
 487                Err(anyhow!("no such contact request"))?
 488            }
 489        })
 490        .await
 491    }
 492
 493    pub fn fuzzy_like_string(string: &str) -> String {
 494        let mut result = String::with_capacity(string.len() * 2 + 1);
 495        for c in string.chars() {
 496            if c.is_alphanumeric() {
 497                result.push('%');
 498                result.push(c);
 499            }
 500        }
 501        result.push('%');
 502        result
 503    }
 504
 505    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 506        self.transact(|tx| async {
 507            let tx = tx;
 508            let like_string = Self::fuzzy_like_string(name_query);
 509            let query = "
 510                SELECT users.*
 511                FROM users
 512                WHERE github_login ILIKE $1
 513                ORDER BY github_login <-> $2
 514                LIMIT $3
 515            ";
 516
 517            Ok(user::Entity::find()
 518                .from_raw_sql(Statement::from_sql_and_values(
 519                    self.pool.get_database_backend(),
 520                    query.into(),
 521                    vec![like_string.into(), name_query.into(), limit.into()],
 522                ))
 523                .all(&tx)
 524                .await?)
 525        })
 526        .await
 527    }
 528
 529    // invite codes
 530
 531    pub async fn create_invite_from_code(
 532        &self,
 533        code: &str,
 534        email_address: &str,
 535        device_id: Option<&str>,
 536    ) -> Result<Invite> {
 537        self.transact(|tx| async move {
 538            let existing_user = user::Entity::find()
 539                .filter(user::Column::EmailAddress.eq(email_address))
 540                .one(&tx)
 541                .await?;
 542
 543            if existing_user.is_some() {
 544                Err(anyhow!("email address is already in use"))?;
 545            }
 546
 547            let inviter = match user::Entity::find()
 548                .filter(user::Column::InviteCode.eq(code))
 549                .one(&tx)
 550                .await?
 551            {
 552                Some(inviter) => inviter,
 553                None => {
 554                    return Err(Error::Http(
 555                        StatusCode::NOT_FOUND,
 556                        "invite code not found".to_string(),
 557                    ))?
 558                }
 559            };
 560
 561            if inviter.invite_count == 0 {
 562                Err(Error::Http(
 563                    StatusCode::UNAUTHORIZED,
 564                    "no invites remaining".to_string(),
 565                ))?;
 566            }
 567
 568            let signup = signup::Entity::insert(signup::ActiveModel {
 569                email_address: ActiveValue::set(email_address.into()),
 570                email_confirmation_code: ActiveValue::set(random_email_confirmation_code()),
 571                email_confirmation_sent: ActiveValue::set(false),
 572                inviting_user_id: ActiveValue::set(Some(inviter.id)),
 573                platform_linux: ActiveValue::set(false),
 574                platform_mac: ActiveValue::set(false),
 575                platform_windows: ActiveValue::set(false),
 576                platform_unknown: ActiveValue::set(true),
 577                device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())),
 578                ..Default::default()
 579            })
 580            .on_conflict(
 581                OnConflict::column(signup::Column::EmailAddress)
 582                    .update_column(signup::Column::InvitingUserId)
 583                    .to_owned(),
 584            )
 585            .exec_with_returning(&tx)
 586            .await?;
 587            tx.commit().await?;
 588
 589            Ok(Invite {
 590                email_address: signup.email_address,
 591                email_confirmation_code: signup.email_confirmation_code,
 592            })
 593        })
 594        .await
 595    }
 596
 597    pub async fn create_user_from_invite(
 598        &self,
 599        invite: &Invite,
 600        user: NewUserParams,
 601    ) -> Result<Option<NewUserResult>> {
 602        self.transact(|tx| async {
 603            let tx = tx;
 604            let signup = signup::Entity::find()
 605                .filter(
 606                    signup::Column::EmailAddress
 607                        .eq(invite.email_address.as_str())
 608                        .and(
 609                            signup::Column::EmailConfirmationCode
 610                                .eq(invite.email_confirmation_code.as_str()),
 611                        ),
 612                )
 613                .one(&tx)
 614                .await?
 615                .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 616
 617            if signup.user_id.is_some() {
 618                return Ok(None);
 619            }
 620
 621            let user = user::Entity::insert(user::ActiveModel {
 622                email_address: ActiveValue::set(Some(invite.email_address.clone())),
 623                github_login: ActiveValue::set(user.github_login.clone()),
 624                github_user_id: ActiveValue::set(Some(user.github_user_id)),
 625                admin: ActiveValue::set(false),
 626                invite_count: ActiveValue::set(user.invite_count),
 627                invite_code: ActiveValue::set(Some(random_invite_code())),
 628                metrics_id: ActiveValue::set(Uuid::new_v4()),
 629                ..Default::default()
 630            })
 631            .on_conflict(
 632                OnConflict::column(user::Column::GithubLogin)
 633                    .update_columns([
 634                        user::Column::EmailAddress,
 635                        user::Column::GithubUserId,
 636                        user::Column::Admin,
 637                    ])
 638                    .to_owned(),
 639            )
 640            .exec_with_returning(&tx)
 641            .await?;
 642
 643            let mut signup = signup.into_active_model();
 644            signup.user_id = ActiveValue::set(Some(user.id));
 645            let signup = signup.update(&tx).await?;
 646
 647            if let Some(inviting_user_id) = signup.inviting_user_id {
 648                let result = user::Entity::update_many()
 649                    .filter(
 650                        user::Column::Id
 651                            .eq(inviting_user_id)
 652                            .and(user::Column::InviteCount.gt(0)),
 653                    )
 654                    .col_expr(
 655                        user::Column::InviteCount,
 656                        Expr::col(user::Column::InviteCount).sub(1),
 657                    )
 658                    .exec(&tx)
 659                    .await?;
 660
 661                if result.rows_affected == 0 {
 662                    Err(Error::Http(
 663                        StatusCode::UNAUTHORIZED,
 664                        "no invites remaining".to_string(),
 665                    ))?;
 666                }
 667
 668                contact::Entity::insert(contact::ActiveModel {
 669                    user_id_a: ActiveValue::set(inviting_user_id),
 670                    user_id_b: ActiveValue::set(user.id),
 671                    a_to_b: ActiveValue::set(true),
 672                    should_notify: ActiveValue::set(true),
 673                    accepted: ActiveValue::set(true),
 674                    ..Default::default()
 675                })
 676                .on_conflict(OnConflict::new().do_nothing().to_owned())
 677                .exec_without_returning(&tx)
 678                .await?;
 679            }
 680
 681            tx.commit().await?;
 682            Ok(Some(NewUserResult {
 683                user_id: user.id,
 684                metrics_id: user.metrics_id.to_string(),
 685                inviting_user_id: signup.inviting_user_id,
 686                signup_device_id: signup.device_id,
 687            }))
 688        })
 689        .await
 690    }
 691
 692    pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> {
 693        self.transact(|tx| async move {
 694            if count > 0 {
 695                user::Entity::update_many()
 696                    .filter(
 697                        user::Column::Id
 698                            .eq(id)
 699                            .and(user::Column::InviteCode.is_null()),
 700                    )
 701                    .col_expr(user::Column::InviteCode, random_invite_code().into())
 702                    .exec(&tx)
 703                    .await?;
 704            }
 705
 706            user::Entity::update_many()
 707                .filter(user::Column::Id.eq(id))
 708                .col_expr(user::Column::InviteCount, count.into())
 709                .exec(&tx)
 710                .await?;
 711            tx.commit().await?;
 712            Ok(())
 713        })
 714        .await
 715    }
 716
 717    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, u32)>> {
 718        self.transact(|tx| async move {
 719            match user::Entity::find_by_id(id).one(&tx).await? {
 720                Some(user) if user.invite_code.is_some() => {
 721                    Ok(Some((user.invite_code.unwrap(), user.invite_count as u32)))
 722                }
 723                _ => Ok(None),
 724            }
 725        })
 726        .await
 727    }
 728
 729    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 730        self.transact(|tx| async move {
 731            user::Entity::find()
 732                .filter(user::Column::InviteCode.eq(code))
 733                .one(&tx)
 734                .await?
 735                .ok_or_else(|| {
 736                    Error::Http(
 737                        StatusCode::NOT_FOUND,
 738                        "that invite code does not exist".to_string(),
 739                    )
 740                })
 741        })
 742        .await
 743    }
 744
 745    // projects
 746
 747    pub async fn share_project(
 748        &self,
 749        room_id: RoomId,
 750        connection_id: ConnectionId,
 751        worktrees: &[proto::WorktreeMetadata],
 752    ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
 753        self.transact(|tx| async move {
 754            let participant = room_participant::Entity::find()
 755                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
 756                .one(&tx)
 757                .await?
 758                .ok_or_else(|| anyhow!("could not find participant"))?;
 759            if participant.room_id != room_id {
 760                return Err(anyhow!("shared project on unexpected room"))?;
 761            }
 762
 763            let project = project::ActiveModel {
 764                room_id: ActiveValue::set(participant.room_id),
 765                host_user_id: ActiveValue::set(participant.user_id),
 766                host_connection_id: ActiveValue::set(connection_id.0 as i32),
 767                ..Default::default()
 768            }
 769            .insert(&tx)
 770            .await?;
 771
 772            worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
 773                id: ActiveValue::set(worktree.id as i32),
 774                project_id: ActiveValue::set(project.id),
 775                abs_path: ActiveValue::set(worktree.abs_path.clone()),
 776                root_name: ActiveValue::set(worktree.root_name.clone()),
 777                visible: ActiveValue::set(worktree.visible),
 778                scan_id: ActiveValue::set(0),
 779                is_complete: ActiveValue::set(false),
 780            }))
 781            .exec(&tx)
 782            .await?;
 783
 784            project_collaborator::ActiveModel {
 785                project_id: ActiveValue::set(project.id),
 786                connection_id: ActiveValue::set(connection_id.0 as i32),
 787                user_id: ActiveValue::set(participant.user_id),
 788                replica_id: ActiveValue::set(0),
 789                is_host: ActiveValue::set(true),
 790                ..Default::default()
 791            }
 792            .insert(&tx)
 793            .await?;
 794
 795            let room = self.get_room(room_id, &tx).await?;
 796            self.commit_room_transaction(room_id, tx, (project.id, room))
 797                .await
 798        })
 799        .await
 800    }
 801
 802    async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
 803        let db_room = room::Entity::find_by_id(room_id)
 804            .one(tx)
 805            .await?
 806            .ok_or_else(|| anyhow!("could not find room"))?;
 807
 808        let mut db_participants = db_room
 809            .find_related(room_participant::Entity)
 810            .stream(tx)
 811            .await?;
 812        let mut participants = HashMap::default();
 813        let mut pending_participants = Vec::new();
 814        while let Some(db_participant) = db_participants.next().await {
 815            let db_participant = db_participant?;
 816            if let Some(answering_connection_id) = db_participant.answering_connection_id {
 817                let location = match (
 818                    db_participant.location_kind,
 819                    db_participant.location_project_id,
 820                ) {
 821                    (Some(0), Some(project_id)) => {
 822                        Some(proto::participant_location::Variant::SharedProject(
 823                            proto::participant_location::SharedProject {
 824                                id: project_id.to_proto(),
 825                            },
 826                        ))
 827                    }
 828                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
 829                        Default::default(),
 830                    )),
 831                    _ => Some(proto::participant_location::Variant::External(
 832                        Default::default(),
 833                    )),
 834                };
 835                participants.insert(
 836                    answering_connection_id,
 837                    proto::Participant {
 838                        user_id: db_participant.user_id.to_proto(),
 839                        peer_id: answering_connection_id as u32,
 840                        projects: Default::default(),
 841                        location: Some(proto::ParticipantLocation { variant: location }),
 842                    },
 843                );
 844            } else {
 845                pending_participants.push(proto::PendingParticipant {
 846                    user_id: db_participant.user_id.to_proto(),
 847                    calling_user_id: db_participant.calling_user_id.to_proto(),
 848                    initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
 849                });
 850            }
 851        }
 852
 853        let mut db_projects = db_room
 854            .find_related(project::Entity)
 855            .find_with_related(worktree::Entity)
 856            .stream(tx)
 857            .await?;
 858
 859        while let Some(row) = db_projects.next().await {
 860            let (db_project, db_worktree) = row?;
 861            if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
 862                let project = if let Some(project) = participant
 863                    .projects
 864                    .iter_mut()
 865                    .find(|project| project.id == db_project.id.to_proto())
 866                {
 867                    project
 868                } else {
 869                    participant.projects.push(proto::ParticipantProject {
 870                        id: db_project.id.to_proto(),
 871                        worktree_root_names: Default::default(),
 872                    });
 873                    participant.projects.last_mut().unwrap()
 874                };
 875
 876                if let Some(db_worktree) = db_worktree {
 877                    project.worktree_root_names.push(db_worktree.root_name);
 878                }
 879            }
 880        }
 881
 882        Ok(proto::Room {
 883            id: db_room.id.to_proto(),
 884            live_kit_room: db_room.live_kit_room,
 885            participants: participants.into_values().collect(),
 886            pending_participants,
 887        })
 888    }
 889
 890    async fn commit_room_transaction<T>(
 891        &self,
 892        room_id: RoomId,
 893        tx: DatabaseTransaction,
 894        data: T,
 895    ) -> Result<RoomGuard<T>> {
 896        let lock = self.rooms.entry(room_id).or_default().clone();
 897        let _guard = lock.lock_owned().await;
 898        tx.commit().await?;
 899        Ok(RoomGuard {
 900            data,
 901            _guard,
 902            _not_send: PhantomData,
 903        })
 904    }
 905
 906    pub async fn create_access_token_hash(
 907        &self,
 908        user_id: UserId,
 909        access_token_hash: &str,
 910        max_access_token_count: usize,
 911    ) -> Result<()> {
 912        self.transact(|tx| async {
 913            let tx = tx;
 914
 915            access_token::ActiveModel {
 916                user_id: ActiveValue::set(user_id),
 917                hash: ActiveValue::set(access_token_hash.into()),
 918                ..Default::default()
 919            }
 920            .insert(&tx)
 921            .await?;
 922
 923            access_token::Entity::delete_many()
 924                .filter(
 925                    access_token::Column::Id.in_subquery(
 926                        Query::select()
 927                            .column(access_token::Column::Id)
 928                            .from(access_token::Entity)
 929                            .and_where(access_token::Column::UserId.eq(user_id))
 930                            .order_by(access_token::Column::Id, sea_orm::Order::Desc)
 931                            .limit(10000)
 932                            .offset(max_access_token_count as u64)
 933                            .to_owned(),
 934                    ),
 935                )
 936                .exec(&tx)
 937                .await?;
 938            tx.commit().await?;
 939            Ok(())
 940        })
 941        .await
 942    }
 943
 944    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
 945        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 946        enum QueryAs {
 947            Hash,
 948        }
 949
 950        self.transact(|tx| async move {
 951            Ok(access_token::Entity::find()
 952                .select_only()
 953                .column(access_token::Column::Hash)
 954                .filter(access_token::Column::UserId.eq(user_id))
 955                .order_by_desc(access_token::Column::Id)
 956                .into_values::<_, QueryAs>()
 957                .all(&tx)
 958                .await?)
 959        })
 960        .await
 961    }
 962
 963    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
 964    where
 965        F: Send + Fn(DatabaseTransaction) -> Fut,
 966        Fut: Send + Future<Output = Result<T>>,
 967    {
 968        let body = async {
 969            loop {
 970                let tx = self.pool.begin().await?;
 971
 972                // In Postgres, serializable transactions are opt-in
 973                if let DatabaseBackend::Postgres = self.pool.get_database_backend() {
 974                    tx.execute(Statement::from_string(
 975                        DatabaseBackend::Postgres,
 976                        "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
 977                    ))
 978                    .await?;
 979                }
 980
 981                match f(tx).await {
 982                    Ok(result) => return Ok(result),
 983                    Err(error) => match error {
 984                        Error::Database2(
 985                            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
 986                            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
 987                        ) if error
 988                            .as_database_error()
 989                            .and_then(|error| error.code())
 990                            .as_deref()
 991                            == Some("40001") =>
 992                        {
 993                            // Retry (don't break the loop)
 994                        }
 995                        error @ _ => return Err(error),
 996                    },
 997                }
 998            }
 999        };
1000
1001        #[cfg(test)]
1002        {
1003            if let Some(background) = self.background.as_ref() {
1004                background.simulate_random_delay().await;
1005            }
1006
1007            self.runtime.as_ref().unwrap().block_on(body)
1008        }
1009
1010        #[cfg(not(test))]
1011        {
1012            body.await
1013        }
1014    }
1015}
1016
1017pub struct RoomGuard<T> {
1018    data: T,
1019    _guard: OwnedMutexGuard<()>,
1020    _not_send: PhantomData<Rc<()>>,
1021}
1022
1023impl<T> Deref for RoomGuard<T> {
1024    type Target = T;
1025
1026    fn deref(&self) -> &T {
1027        &self.data
1028    }
1029}
1030
1031impl<T> DerefMut for RoomGuard<T> {
1032    fn deref_mut(&mut self) -> &mut T {
1033        &mut self.data
1034    }
1035}
1036
1037#[derive(Debug, Serialize, Deserialize)]
1038pub struct NewUserParams {
1039    pub github_login: String,
1040    pub github_user_id: i32,
1041    pub invite_count: i32,
1042}
1043
1044#[derive(Debug)]
1045pub struct NewUserResult {
1046    pub user_id: UserId,
1047    pub metrics_id: String,
1048    pub inviting_user_id: Option<UserId>,
1049    pub signup_device_id: Option<String>,
1050}
1051
1052fn random_invite_code() -> String {
1053    nanoid::nanoid!(16)
1054}
1055
1056fn random_email_confirmation_code() -> String {
1057    nanoid::nanoid!(64)
1058}
1059
1060macro_rules! id_type {
1061    ($name:ident) => {
1062        #[derive(
1063            Clone,
1064            Copy,
1065            Debug,
1066            Default,
1067            PartialEq,
1068            Eq,
1069            PartialOrd,
1070            Ord,
1071            Hash,
1072            sqlx::Type,
1073            Serialize,
1074            Deserialize,
1075        )]
1076        #[sqlx(transparent)]
1077        #[serde(transparent)]
1078        pub struct $name(pub i32);
1079
1080        impl $name {
1081            #[allow(unused)]
1082            pub const MAX: Self = Self(i32::MAX);
1083
1084            #[allow(unused)]
1085            pub fn from_proto(value: u64) -> Self {
1086                Self(value as i32)
1087            }
1088
1089            #[allow(unused)]
1090            pub fn to_proto(self) -> u64 {
1091                self.0 as u64
1092            }
1093        }
1094
1095        impl std::fmt::Display for $name {
1096            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
1097                self.0.fmt(f)
1098            }
1099        }
1100
1101        impl From<$name> for sea_query::Value {
1102            fn from(value: $name) -> Self {
1103                sea_query::Value::Int(Some(value.0))
1104            }
1105        }
1106
1107        impl sea_orm::TryGetable for $name {
1108            fn try_get(
1109                res: &sea_orm::QueryResult,
1110                pre: &str,
1111                col: &str,
1112            ) -> Result<Self, sea_orm::TryGetError> {
1113                Ok(Self(i32::try_get(res, pre, col)?))
1114            }
1115        }
1116
1117        impl sea_query::ValueType for $name {
1118            fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
1119                match v {
1120                    Value::TinyInt(Some(int)) => {
1121                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1122                    }
1123                    Value::SmallInt(Some(int)) => {
1124                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1125                    }
1126                    Value::Int(Some(int)) => {
1127                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1128                    }
1129                    Value::BigInt(Some(int)) => {
1130                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1131                    }
1132                    Value::TinyUnsigned(Some(int)) => {
1133                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1134                    }
1135                    Value::SmallUnsigned(Some(int)) => {
1136                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1137                    }
1138                    Value::Unsigned(Some(int)) => {
1139                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1140                    }
1141                    Value::BigUnsigned(Some(int)) => {
1142                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
1143                    }
1144                    _ => Err(sea_query::ValueTypeErr),
1145                }
1146            }
1147
1148            fn type_name() -> String {
1149                stringify!($name).into()
1150            }
1151
1152            fn array_type() -> sea_query::ArrayType {
1153                sea_query::ArrayType::Int
1154            }
1155
1156            fn column_type() -> sea_query::ColumnType {
1157                sea_query::ColumnType::Integer(None)
1158            }
1159        }
1160
1161        impl sea_orm::TryFromU64 for $name {
1162            fn try_from_u64(n: u64) -> Result<Self, DbErr> {
1163                Ok(Self(n.try_into().map_err(|_| {
1164                    DbErr::ConvertFromU64(concat!(
1165                        "error converting ",
1166                        stringify!($name),
1167                        " to u64"
1168                    ))
1169                })?))
1170            }
1171        }
1172
1173        impl sea_query::Nullable for $name {
1174            fn null() -> Value {
1175                Value::Int(None)
1176            }
1177        }
1178    };
1179}
1180
1181id_type!(AccessTokenId);
1182id_type!(ContactId);
1183id_type!(UserId);
1184id_type!(RoomId);
1185id_type!(RoomParticipantId);
1186id_type!(ProjectId);
1187id_type!(ProjectCollaboratorId);
1188id_type!(SignupId);
1189id_type!(WorktreeId);
1190
1191#[cfg(test)]
1192pub use test::*;
1193
1194#[cfg(test)]
1195mod test {
1196    use super::*;
1197    use gpui::executor::Background;
1198    use lazy_static::lazy_static;
1199    use parking_lot::Mutex;
1200    use rand::prelude::*;
1201    use sea_orm::ConnectionTrait;
1202    use sqlx::migrate::MigrateDatabase;
1203    use std::sync::Arc;
1204
1205    pub struct TestDb {
1206        pub db: Option<Arc<Database>>,
1207        pub connection: Option<sqlx::AnyConnection>,
1208    }
1209
1210    impl TestDb {
1211        pub fn sqlite(background: Arc<Background>) -> Self {
1212            let url = format!("sqlite::memory:");
1213            let runtime = tokio::runtime::Builder::new_current_thread()
1214                .enable_io()
1215                .enable_time()
1216                .build()
1217                .unwrap();
1218
1219            let mut db = runtime.block_on(async {
1220                let mut options = ConnectOptions::new(url);
1221                options.max_connections(5);
1222                let db = Database::new(options).await.unwrap();
1223                let sql = include_str!(concat!(
1224                    env!("CARGO_MANIFEST_DIR"),
1225                    "/migrations.sqlite/20221109000000_test_schema.sql"
1226                ));
1227                db.pool
1228                    .execute(sea_orm::Statement::from_string(
1229                        db.pool.get_database_backend(),
1230                        sql.into(),
1231                    ))
1232                    .await
1233                    .unwrap();
1234                db
1235            });
1236
1237            db.background = Some(background);
1238            db.runtime = Some(runtime);
1239
1240            Self {
1241                db: Some(Arc::new(db)),
1242                connection: None,
1243            }
1244        }
1245
1246        pub fn postgres(background: Arc<Background>) -> Self {
1247            lazy_static! {
1248                static ref LOCK: Mutex<()> = Mutex::new(());
1249            }
1250
1251            let _guard = LOCK.lock();
1252            let mut rng = StdRng::from_entropy();
1253            let url = format!(
1254                "postgres://postgres@localhost/zed-test-{}",
1255                rng.gen::<u128>()
1256            );
1257            let runtime = tokio::runtime::Builder::new_current_thread()
1258                .enable_io()
1259                .enable_time()
1260                .build()
1261                .unwrap();
1262
1263            let mut db = runtime.block_on(async {
1264                sqlx::Postgres::create_database(&url)
1265                    .await
1266                    .expect("failed to create test db");
1267                let mut options = ConnectOptions::new(url);
1268                options
1269                    .max_connections(5)
1270                    .idle_timeout(Duration::from_secs(0));
1271                let db = Database::new(options).await.unwrap();
1272                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
1273                db.migrate(Path::new(migrations_path), false).await.unwrap();
1274                db
1275            });
1276
1277            db.background = Some(background);
1278            db.runtime = Some(runtime);
1279
1280            Self {
1281                db: Some(Arc::new(db)),
1282                connection: None,
1283            }
1284        }
1285
1286        pub fn db(&self) -> &Arc<Database> {
1287            self.db.as_ref().unwrap()
1288        }
1289    }
1290
1291    impl Drop for TestDb {
1292        fn drop(&mut self) {
1293            let db = self.db.take().unwrap();
1294            if let DatabaseBackend::Postgres = db.pool.get_database_backend() {
1295                db.runtime.as_ref().unwrap().block_on(async {
1296                    use util::ResultExt;
1297                    let query = "
1298                        SELECT pg_terminate_backend(pg_stat_activity.pid)
1299                        FROM pg_stat_activity
1300                        WHERE
1301                            pg_stat_activity.datname = current_database() AND
1302                            pid <> pg_backend_pid();
1303                    ";
1304                    db.pool
1305                        .execute(sea_orm::Statement::from_string(
1306                            db.pool.get_database_backend(),
1307                            query.into(),
1308                        ))
1309                        .await
1310                        .log_err();
1311                    sqlx::Postgres::drop_database(db.options.get_url())
1312                        .await
1313                        .log_err();
1314                })
1315            }
1316        }
1317    }
1318}