db.rs

   1mod access_token;
   2mod contact;
   3mod language_server;
   4mod project;
   5mod project_collaborator;
   6mod room;
   7mod room_participant;
   8mod signup;
   9#[cfg(test)]
  10mod tests;
  11mod user;
  12mod worktree;
  13mod worktree_diagnostic_summary;
  14mod worktree_entry;
  15
  16use crate::{Error, Result};
  17use anyhow::anyhow;
  18use collections::{BTreeMap, HashMap, HashSet};
  19pub use contact::Contact;
  20use dashmap::DashMap;
  21use futures::StreamExt;
  22use hyper::StatusCode;
  23use rpc::{proto, ConnectionId};
  24pub use sea_orm::ConnectOptions;
  25use sea_orm::{
  26    entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseBackend, DatabaseConnection,
  27    DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, JoinType, QueryOrder,
  28    QuerySelect, Statement, TransactionTrait,
  29};
  30use sea_query::{Alias, Expr, OnConflict, Query};
  31use serde::{Deserialize, Serialize};
  32pub use signup::{Invite, NewSignup, WaitlistSummary};
  33use sqlx::migrate::{Migrate, Migration, MigrationSource};
  34use sqlx::Connection;
  35use std::ops::{Deref, DerefMut};
  36use std::path::Path;
  37use std::time::Duration;
  38use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
  39use tokio::sync::{Mutex, OwnedMutexGuard};
  40pub use user::Model as User;
  41
  42pub struct Database {
  43    options: ConnectOptions,
  44    pool: DatabaseConnection,
  45    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
  46    #[cfg(test)]
  47    background: Option<std::sync::Arc<gpui::executor::Background>>,
  48    #[cfg(test)]
  49    runtime: Option<tokio::runtime::Runtime>,
  50    epoch: Uuid,
  51}
  52
  53impl Database {
  54    pub async fn new(options: ConnectOptions) -> Result<Self> {
  55        Ok(Self {
  56            options: options.clone(),
  57            pool: sea_orm::Database::connect(options).await?,
  58            rooms: DashMap::with_capacity(16384),
  59            #[cfg(test)]
  60            background: None,
  61            #[cfg(test)]
  62            runtime: None,
  63            epoch: Uuid::new_v4(),
  64        })
  65    }
  66
  67    pub async fn migrate(
  68        &self,
  69        migrations_path: &Path,
  70        ignore_checksum_mismatch: bool,
  71    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
  72        let migrations = MigrationSource::resolve(migrations_path)
  73            .await
  74            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
  75
  76        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
  77
  78        connection.ensure_migrations_table().await?;
  79        let applied_migrations: HashMap<_, _> = connection
  80            .list_applied_migrations()
  81            .await?
  82            .into_iter()
  83            .map(|m| (m.version, m))
  84            .collect();
  85
  86        let mut new_migrations = Vec::new();
  87        for migration in migrations {
  88            match applied_migrations.get(&migration.version) {
  89                Some(applied_migration) => {
  90                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
  91                    {
  92                        Err(anyhow!(
  93                            "checksum mismatch for applied migration {}",
  94                            migration.description
  95                        ))?;
  96                    }
  97                }
  98                None => {
  99                    let elapsed = connection.apply(&migration).await?;
 100                    new_migrations.push((migration, elapsed));
 101                }
 102            }
 103        }
 104
 105        Ok(new_migrations)
 106    }
 107
 108    pub async fn clear_stale_data(&self) -> Result<()> {
 109        self.transact(|tx| async {
 110            project_collaborator::Entity::delete_many()
 111                .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch))
 112                .exec(&tx)
 113                .await?;
 114            room_participant::Entity::delete_many()
 115                .filter(
 116                    room_participant::Column::AnsweringConnectionEpoch
 117                        .ne(self.epoch)
 118                        .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch)),
 119                )
 120                .exec(&tx)
 121                .await?;
 122            project::Entity::delete_many()
 123                .filter(project::Column::HostConnectionEpoch.ne(self.epoch))
 124                .exec(&tx)
 125                .await?;
 126            tx.commit().await?;
 127            Ok(())
 128        })
 129        .await
 130    }
 131
 132    // users
 133
 134    pub async fn create_user(
 135        &self,
 136        email_address: &str,
 137        admin: bool,
 138        params: NewUserParams,
 139    ) -> Result<NewUserResult> {
 140        self.transact(|tx| async {
 141            let user = user::Entity::insert(user::ActiveModel {
 142                email_address: ActiveValue::set(Some(email_address.into())),
 143                github_login: ActiveValue::set(params.github_login.clone()),
 144                github_user_id: ActiveValue::set(Some(params.github_user_id)),
 145                admin: ActiveValue::set(admin),
 146                metrics_id: ActiveValue::set(Uuid::new_v4()),
 147                ..Default::default()
 148            })
 149            .on_conflict(
 150                OnConflict::column(user::Column::GithubLogin)
 151                    .update_column(user::Column::GithubLogin)
 152                    .to_owned(),
 153            )
 154            .exec_with_returning(&tx)
 155            .await?;
 156
 157            tx.commit().await?;
 158
 159            Ok(NewUserResult {
 160                user_id: user.id,
 161                metrics_id: user.metrics_id.to_string(),
 162                signup_device_id: None,
 163                inviting_user_id: None,
 164            })
 165        })
 166        .await
 167    }
 168
 169    pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<user::Model>> {
 170        self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) })
 171            .await
 172    }
 173
 174    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
 175        self.transact(|tx| async {
 176            let tx = tx;
 177            Ok(user::Entity::find()
 178                .filter(user::Column::Id.is_in(ids.iter().copied()))
 179                .all(&tx)
 180                .await?)
 181        })
 182        .await
 183    }
 184
 185    pub async fn get_user_by_github_account(
 186        &self,
 187        github_login: &str,
 188        github_user_id: Option<i32>,
 189    ) -> Result<Option<User>> {
 190        self.transact(|tx| async {
 191            let tx = tx;
 192            if let Some(github_user_id) = github_user_id {
 193                if let Some(user_by_github_user_id) = user::Entity::find()
 194                    .filter(user::Column::GithubUserId.eq(github_user_id))
 195                    .one(&tx)
 196                    .await?
 197                {
 198                    let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
 199                    user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
 200                    Ok(Some(user_by_github_user_id.update(&tx).await?))
 201                } else if let Some(user_by_github_login) = user::Entity::find()
 202                    .filter(user::Column::GithubLogin.eq(github_login))
 203                    .one(&tx)
 204                    .await?
 205                {
 206                    let mut user_by_github_login = user_by_github_login.into_active_model();
 207                    user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
 208                    Ok(Some(user_by_github_login.update(&tx).await?))
 209                } else {
 210                    Ok(None)
 211                }
 212            } else {
 213                Ok(user::Entity::find()
 214                    .filter(user::Column::GithubLogin.eq(github_login))
 215                    .one(&tx)
 216                    .await?)
 217            }
 218        })
 219        .await
 220    }
 221
 222    pub async fn get_all_users(&self, page: u32, limit: u32) -> Result<Vec<User>> {
 223        self.transact(|tx| async move {
 224            Ok(user::Entity::find()
 225                .order_by_asc(user::Column::GithubLogin)
 226                .limit(limit as u64)
 227                .offset(page as u64 * limit as u64)
 228                .all(&tx)
 229                .await?)
 230        })
 231        .await
 232    }
 233
 234    pub async fn get_users_with_no_invites(
 235        &self,
 236        invited_by_another_user: bool,
 237    ) -> Result<Vec<User>> {
 238        self.transact(|tx| async move {
 239            Ok(user::Entity::find()
 240                .filter(
 241                    user::Column::InviteCount
 242                        .eq(0)
 243                        .and(if invited_by_another_user {
 244                            user::Column::InviterId.is_not_null()
 245                        } else {
 246                            user::Column::InviterId.is_null()
 247                        }),
 248                )
 249                .all(&tx)
 250                .await?)
 251        })
 252        .await
 253    }
 254
 255    pub async fn get_user_metrics_id(&self, id: UserId) -> Result<String> {
 256        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
 257        enum QueryAs {
 258            MetricsId,
 259        }
 260
 261        self.transact(|tx| async move {
 262            let metrics_id: Uuid = user::Entity::find_by_id(id)
 263                .select_only()
 264                .column(user::Column::MetricsId)
 265                .into_values::<_, QueryAs>()
 266                .one(&tx)
 267                .await?
 268                .ok_or_else(|| anyhow!("could not find user"))?;
 269            Ok(metrics_id.to_string())
 270        })
 271        .await
 272    }
 273
 274    pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
 275        self.transact(|tx| async move {
 276            user::Entity::update_many()
 277                .filter(user::Column::Id.eq(id))
 278                .set(user::ActiveModel {
 279                    admin: ActiveValue::set(is_admin),
 280                    ..Default::default()
 281                })
 282                .exec(&tx)
 283                .await?;
 284            tx.commit().await?;
 285            Ok(())
 286        })
 287        .await
 288    }
 289
 290    pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> {
 291        self.transact(|tx| async move {
 292            user::Entity::update_many()
 293                .filter(user::Column::Id.eq(id))
 294                .set(user::ActiveModel {
 295                    connected_once: ActiveValue::set(connected_once),
 296                    ..Default::default()
 297                })
 298                .exec(&tx)
 299                .await?;
 300            tx.commit().await?;
 301            Ok(())
 302        })
 303        .await
 304    }
 305
 306    pub async fn destroy_user(&self, id: UserId) -> Result<()> {
 307        self.transact(|tx| async move {
 308            access_token::Entity::delete_many()
 309                .filter(access_token::Column::UserId.eq(id))
 310                .exec(&tx)
 311                .await?;
 312            user::Entity::delete_by_id(id).exec(&tx).await?;
 313            tx.commit().await?;
 314            Ok(())
 315        })
 316        .await
 317    }
 318
 319    // contacts
 320
 321    pub async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
 322        #[derive(Debug, FromQueryResult)]
 323        struct ContactWithUserBusyStatuses {
 324            user_id_a: UserId,
 325            user_id_b: UserId,
 326            a_to_b: bool,
 327            accepted: bool,
 328            should_notify: bool,
 329            user_a_busy: bool,
 330            user_b_busy: bool,
 331        }
 332
 333        self.transact(|tx| async move {
 334            let user_a_participant = Alias::new("user_a_participant");
 335            let user_b_participant = Alias::new("user_b_participant");
 336            let mut db_contacts = contact::Entity::find()
 337                .column_as(
 338                    Expr::tbl(user_a_participant.clone(), room_participant::Column::Id)
 339                        .is_not_null(),
 340                    "user_a_busy",
 341                )
 342                .column_as(
 343                    Expr::tbl(user_b_participant.clone(), room_participant::Column::Id)
 344                        .is_not_null(),
 345                    "user_b_busy",
 346                )
 347                .filter(
 348                    contact::Column::UserIdA
 349                        .eq(user_id)
 350                        .or(contact::Column::UserIdB.eq(user_id)),
 351                )
 352                .join_as(
 353                    JoinType::LeftJoin,
 354                    contact::Relation::UserARoomParticipant.def(),
 355                    user_a_participant,
 356                )
 357                .join_as(
 358                    JoinType::LeftJoin,
 359                    contact::Relation::UserBRoomParticipant.def(),
 360                    user_b_participant,
 361                )
 362                .into_model::<ContactWithUserBusyStatuses>()
 363                .stream(&tx)
 364                .await?;
 365
 366            let mut contacts = Vec::new();
 367            while let Some(db_contact) = db_contacts.next().await {
 368                let db_contact = db_contact?;
 369                if db_contact.user_id_a == user_id {
 370                    if db_contact.accepted {
 371                        contacts.push(Contact::Accepted {
 372                            user_id: db_contact.user_id_b,
 373                            should_notify: db_contact.should_notify && db_contact.a_to_b,
 374                            busy: db_contact.user_b_busy,
 375                        });
 376                    } else if db_contact.a_to_b {
 377                        contacts.push(Contact::Outgoing {
 378                            user_id: db_contact.user_id_b,
 379                        })
 380                    } else {
 381                        contacts.push(Contact::Incoming {
 382                            user_id: db_contact.user_id_b,
 383                            should_notify: db_contact.should_notify,
 384                        });
 385                    }
 386                } else if db_contact.accepted {
 387                    contacts.push(Contact::Accepted {
 388                        user_id: db_contact.user_id_a,
 389                        should_notify: db_contact.should_notify && !db_contact.a_to_b,
 390                        busy: db_contact.user_a_busy,
 391                    });
 392                } else if db_contact.a_to_b {
 393                    contacts.push(Contact::Incoming {
 394                        user_id: db_contact.user_id_a,
 395                        should_notify: db_contact.should_notify,
 396                    });
 397                } else {
 398                    contacts.push(Contact::Outgoing {
 399                        user_id: db_contact.user_id_a,
 400                    });
 401                }
 402            }
 403
 404            contacts.sort_unstable_by_key(|contact| contact.user_id());
 405
 406            Ok(contacts)
 407        })
 408        .await
 409    }
 410
 411    pub async fn is_user_busy(&self, user_id: UserId) -> Result<bool> {
 412        self.transact(|tx| async move {
 413            let participant = room_participant::Entity::find()
 414                .filter(room_participant::Column::UserId.eq(user_id))
 415                .one(&tx)
 416                .await?;
 417            Ok(participant.is_some())
 418        })
 419        .await
 420    }
 421
 422    pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
 423        self.transact(|tx| async move {
 424            let (id_a, id_b) = if user_id_1 < user_id_2 {
 425                (user_id_1, user_id_2)
 426            } else {
 427                (user_id_2, user_id_1)
 428            };
 429
 430            Ok(contact::Entity::find()
 431                .filter(
 432                    contact::Column::UserIdA
 433                        .eq(id_a)
 434                        .and(contact::Column::UserIdB.eq(id_b))
 435                        .and(contact::Column::Accepted.eq(true)),
 436                )
 437                .one(&tx)
 438                .await?
 439                .is_some())
 440        })
 441        .await
 442    }
 443
 444    pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
 445        self.transact(|tx| async move {
 446            let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
 447                (sender_id, receiver_id, true)
 448            } else {
 449                (receiver_id, sender_id, false)
 450            };
 451
 452            let rows_affected = contact::Entity::insert(contact::ActiveModel {
 453                user_id_a: ActiveValue::set(id_a),
 454                user_id_b: ActiveValue::set(id_b),
 455                a_to_b: ActiveValue::set(a_to_b),
 456                accepted: ActiveValue::set(false),
 457                should_notify: ActiveValue::set(true),
 458                ..Default::default()
 459            })
 460            .on_conflict(
 461                OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB])
 462                    .values([
 463                        (contact::Column::Accepted, true.into()),
 464                        (contact::Column::ShouldNotify, false.into()),
 465                    ])
 466                    .action_and_where(
 467                        contact::Column::Accepted.eq(false).and(
 468                            contact::Column::AToB
 469                                .eq(a_to_b)
 470                                .and(contact::Column::UserIdA.eq(id_b))
 471                                .or(contact::Column::AToB
 472                                    .ne(a_to_b)
 473                                    .and(contact::Column::UserIdA.eq(id_a))),
 474                        ),
 475                    )
 476                    .to_owned(),
 477            )
 478            .exec_without_returning(&tx)
 479            .await?;
 480
 481            if rows_affected == 1 {
 482                tx.commit().await?;
 483                Ok(())
 484            } else {
 485                Err(anyhow!("contact already requested"))?
 486            }
 487        })
 488        .await
 489    }
 490
 491    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> {
 492        self.transact(|tx| async move {
 493            let (id_a, id_b) = if responder_id < requester_id {
 494                (responder_id, requester_id)
 495            } else {
 496                (requester_id, responder_id)
 497            };
 498
 499            let result = contact::Entity::delete_many()
 500                .filter(
 501                    contact::Column::UserIdA
 502                        .eq(id_a)
 503                        .and(contact::Column::UserIdB.eq(id_b)),
 504                )
 505                .exec(&tx)
 506                .await?;
 507
 508            if result.rows_affected == 1 {
 509                tx.commit().await?;
 510                Ok(())
 511            } else {
 512                Err(anyhow!("no such contact"))?
 513            }
 514        })
 515        .await
 516    }
 517
 518    pub async fn dismiss_contact_notification(
 519        &self,
 520        user_id: UserId,
 521        contact_user_id: UserId,
 522    ) -> Result<()> {
 523        self.transact(|tx| async move {
 524            let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
 525                (user_id, contact_user_id, true)
 526            } else {
 527                (contact_user_id, user_id, false)
 528            };
 529
 530            let result = contact::Entity::update_many()
 531                .set(contact::ActiveModel {
 532                    should_notify: ActiveValue::set(false),
 533                    ..Default::default()
 534                })
 535                .filter(
 536                    contact::Column::UserIdA
 537                        .eq(id_a)
 538                        .and(contact::Column::UserIdB.eq(id_b))
 539                        .and(
 540                            contact::Column::AToB
 541                                .eq(a_to_b)
 542                                .and(contact::Column::Accepted.eq(true))
 543                                .or(contact::Column::AToB
 544                                    .ne(a_to_b)
 545                                    .and(contact::Column::Accepted.eq(false))),
 546                        ),
 547                )
 548                .exec(&tx)
 549                .await?;
 550            if result.rows_affected == 0 {
 551                Err(anyhow!("no such contact request"))?
 552            } else {
 553                tx.commit().await?;
 554                Ok(())
 555            }
 556        })
 557        .await
 558    }
 559
 560    pub async fn respond_to_contact_request(
 561        &self,
 562        responder_id: UserId,
 563        requester_id: UserId,
 564        accept: bool,
 565    ) -> Result<()> {
 566        self.transact(|tx| async move {
 567            let (id_a, id_b, a_to_b) = if responder_id < requester_id {
 568                (responder_id, requester_id, false)
 569            } else {
 570                (requester_id, responder_id, true)
 571            };
 572            let rows_affected = if accept {
 573                let result = contact::Entity::update_many()
 574                    .set(contact::ActiveModel {
 575                        accepted: ActiveValue::set(true),
 576                        should_notify: ActiveValue::set(true),
 577                        ..Default::default()
 578                    })
 579                    .filter(
 580                        contact::Column::UserIdA
 581                            .eq(id_a)
 582                            .and(contact::Column::UserIdB.eq(id_b))
 583                            .and(contact::Column::AToB.eq(a_to_b)),
 584                    )
 585                    .exec(&tx)
 586                    .await?;
 587                result.rows_affected
 588            } else {
 589                let result = contact::Entity::delete_many()
 590                    .filter(
 591                        contact::Column::UserIdA
 592                            .eq(id_a)
 593                            .and(contact::Column::UserIdB.eq(id_b))
 594                            .and(contact::Column::AToB.eq(a_to_b))
 595                            .and(contact::Column::Accepted.eq(false)),
 596                    )
 597                    .exec(&tx)
 598                    .await?;
 599
 600                result.rows_affected
 601            };
 602
 603            if rows_affected == 1 {
 604                tx.commit().await?;
 605                Ok(())
 606            } else {
 607                Err(anyhow!("no such contact request"))?
 608            }
 609        })
 610        .await
 611    }
 612
 613    pub fn fuzzy_like_string(string: &str) -> String {
 614        let mut result = String::with_capacity(string.len() * 2 + 1);
 615        for c in string.chars() {
 616            if c.is_alphanumeric() {
 617                result.push('%');
 618                result.push(c);
 619            }
 620        }
 621        result.push('%');
 622        result
 623    }
 624
 625    pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result<Vec<User>> {
 626        self.transact(|tx| async {
 627            let tx = tx;
 628            let like_string = Self::fuzzy_like_string(name_query);
 629            let query = "
 630                SELECT users.*
 631                FROM users
 632                WHERE github_login ILIKE $1
 633                ORDER BY github_login <-> $2
 634                LIMIT $3
 635            ";
 636
 637            Ok(user::Entity::find()
 638                .from_raw_sql(Statement::from_sql_and_values(
 639                    self.pool.get_database_backend(),
 640                    query.into(),
 641                    vec![like_string.into(), name_query.into(), limit.into()],
 642                ))
 643                .all(&tx)
 644                .await?)
 645        })
 646        .await
 647    }
 648
 649    // signups
 650
 651    pub async fn create_signup(&self, signup: NewSignup) -> Result<()> {
 652        self.transact(|tx| async {
 653            signup::ActiveModel {
 654                email_address: ActiveValue::set(signup.email_address.clone()),
 655                email_confirmation_code: ActiveValue::set(random_email_confirmation_code()),
 656                email_confirmation_sent: ActiveValue::set(false),
 657                platform_mac: ActiveValue::set(signup.platform_mac),
 658                platform_windows: ActiveValue::set(signup.platform_windows),
 659                platform_linux: ActiveValue::set(signup.platform_linux),
 660                platform_unknown: ActiveValue::set(false),
 661                editor_features: ActiveValue::set(Some(signup.editor_features.clone())),
 662                programming_languages: ActiveValue::set(Some(signup.programming_languages.clone())),
 663                device_id: ActiveValue::set(signup.device_id.clone()),
 664                ..Default::default()
 665            }
 666            .insert(&tx)
 667            .await?;
 668            tx.commit().await?;
 669            Ok(())
 670        })
 671        .await
 672    }
 673
 674    pub async fn get_waitlist_summary(&self) -> Result<WaitlistSummary> {
 675        self.transact(|tx| async move {
 676            let query = "
 677                SELECT
 678                    COUNT(*) as count,
 679                    COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count,
 680                    COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count,
 681                    COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count,
 682                    COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count
 683                FROM (
 684                    SELECT *
 685                    FROM signups
 686                    WHERE
 687                        NOT email_confirmation_sent
 688                ) AS unsent
 689            ";
 690            Ok(
 691                WaitlistSummary::find_by_statement(Statement::from_sql_and_values(
 692                    self.pool.get_database_backend(),
 693                    query.into(),
 694                    vec![],
 695                ))
 696                .one(&tx)
 697                .await?
 698                .ok_or_else(|| anyhow!("invalid result"))?,
 699            )
 700        })
 701        .await
 702    }
 703
 704    pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> {
 705        let emails = invites
 706            .iter()
 707            .map(|s| s.email_address.as_str())
 708            .collect::<Vec<_>>();
 709        self.transact(|tx| async {
 710            signup::Entity::update_many()
 711                .filter(signup::Column::EmailAddress.is_in(emails.iter().copied()))
 712                .set(signup::ActiveModel {
 713                    email_confirmation_sent: ActiveValue::set(true),
 714                    ..Default::default()
 715                })
 716                .exec(&tx)
 717                .await?;
 718            tx.commit().await?;
 719            Ok(())
 720        })
 721        .await
 722    }
 723
 724    pub async fn get_unsent_invites(&self, count: usize) -> Result<Vec<Invite>> {
 725        self.transact(|tx| async move {
 726            Ok(signup::Entity::find()
 727                .select_only()
 728                .column(signup::Column::EmailAddress)
 729                .column(signup::Column::EmailConfirmationCode)
 730                .filter(
 731                    signup::Column::EmailConfirmationSent.eq(false).and(
 732                        signup::Column::PlatformMac
 733                            .eq(true)
 734                            .or(signup::Column::PlatformUnknown.eq(true)),
 735                    ),
 736                )
 737                .limit(count as u64)
 738                .into_model()
 739                .all(&tx)
 740                .await?)
 741        })
 742        .await
 743    }
 744
 745    // invite codes
 746
 747    pub async fn create_invite_from_code(
 748        &self,
 749        code: &str,
 750        email_address: &str,
 751        device_id: Option<&str>,
 752    ) -> Result<Invite> {
 753        self.transact(|tx| async move {
 754            let existing_user = user::Entity::find()
 755                .filter(user::Column::EmailAddress.eq(email_address))
 756                .one(&tx)
 757                .await?;
 758
 759            if existing_user.is_some() {
 760                Err(anyhow!("email address is already in use"))?;
 761            }
 762
 763            let inviter = match user::Entity::find()
 764                .filter(user::Column::InviteCode.eq(code))
 765                .one(&tx)
 766                .await?
 767            {
 768                Some(inviter) => inviter,
 769                None => {
 770                    return Err(Error::Http(
 771                        StatusCode::NOT_FOUND,
 772                        "invite code not found".to_string(),
 773                    ))?
 774                }
 775            };
 776
 777            if inviter.invite_count == 0 {
 778                Err(Error::Http(
 779                    StatusCode::UNAUTHORIZED,
 780                    "no invites remaining".to_string(),
 781                ))?;
 782            }
 783
 784            let signup = signup::Entity::insert(signup::ActiveModel {
 785                email_address: ActiveValue::set(email_address.into()),
 786                email_confirmation_code: ActiveValue::set(random_email_confirmation_code()),
 787                email_confirmation_sent: ActiveValue::set(false),
 788                inviting_user_id: ActiveValue::set(Some(inviter.id)),
 789                platform_linux: ActiveValue::set(false),
 790                platform_mac: ActiveValue::set(false),
 791                platform_windows: ActiveValue::set(false),
 792                platform_unknown: ActiveValue::set(true),
 793                device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())),
 794                ..Default::default()
 795            })
 796            .on_conflict(
 797                OnConflict::column(signup::Column::EmailAddress)
 798                    .update_column(signup::Column::InvitingUserId)
 799                    .to_owned(),
 800            )
 801            .exec_with_returning(&tx)
 802            .await?;
 803            tx.commit().await?;
 804
 805            Ok(Invite {
 806                email_address: signup.email_address,
 807                email_confirmation_code: signup.email_confirmation_code,
 808            })
 809        })
 810        .await
 811    }
 812
 813    pub async fn create_user_from_invite(
 814        &self,
 815        invite: &Invite,
 816        user: NewUserParams,
 817    ) -> Result<Option<NewUserResult>> {
 818        self.transact(|tx| async {
 819            let tx = tx;
 820            let signup = signup::Entity::find()
 821                .filter(
 822                    signup::Column::EmailAddress
 823                        .eq(invite.email_address.as_str())
 824                        .and(
 825                            signup::Column::EmailConfirmationCode
 826                                .eq(invite.email_confirmation_code.as_str()),
 827                        ),
 828                )
 829                .one(&tx)
 830                .await?
 831                .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?;
 832
 833            if signup.user_id.is_some() {
 834                return Ok(None);
 835            }
 836
 837            let user = user::Entity::insert(user::ActiveModel {
 838                email_address: ActiveValue::set(Some(invite.email_address.clone())),
 839                github_login: ActiveValue::set(user.github_login.clone()),
 840                github_user_id: ActiveValue::set(Some(user.github_user_id)),
 841                admin: ActiveValue::set(false),
 842                invite_count: ActiveValue::set(user.invite_count),
 843                invite_code: ActiveValue::set(Some(random_invite_code())),
 844                metrics_id: ActiveValue::set(Uuid::new_v4()),
 845                ..Default::default()
 846            })
 847            .on_conflict(
 848                OnConflict::column(user::Column::GithubLogin)
 849                    .update_columns([
 850                        user::Column::EmailAddress,
 851                        user::Column::GithubUserId,
 852                        user::Column::Admin,
 853                    ])
 854                    .to_owned(),
 855            )
 856            .exec_with_returning(&tx)
 857            .await?;
 858
 859            let mut signup = signup.into_active_model();
 860            signup.user_id = ActiveValue::set(Some(user.id));
 861            let signup = signup.update(&tx).await?;
 862
 863            if let Some(inviting_user_id) = signup.inviting_user_id {
 864                let result = user::Entity::update_many()
 865                    .filter(
 866                        user::Column::Id
 867                            .eq(inviting_user_id)
 868                            .and(user::Column::InviteCount.gt(0)),
 869                    )
 870                    .col_expr(
 871                        user::Column::InviteCount,
 872                        Expr::col(user::Column::InviteCount).sub(1),
 873                    )
 874                    .exec(&tx)
 875                    .await?;
 876
 877                if result.rows_affected == 0 {
 878                    Err(Error::Http(
 879                        StatusCode::UNAUTHORIZED,
 880                        "no invites remaining".to_string(),
 881                    ))?;
 882                }
 883
 884                contact::Entity::insert(contact::ActiveModel {
 885                    user_id_a: ActiveValue::set(inviting_user_id),
 886                    user_id_b: ActiveValue::set(user.id),
 887                    a_to_b: ActiveValue::set(true),
 888                    should_notify: ActiveValue::set(true),
 889                    accepted: ActiveValue::set(true),
 890                    ..Default::default()
 891                })
 892                .on_conflict(OnConflict::new().do_nothing().to_owned())
 893                .exec_without_returning(&tx)
 894                .await?;
 895            }
 896
 897            tx.commit().await?;
 898            Ok(Some(NewUserResult {
 899                user_id: user.id,
 900                metrics_id: user.metrics_id.to_string(),
 901                inviting_user_id: signup.inviting_user_id,
 902                signup_device_id: signup.device_id,
 903            }))
 904        })
 905        .await
 906    }
 907
 908    pub async fn set_invite_count_for_user(&self, id: UserId, count: i32) -> Result<()> {
 909        self.transact(|tx| async move {
 910            if count > 0 {
 911                user::Entity::update_many()
 912                    .filter(
 913                        user::Column::Id
 914                            .eq(id)
 915                            .and(user::Column::InviteCode.is_null()),
 916                    )
 917                    .set(user::ActiveModel {
 918                        invite_code: ActiveValue::set(Some(random_invite_code())),
 919                        ..Default::default()
 920                    })
 921                    .exec(&tx)
 922                    .await?;
 923            }
 924
 925            user::Entity::update_many()
 926                .filter(user::Column::Id.eq(id))
 927                .set(user::ActiveModel {
 928                    invite_count: ActiveValue::set(count),
 929                    ..Default::default()
 930                })
 931                .exec(&tx)
 932                .await?;
 933            tx.commit().await?;
 934            Ok(())
 935        })
 936        .await
 937    }
 938
 939    pub async fn get_invite_code_for_user(&self, id: UserId) -> Result<Option<(String, i32)>> {
 940        self.transact(|tx| async move {
 941            match user::Entity::find_by_id(id).one(&tx).await? {
 942                Some(user) if user.invite_code.is_some() => {
 943                    Ok(Some((user.invite_code.unwrap(), user.invite_count)))
 944                }
 945                _ => Ok(None),
 946            }
 947        })
 948        .await
 949    }
 950
 951    pub async fn get_user_for_invite_code(&self, code: &str) -> Result<User> {
 952        self.transact(|tx| async move {
 953            user::Entity::find()
 954                .filter(user::Column::InviteCode.eq(code))
 955                .one(&tx)
 956                .await?
 957                .ok_or_else(|| {
 958                    Error::Http(
 959                        StatusCode::NOT_FOUND,
 960                        "that invite code does not exist".to_string(),
 961                    )
 962                })
 963        })
 964        .await
 965    }
 966
 967    // rooms
 968
 969    pub async fn incoming_call_for_user(
 970        &self,
 971        user_id: UserId,
 972    ) -> Result<Option<proto::IncomingCall>> {
 973        self.transact(|tx| async move {
 974            let pending_participant = room_participant::Entity::find()
 975                .filter(
 976                    room_participant::Column::UserId
 977                        .eq(user_id)
 978                        .and(room_participant::Column::AnsweringConnectionId.is_null()),
 979                )
 980                .one(&tx)
 981                .await?;
 982
 983            if let Some(pending_participant) = pending_participant {
 984                let room = self.get_room(pending_participant.room_id, &tx).await?;
 985                Ok(Self::build_incoming_call(&room, user_id))
 986            } else {
 987                Ok(None)
 988            }
 989        })
 990        .await
 991    }
 992
 993    pub async fn create_room(
 994        &self,
 995        user_id: UserId,
 996        connection_id: ConnectionId,
 997        live_kit_room: &str,
 998    ) -> Result<RoomGuard<proto::Room>> {
 999        self.transact(|tx| async move {
1000            let room = room::ActiveModel {
1001                live_kit_room: ActiveValue::set(live_kit_room.into()),
1002                ..Default::default()
1003            }
1004            .insert(&tx)
1005            .await?;
1006            let room_id = room.id;
1007
1008            room_participant::ActiveModel {
1009                room_id: ActiveValue::set(room_id),
1010                user_id: ActiveValue::set(user_id),
1011                answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)),
1012                answering_connection_epoch: ActiveValue::set(Some(self.epoch)),
1013                calling_user_id: ActiveValue::set(user_id),
1014                calling_connection_id: ActiveValue::set(connection_id.0 as i32),
1015                calling_connection_epoch: ActiveValue::set(self.epoch),
1016                ..Default::default()
1017            }
1018            .insert(&tx)
1019            .await?;
1020
1021            let room = self.get_room(room_id, &tx).await?;
1022            self.commit_room_transaction(room_id, tx, room).await
1023        })
1024        .await
1025    }
1026
1027    pub async fn call(
1028        &self,
1029        room_id: RoomId,
1030        calling_user_id: UserId,
1031        calling_connection_id: ConnectionId,
1032        called_user_id: UserId,
1033        initial_project_id: Option<ProjectId>,
1034    ) -> Result<RoomGuard<(proto::Room, proto::IncomingCall)>> {
1035        self.transact(|tx| async move {
1036            room_participant::ActiveModel {
1037                room_id: ActiveValue::set(room_id),
1038                user_id: ActiveValue::set(called_user_id),
1039                calling_user_id: ActiveValue::set(calling_user_id),
1040                calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32),
1041                calling_connection_epoch: ActiveValue::set(self.epoch),
1042                initial_project_id: ActiveValue::set(initial_project_id),
1043                ..Default::default()
1044            }
1045            .insert(&tx)
1046            .await?;
1047
1048            let room = self.get_room(room_id, &tx).await?;
1049            let incoming_call = Self::build_incoming_call(&room, called_user_id)
1050                .ok_or_else(|| anyhow!("failed to build incoming call"))?;
1051            self.commit_room_transaction(room_id, tx, (room, incoming_call))
1052                .await
1053        })
1054        .await
1055    }
1056
1057    pub async fn call_failed(
1058        &self,
1059        room_id: RoomId,
1060        called_user_id: UserId,
1061    ) -> Result<RoomGuard<proto::Room>> {
1062        self.transact(|tx| async move {
1063            room_participant::Entity::delete_many()
1064                .filter(
1065                    room_participant::Column::RoomId
1066                        .eq(room_id)
1067                        .and(room_participant::Column::UserId.eq(called_user_id)),
1068                )
1069                .exec(&tx)
1070                .await?;
1071            let room = self.get_room(room_id, &tx).await?;
1072            self.commit_room_transaction(room_id, tx, room).await
1073        })
1074        .await
1075    }
1076
1077    pub async fn decline_call(
1078        &self,
1079        expected_room_id: Option<RoomId>,
1080        user_id: UserId,
1081    ) -> Result<RoomGuard<proto::Room>> {
1082        self.transact(|tx| async move {
1083            let participant = room_participant::Entity::find()
1084                .filter(
1085                    room_participant::Column::UserId
1086                        .eq(user_id)
1087                        .and(room_participant::Column::AnsweringConnectionId.is_null()),
1088                )
1089                .one(&tx)
1090                .await?
1091                .ok_or_else(|| anyhow!("could not decline call"))?;
1092            let room_id = participant.room_id;
1093
1094            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1095                return Err(anyhow!("declining call on unexpected room"))?;
1096            }
1097
1098            room_participant::Entity::delete(participant.into_active_model())
1099                .exec(&tx)
1100                .await?;
1101
1102            let room = self.get_room(room_id, &tx).await?;
1103            self.commit_room_transaction(room_id, tx, room).await
1104        })
1105        .await
1106    }
1107
1108    pub async fn cancel_call(
1109        &self,
1110        expected_room_id: Option<RoomId>,
1111        calling_connection_id: ConnectionId,
1112        called_user_id: UserId,
1113    ) -> Result<RoomGuard<proto::Room>> {
1114        self.transact(|tx| async move {
1115            let participant = room_participant::Entity::find()
1116                .filter(
1117                    room_participant::Column::UserId
1118                        .eq(called_user_id)
1119                        .and(
1120                            room_participant::Column::CallingConnectionId
1121                                .eq(calling_connection_id.0 as i32),
1122                        )
1123                        .and(room_participant::Column::AnsweringConnectionId.is_null()),
1124                )
1125                .one(&tx)
1126                .await?
1127                .ok_or_else(|| anyhow!("could not cancel call"))?;
1128            let room_id = participant.room_id;
1129            if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) {
1130                return Err(anyhow!("canceling call on unexpected room"))?;
1131            }
1132
1133            room_participant::Entity::delete(participant.into_active_model())
1134                .exec(&tx)
1135                .await?;
1136
1137            let room = self.get_room(room_id, &tx).await?;
1138            self.commit_room_transaction(room_id, tx, room).await
1139        })
1140        .await
1141    }
1142
1143    pub async fn join_room(
1144        &self,
1145        room_id: RoomId,
1146        user_id: UserId,
1147        connection_id: ConnectionId,
1148    ) -> Result<RoomGuard<proto::Room>> {
1149        self.transact(|tx| async move {
1150            let result = room_participant::Entity::update_many()
1151                .filter(
1152                    room_participant::Column::RoomId
1153                        .eq(room_id)
1154                        .and(room_participant::Column::UserId.eq(user_id))
1155                        .and(room_participant::Column::AnsweringConnectionId.is_null()),
1156                )
1157                .set(room_participant::ActiveModel {
1158                    answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)),
1159                    answering_connection_epoch: ActiveValue::set(Some(self.epoch)),
1160                    ..Default::default()
1161                })
1162                .exec(&tx)
1163                .await?;
1164            if result.rows_affected == 0 {
1165                Err(anyhow!("room does not exist or was already joined"))?
1166            } else {
1167                let room = self.get_room(room_id, &tx).await?;
1168                self.commit_room_transaction(room_id, tx, room).await
1169            }
1170        })
1171        .await
1172    }
1173
1174    pub async fn leave_room(
1175        &self,
1176        connection_id: ConnectionId,
1177    ) -> Result<Option<RoomGuard<LeftRoom>>> {
1178        self.transact(|tx| async move {
1179            let leaving_participant = room_participant::Entity::find()
1180                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
1181                .one(&tx)
1182                .await?;
1183
1184            if let Some(leaving_participant) = leaving_participant {
1185                // Leave room.
1186                let room_id = leaving_participant.room_id;
1187                room_participant::Entity::delete_by_id(leaving_participant.id)
1188                    .exec(&tx)
1189                    .await?;
1190
1191                // Cancel pending calls initiated by the leaving user.
1192                let called_participants = room_participant::Entity::find()
1193                    .filter(
1194                        room_participant::Column::CallingConnectionId
1195                            .eq(connection_id.0)
1196                            .and(room_participant::Column::AnsweringConnectionId.is_null()),
1197                    )
1198                    .all(&tx)
1199                    .await?;
1200                room_participant::Entity::delete_many()
1201                    .filter(
1202                        room_participant::Column::Id
1203                            .is_in(called_participants.iter().map(|participant| participant.id)),
1204                    )
1205                    .exec(&tx)
1206                    .await?;
1207                let canceled_calls_to_user_ids = called_participants
1208                    .into_iter()
1209                    .map(|participant| participant.user_id)
1210                    .collect();
1211
1212                // Detect left projects.
1213                #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1214                enum QueryProjectIds {
1215                    ProjectId,
1216                }
1217                let project_ids: Vec<ProjectId> = project_collaborator::Entity::find()
1218                    .select_only()
1219                    .column_as(
1220                        project_collaborator::Column::ProjectId,
1221                        QueryProjectIds::ProjectId,
1222                    )
1223                    .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0))
1224                    .into_values::<_, QueryProjectIds>()
1225                    .all(&tx)
1226                    .await?;
1227                let mut left_projects = HashMap::default();
1228                let mut collaborators = project_collaborator::Entity::find()
1229                    .filter(project_collaborator::Column::ProjectId.is_in(project_ids))
1230                    .stream(&tx)
1231                    .await?;
1232                while let Some(collaborator) = collaborators.next().await {
1233                    let collaborator = collaborator?;
1234                    let left_project =
1235                        left_projects
1236                            .entry(collaborator.project_id)
1237                            .or_insert(LeftProject {
1238                                id: collaborator.project_id,
1239                                host_user_id: Default::default(),
1240                                connection_ids: Default::default(),
1241                                host_connection_id: Default::default(),
1242                            });
1243
1244                    let collaborator_connection_id =
1245                        ConnectionId(collaborator.connection_id as u32);
1246                    if collaborator_connection_id != connection_id {
1247                        left_project.connection_ids.push(collaborator_connection_id);
1248                    }
1249
1250                    if collaborator.is_host {
1251                        left_project.host_user_id = collaborator.user_id;
1252                        left_project.host_connection_id =
1253                            ConnectionId(collaborator.connection_id as u32);
1254                    }
1255                }
1256                drop(collaborators);
1257
1258                // Leave projects.
1259                project_collaborator::Entity::delete_many()
1260                    .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0))
1261                    .exec(&tx)
1262                    .await?;
1263
1264                // Unshare projects.
1265                project::Entity::delete_many()
1266                    .filter(
1267                        project::Column::RoomId
1268                            .eq(room_id)
1269                            .and(project::Column::HostConnectionId.eq(connection_id.0)),
1270                    )
1271                    .exec(&tx)
1272                    .await?;
1273
1274                let room = self.get_room(room_id, &tx).await?;
1275                Ok(Some(
1276                    self.commit_room_transaction(
1277                        room_id,
1278                        tx,
1279                        LeftRoom {
1280                            room,
1281                            left_projects,
1282                            canceled_calls_to_user_ids,
1283                        },
1284                    )
1285                    .await?,
1286                ))
1287            } else {
1288                Ok(None)
1289            }
1290        })
1291        .await
1292    }
1293
1294    pub async fn update_room_participant_location(
1295        &self,
1296        room_id: RoomId,
1297        connection_id: ConnectionId,
1298        location: proto::ParticipantLocation,
1299    ) -> Result<RoomGuard<proto::Room>> {
1300        self.transact(|tx| async {
1301            let mut tx = tx;
1302            let location_kind;
1303            let location_project_id;
1304            match location
1305                .variant
1306                .as_ref()
1307                .ok_or_else(|| anyhow!("invalid location"))?
1308            {
1309                proto::participant_location::Variant::SharedProject(project) => {
1310                    location_kind = 0;
1311                    location_project_id = Some(ProjectId::from_proto(project.id));
1312                }
1313                proto::participant_location::Variant::UnsharedProject(_) => {
1314                    location_kind = 1;
1315                    location_project_id = None;
1316                }
1317                proto::participant_location::Variant::External(_) => {
1318                    location_kind = 2;
1319                    location_project_id = None;
1320                }
1321            }
1322
1323            let result = room_participant::Entity::update_many()
1324                .filter(
1325                    room_participant::Column::RoomId
1326                        .eq(room_id)
1327                        .and(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)),
1328                )
1329                .set(room_participant::ActiveModel {
1330                    location_kind: ActiveValue::set(Some(location_kind)),
1331                    location_project_id: ActiveValue::set(location_project_id),
1332                    ..Default::default()
1333                })
1334                .exec(&tx)
1335                .await?;
1336
1337            if result.rows_affected == 1 {
1338                let room = self.get_room(room_id, &mut tx).await?;
1339                self.commit_room_transaction(room_id, tx, room).await
1340            } else {
1341                Err(anyhow!("could not update room participant location"))?
1342            }
1343        })
1344        .await
1345    }
1346
1347    fn build_incoming_call(
1348        room: &proto::Room,
1349        called_user_id: UserId,
1350    ) -> Option<proto::IncomingCall> {
1351        let pending_participant = room
1352            .pending_participants
1353            .iter()
1354            .find(|participant| participant.user_id == called_user_id.to_proto())?;
1355
1356        Some(proto::IncomingCall {
1357            room_id: room.id,
1358            calling_user_id: pending_participant.calling_user_id,
1359            participant_user_ids: room
1360                .participants
1361                .iter()
1362                .map(|participant| participant.user_id)
1363                .collect(),
1364            initial_project: room.participants.iter().find_map(|participant| {
1365                let initial_project_id = pending_participant.initial_project_id?;
1366                participant
1367                    .projects
1368                    .iter()
1369                    .find(|project| project.id == initial_project_id)
1370                    .cloned()
1371            }),
1372        })
1373    }
1374
1375    async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
1376        let db_room = room::Entity::find_by_id(room_id)
1377            .one(tx)
1378            .await?
1379            .ok_or_else(|| anyhow!("could not find room"))?;
1380
1381        let mut db_participants = db_room
1382            .find_related(room_participant::Entity)
1383            .stream(tx)
1384            .await?;
1385        let mut participants = HashMap::default();
1386        let mut pending_participants = Vec::new();
1387        while let Some(db_participant) = db_participants.next().await {
1388            let db_participant = db_participant?;
1389            if let Some(answering_connection_id) = db_participant.answering_connection_id {
1390                let location = match (
1391                    db_participant.location_kind,
1392                    db_participant.location_project_id,
1393                ) {
1394                    (Some(0), Some(project_id)) => {
1395                        Some(proto::participant_location::Variant::SharedProject(
1396                            proto::participant_location::SharedProject {
1397                                id: project_id.to_proto(),
1398                            },
1399                        ))
1400                    }
1401                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
1402                        Default::default(),
1403                    )),
1404                    _ => Some(proto::participant_location::Variant::External(
1405                        Default::default(),
1406                    )),
1407                };
1408                participants.insert(
1409                    answering_connection_id,
1410                    proto::Participant {
1411                        user_id: db_participant.user_id.to_proto(),
1412                        peer_id: answering_connection_id as u32,
1413                        projects: Default::default(),
1414                        location: Some(proto::ParticipantLocation { variant: location }),
1415                    },
1416                );
1417            } else {
1418                pending_participants.push(proto::PendingParticipant {
1419                    user_id: db_participant.user_id.to_proto(),
1420                    calling_user_id: db_participant.calling_user_id.to_proto(),
1421                    initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
1422                });
1423            }
1424        }
1425        drop(db_participants);
1426
1427        let mut db_projects = db_room
1428            .find_related(project::Entity)
1429            .find_with_related(worktree::Entity)
1430            .stream(tx)
1431            .await?;
1432
1433        while let Some(row) = db_projects.next().await {
1434            let (db_project, db_worktree) = row?;
1435            if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
1436                let project = if let Some(project) = participant
1437                    .projects
1438                    .iter_mut()
1439                    .find(|project| project.id == db_project.id.to_proto())
1440                {
1441                    project
1442                } else {
1443                    participant.projects.push(proto::ParticipantProject {
1444                        id: db_project.id.to_proto(),
1445                        worktree_root_names: Default::default(),
1446                    });
1447                    participant.projects.last_mut().unwrap()
1448                };
1449
1450                if let Some(db_worktree) = db_worktree {
1451                    project.worktree_root_names.push(db_worktree.root_name);
1452                }
1453            }
1454        }
1455
1456        Ok(proto::Room {
1457            id: db_room.id.to_proto(),
1458            live_kit_room: db_room.live_kit_room,
1459            participants: participants.into_values().collect(),
1460            pending_participants,
1461        })
1462    }
1463
1464    async fn commit_room_transaction<T>(
1465        &self,
1466        room_id: RoomId,
1467        tx: DatabaseTransaction,
1468        data: T,
1469    ) -> Result<RoomGuard<T>> {
1470        let lock = self.rooms.entry(room_id).or_default().clone();
1471        let _guard = lock.lock_owned().await;
1472        tx.commit().await?;
1473        Ok(RoomGuard {
1474            data,
1475            _guard,
1476            _not_send: PhantomData,
1477        })
1478    }
1479
1480    // projects
1481
1482    pub async fn project_count_excluding_admins(&self) -> Result<usize> {
1483        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
1484        enum QueryAs {
1485            Count,
1486        }
1487
1488        self.transact(|tx| async move {
1489            Ok(project::Entity::find()
1490                .select_only()
1491                .column_as(project::Column::Id.count(), QueryAs::Count)
1492                .inner_join(user::Entity)
1493                .filter(user::Column::Admin.eq(false))
1494                .into_values::<_, QueryAs>()
1495                .one(&tx)
1496                .await?
1497                .unwrap_or(0) as usize)
1498        })
1499        .await
1500    }
1501
1502    pub async fn share_project(
1503        &self,
1504        room_id: RoomId,
1505        connection_id: ConnectionId,
1506        worktrees: &[proto::WorktreeMetadata],
1507    ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
1508        self.transact(|tx| async move {
1509            let participant = room_participant::Entity::find()
1510                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
1511                .one(&tx)
1512                .await?
1513                .ok_or_else(|| anyhow!("could not find participant"))?;
1514            if participant.room_id != room_id {
1515                return Err(anyhow!("shared project on unexpected room"))?;
1516            }
1517
1518            let project = project::ActiveModel {
1519                room_id: ActiveValue::set(participant.room_id),
1520                host_user_id: ActiveValue::set(participant.user_id),
1521                host_connection_id: ActiveValue::set(connection_id.0 as i32),
1522                host_connection_epoch: ActiveValue::set(self.epoch),
1523                ..Default::default()
1524            }
1525            .insert(&tx)
1526            .await?;
1527
1528            if !worktrees.is_empty() {
1529                worktree::Entity::insert_many(worktrees.iter().map(|worktree| {
1530                    worktree::ActiveModel {
1531                        id: ActiveValue::set(worktree.id as i64),
1532                        project_id: ActiveValue::set(project.id),
1533                        abs_path: ActiveValue::set(worktree.abs_path.clone()),
1534                        root_name: ActiveValue::set(worktree.root_name.clone()),
1535                        visible: ActiveValue::set(worktree.visible),
1536                        scan_id: ActiveValue::set(0),
1537                        is_complete: ActiveValue::set(false),
1538                    }
1539                }))
1540                .exec(&tx)
1541                .await?;
1542            }
1543
1544            project_collaborator::ActiveModel {
1545                project_id: ActiveValue::set(project.id),
1546                connection_id: ActiveValue::set(connection_id.0 as i32),
1547                connection_epoch: ActiveValue::set(self.epoch),
1548                user_id: ActiveValue::set(participant.user_id),
1549                replica_id: ActiveValue::set(ReplicaId(0)),
1550                is_host: ActiveValue::set(true),
1551                ..Default::default()
1552            }
1553            .insert(&tx)
1554            .await?;
1555
1556            let room = self.get_room(room_id, &tx).await?;
1557            self.commit_room_transaction(room_id, tx, (project.id, room))
1558                .await
1559        })
1560        .await
1561    }
1562
1563    pub async fn unshare_project(
1564        &self,
1565        project_id: ProjectId,
1566        connection_id: ConnectionId,
1567    ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1568        self.transact(|tx| async move {
1569            let guest_connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
1570
1571            let project = project::Entity::find_by_id(project_id)
1572                .one(&tx)
1573                .await?
1574                .ok_or_else(|| anyhow!("project not found"))?;
1575            if project.host_connection_id == connection_id.0 as i32 {
1576                let room_id = project.room_id;
1577                project::Entity::delete(project.into_active_model())
1578                    .exec(&tx)
1579                    .await?;
1580                let room = self.get_room(room_id, &tx).await?;
1581                self.commit_room_transaction(room_id, tx, (room, guest_connection_ids))
1582                    .await
1583            } else {
1584                Err(anyhow!("cannot unshare a project hosted by another user"))?
1585            }
1586        })
1587        .await
1588    }
1589
1590    pub async fn update_project(
1591        &self,
1592        project_id: ProjectId,
1593        connection_id: ConnectionId,
1594        worktrees: &[proto::WorktreeMetadata],
1595    ) -> Result<RoomGuard<(proto::Room, Vec<ConnectionId>)>> {
1596        self.transact(|tx| async move {
1597            let project = project::Entity::find_by_id(project_id)
1598                .filter(project::Column::HostConnectionId.eq(connection_id.0))
1599                .one(&tx)
1600                .await?
1601                .ok_or_else(|| anyhow!("no such project"))?;
1602
1603            if !worktrees.is_empty() {
1604                worktree::Entity::insert_many(worktrees.iter().map(|worktree| {
1605                    worktree::ActiveModel {
1606                        id: ActiveValue::set(worktree.id as i64),
1607                        project_id: ActiveValue::set(project.id),
1608                        abs_path: ActiveValue::set(worktree.abs_path.clone()),
1609                        root_name: ActiveValue::set(worktree.root_name.clone()),
1610                        visible: ActiveValue::set(worktree.visible),
1611                        scan_id: ActiveValue::set(0),
1612                        is_complete: ActiveValue::set(false),
1613                    }
1614                }))
1615                .on_conflict(
1616                    OnConflict::columns([worktree::Column::ProjectId, worktree::Column::Id])
1617                        .update_column(worktree::Column::RootName)
1618                        .to_owned(),
1619                )
1620                .exec(&tx)
1621                .await?;
1622            }
1623
1624            worktree::Entity::delete_many()
1625                .filter(
1626                    worktree::Column::ProjectId.eq(project.id).and(
1627                        worktree::Column::Id
1628                            .is_not_in(worktrees.iter().map(|worktree| worktree.id as i64)),
1629                    ),
1630                )
1631                .exec(&tx)
1632                .await?;
1633
1634            let guest_connection_ids = self.project_guest_connection_ids(project.id, &tx).await?;
1635            let room = self.get_room(project.room_id, &tx).await?;
1636            self.commit_room_transaction(project.room_id, tx, (room, guest_connection_ids))
1637                .await
1638        })
1639        .await
1640    }
1641
1642    pub async fn update_worktree(
1643        &self,
1644        update: &proto::UpdateWorktree,
1645        connection_id: ConnectionId,
1646    ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1647        self.transact(|tx| async move {
1648            let project_id = ProjectId::from_proto(update.project_id);
1649            let worktree_id = update.worktree_id as i64;
1650
1651            // Ensure the update comes from the host.
1652            let project = project::Entity::find_by_id(project_id)
1653                .filter(project::Column::HostConnectionId.eq(connection_id.0))
1654                .one(&tx)
1655                .await?
1656                .ok_or_else(|| anyhow!("no such project"))?;
1657            let room_id = project.room_id;
1658
1659            // Update metadata.
1660            worktree::Entity::update(worktree::ActiveModel {
1661                id: ActiveValue::set(worktree_id),
1662                project_id: ActiveValue::set(project_id),
1663                root_name: ActiveValue::set(update.root_name.clone()),
1664                scan_id: ActiveValue::set(update.scan_id as i64),
1665                is_complete: ActiveValue::set(update.is_last_update),
1666                abs_path: ActiveValue::set(update.abs_path.clone()),
1667                ..Default::default()
1668            })
1669            .exec(&tx)
1670            .await?;
1671
1672            if !update.updated_entries.is_empty() {
1673                worktree_entry::Entity::insert_many(update.updated_entries.iter().map(|entry| {
1674                    let mtime = entry.mtime.clone().unwrap_or_default();
1675                    worktree_entry::ActiveModel {
1676                        project_id: ActiveValue::set(project_id),
1677                        worktree_id: ActiveValue::set(worktree_id),
1678                        id: ActiveValue::set(entry.id as i64),
1679                        is_dir: ActiveValue::set(entry.is_dir),
1680                        path: ActiveValue::set(entry.path.clone()),
1681                        inode: ActiveValue::set(entry.inode as i64),
1682                        mtime_seconds: ActiveValue::set(mtime.seconds as i64),
1683                        mtime_nanos: ActiveValue::set(mtime.nanos as i32),
1684                        is_symlink: ActiveValue::set(entry.is_symlink),
1685                        is_ignored: ActiveValue::set(entry.is_ignored),
1686                    }
1687                }))
1688                .on_conflict(
1689                    OnConflict::columns([
1690                        worktree_entry::Column::ProjectId,
1691                        worktree_entry::Column::WorktreeId,
1692                        worktree_entry::Column::Id,
1693                    ])
1694                    .update_columns([
1695                        worktree_entry::Column::IsDir,
1696                        worktree_entry::Column::Path,
1697                        worktree_entry::Column::Inode,
1698                        worktree_entry::Column::MtimeSeconds,
1699                        worktree_entry::Column::MtimeNanos,
1700                        worktree_entry::Column::IsSymlink,
1701                        worktree_entry::Column::IsIgnored,
1702                    ])
1703                    .to_owned(),
1704                )
1705                .exec(&tx)
1706                .await?;
1707            }
1708
1709            if !update.removed_entries.is_empty() {
1710                worktree_entry::Entity::delete_many()
1711                    .filter(
1712                        worktree_entry::Column::ProjectId
1713                            .eq(project_id)
1714                            .and(worktree_entry::Column::WorktreeId.eq(worktree_id))
1715                            .and(
1716                                worktree_entry::Column::Id
1717                                    .is_in(update.removed_entries.iter().map(|id| *id as i64)),
1718                            ),
1719                    )
1720                    .exec(&tx)
1721                    .await?;
1722            }
1723
1724            let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
1725            self.commit_room_transaction(room_id, tx, connection_ids)
1726                .await
1727        })
1728        .await
1729    }
1730
1731    pub async fn update_diagnostic_summary(
1732        &self,
1733        update: &proto::UpdateDiagnosticSummary,
1734        connection_id: ConnectionId,
1735    ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1736        self.transact(|tx| async {
1737            let project_id = ProjectId::from_proto(update.project_id);
1738            let worktree_id = update.worktree_id as i64;
1739            let summary = update
1740                .summary
1741                .as_ref()
1742                .ok_or_else(|| anyhow!("invalid summary"))?;
1743
1744            // Ensure the update comes from the host.
1745            let project = project::Entity::find_by_id(project_id)
1746                .one(&tx)
1747                .await?
1748                .ok_or_else(|| anyhow!("no such project"))?;
1749            if project.host_connection_id != connection_id.0 as i32 {
1750                return Err(anyhow!("can't update a project hosted by someone else"))?;
1751            }
1752
1753            // Update summary.
1754            worktree_diagnostic_summary::Entity::insert(worktree_diagnostic_summary::ActiveModel {
1755                project_id: ActiveValue::set(project_id),
1756                worktree_id: ActiveValue::set(worktree_id),
1757                path: ActiveValue::set(summary.path.clone()),
1758                language_server_id: ActiveValue::set(summary.language_server_id as i64),
1759                error_count: ActiveValue::set(summary.error_count as i32),
1760                warning_count: ActiveValue::set(summary.warning_count as i32),
1761                ..Default::default()
1762            })
1763            .on_conflict(
1764                OnConflict::columns([
1765                    worktree_diagnostic_summary::Column::ProjectId,
1766                    worktree_diagnostic_summary::Column::WorktreeId,
1767                    worktree_diagnostic_summary::Column::Path,
1768                ])
1769                .update_columns([
1770                    worktree_diagnostic_summary::Column::LanguageServerId,
1771                    worktree_diagnostic_summary::Column::ErrorCount,
1772                    worktree_diagnostic_summary::Column::WarningCount,
1773                ])
1774                .to_owned(),
1775            )
1776            .exec(&tx)
1777            .await?;
1778
1779            let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
1780            self.commit_room_transaction(project.room_id, tx, connection_ids)
1781                .await
1782        })
1783        .await
1784    }
1785
1786    pub async fn start_language_server(
1787        &self,
1788        update: &proto::StartLanguageServer,
1789        connection_id: ConnectionId,
1790    ) -> Result<RoomGuard<Vec<ConnectionId>>> {
1791        self.transact(|tx| async {
1792            let project_id = ProjectId::from_proto(update.project_id);
1793            let server = update
1794                .server
1795                .as_ref()
1796                .ok_or_else(|| anyhow!("invalid language server"))?;
1797
1798            // Ensure the update comes from the host.
1799            let project = project::Entity::find_by_id(project_id)
1800                .one(&tx)
1801                .await?
1802                .ok_or_else(|| anyhow!("no such project"))?;
1803            if project.host_connection_id != connection_id.0 as i32 {
1804                return Err(anyhow!("can't update a project hosted by someone else"))?;
1805            }
1806
1807            // Add the newly-started language server.
1808            language_server::Entity::insert(language_server::ActiveModel {
1809                project_id: ActiveValue::set(project_id),
1810                id: ActiveValue::set(server.id as i64),
1811                name: ActiveValue::set(server.name.clone()),
1812                ..Default::default()
1813            })
1814            .on_conflict(
1815                OnConflict::columns([
1816                    language_server::Column::ProjectId,
1817                    language_server::Column::Id,
1818                ])
1819                .update_column(language_server::Column::Name)
1820                .to_owned(),
1821            )
1822            .exec(&tx)
1823            .await?;
1824
1825            let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?;
1826            self.commit_room_transaction(project.room_id, tx, connection_ids)
1827                .await
1828        })
1829        .await
1830    }
1831
1832    pub async fn join_project(
1833        &self,
1834        project_id: ProjectId,
1835        connection_id: ConnectionId,
1836    ) -> Result<RoomGuard<(Project, ReplicaId)>> {
1837        self.transact(|tx| async move {
1838            let participant = room_participant::Entity::find()
1839                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
1840                .one(&tx)
1841                .await?
1842                .ok_or_else(|| anyhow!("must join a room first"))?;
1843
1844            let project = project::Entity::find_by_id(project_id)
1845                .one(&tx)
1846                .await?
1847                .ok_or_else(|| anyhow!("no such project"))?;
1848            if project.room_id != participant.room_id {
1849                return Err(anyhow!("no such project"))?;
1850            }
1851
1852            let mut collaborators = project
1853                .find_related(project_collaborator::Entity)
1854                .all(&tx)
1855                .await?;
1856            let replica_ids = collaborators
1857                .iter()
1858                .map(|c| c.replica_id)
1859                .collect::<HashSet<_>>();
1860            let mut replica_id = ReplicaId(1);
1861            while replica_ids.contains(&replica_id) {
1862                replica_id.0 += 1;
1863            }
1864            let new_collaborator = project_collaborator::ActiveModel {
1865                project_id: ActiveValue::set(project_id),
1866                connection_id: ActiveValue::set(connection_id.0 as i32),
1867                connection_epoch: ActiveValue::set(self.epoch),
1868                user_id: ActiveValue::set(participant.user_id),
1869                replica_id: ActiveValue::set(replica_id),
1870                is_host: ActiveValue::set(false),
1871                ..Default::default()
1872            }
1873            .insert(&tx)
1874            .await?;
1875            collaborators.push(new_collaborator);
1876
1877            let db_worktrees = project.find_related(worktree::Entity).all(&tx).await?;
1878            let mut worktrees = db_worktrees
1879                .into_iter()
1880                .map(|db_worktree| {
1881                    (
1882                        db_worktree.id as u64,
1883                        Worktree {
1884                            id: db_worktree.id as u64,
1885                            abs_path: db_worktree.abs_path,
1886                            root_name: db_worktree.root_name,
1887                            visible: db_worktree.visible,
1888                            entries: Default::default(),
1889                            diagnostic_summaries: Default::default(),
1890                            scan_id: db_worktree.scan_id as u64,
1891                            is_complete: db_worktree.is_complete,
1892                        },
1893                    )
1894                })
1895                .collect::<BTreeMap<_, _>>();
1896
1897            // Populate worktree entries.
1898            {
1899                let mut db_entries = worktree_entry::Entity::find()
1900                    .filter(worktree_entry::Column::ProjectId.eq(project_id))
1901                    .stream(&tx)
1902                    .await?;
1903                while let Some(db_entry) = db_entries.next().await {
1904                    let db_entry = db_entry?;
1905                    if let Some(worktree) = worktrees.get_mut(&(db_entry.worktree_id as u64)) {
1906                        worktree.entries.push(proto::Entry {
1907                            id: db_entry.id as u64,
1908                            is_dir: db_entry.is_dir,
1909                            path: db_entry.path,
1910                            inode: db_entry.inode as u64,
1911                            mtime: Some(proto::Timestamp {
1912                                seconds: db_entry.mtime_seconds as u64,
1913                                nanos: db_entry.mtime_nanos as u32,
1914                            }),
1915                            is_symlink: db_entry.is_symlink,
1916                            is_ignored: db_entry.is_ignored,
1917                        });
1918                    }
1919                }
1920            }
1921
1922            // Populate worktree diagnostic summaries.
1923            {
1924                let mut db_summaries = worktree_diagnostic_summary::Entity::find()
1925                    .filter(worktree_diagnostic_summary::Column::ProjectId.eq(project_id))
1926                    .stream(&tx)
1927                    .await?;
1928                while let Some(db_summary) = db_summaries.next().await {
1929                    let db_summary = db_summary?;
1930                    if let Some(worktree) = worktrees.get_mut(&(db_summary.worktree_id as u64)) {
1931                        worktree
1932                            .diagnostic_summaries
1933                            .push(proto::DiagnosticSummary {
1934                                path: db_summary.path,
1935                                language_server_id: db_summary.language_server_id as u64,
1936                                error_count: db_summary.error_count as u32,
1937                                warning_count: db_summary.warning_count as u32,
1938                            });
1939                    }
1940                }
1941            }
1942
1943            // Populate language servers.
1944            let language_servers = project
1945                .find_related(language_server::Entity)
1946                .all(&tx)
1947                .await?;
1948
1949            self.commit_room_transaction(
1950                project.room_id,
1951                tx,
1952                (
1953                    Project {
1954                        collaborators,
1955                        worktrees,
1956                        language_servers: language_servers
1957                            .into_iter()
1958                            .map(|language_server| proto::LanguageServer {
1959                                id: language_server.id as u64,
1960                                name: language_server.name,
1961                            })
1962                            .collect(),
1963                    },
1964                    replica_id as ReplicaId,
1965                ),
1966            )
1967            .await
1968        })
1969        .await
1970    }
1971
1972    pub async fn leave_project(
1973        &self,
1974        project_id: ProjectId,
1975        connection_id: ConnectionId,
1976    ) -> Result<RoomGuard<LeftProject>> {
1977        self.transact(|tx| async move {
1978            let result = project_collaborator::Entity::delete_many()
1979                .filter(
1980                    project_collaborator::Column::ProjectId
1981                        .eq(project_id)
1982                        .and(project_collaborator::Column::ConnectionId.eq(connection_id.0)),
1983                )
1984                .exec(&tx)
1985                .await?;
1986            if result.rows_affected == 0 {
1987                Err(anyhow!("not a collaborator on this project"))?;
1988            }
1989
1990            let project = project::Entity::find_by_id(project_id)
1991                .one(&tx)
1992                .await?
1993                .ok_or_else(|| anyhow!("no such project"))?;
1994            let collaborators = project
1995                .find_related(project_collaborator::Entity)
1996                .all(&tx)
1997                .await?;
1998            let connection_ids = collaborators
1999                .into_iter()
2000                .map(|collaborator| ConnectionId(collaborator.connection_id as u32))
2001                .collect();
2002
2003            self.commit_room_transaction(
2004                project.room_id,
2005                tx,
2006                LeftProject {
2007                    id: project_id,
2008                    host_user_id: project.host_user_id,
2009                    host_connection_id: ConnectionId(project.host_connection_id as u32),
2010                    connection_ids,
2011                },
2012            )
2013            .await
2014        })
2015        .await
2016    }
2017
2018    pub async fn project_collaborators(
2019        &self,
2020        project_id: ProjectId,
2021        connection_id: ConnectionId,
2022    ) -> Result<Vec<project_collaborator::Model>> {
2023        self.transact(|tx| async move {
2024            let collaborators = project_collaborator::Entity::find()
2025                .filter(project_collaborator::Column::ProjectId.eq(project_id))
2026                .all(&tx)
2027                .await?;
2028
2029            if collaborators
2030                .iter()
2031                .any(|collaborator| collaborator.connection_id == connection_id.0 as i32)
2032            {
2033                Ok(collaborators)
2034            } else {
2035                Err(anyhow!("no such project"))?
2036            }
2037        })
2038        .await
2039    }
2040
2041    pub async fn project_connection_ids(
2042        &self,
2043        project_id: ProjectId,
2044        connection_id: ConnectionId,
2045    ) -> Result<HashSet<ConnectionId>> {
2046        self.transact(|tx| async move {
2047            #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
2048            enum QueryAs {
2049                ConnectionId,
2050            }
2051
2052            let mut db_connection_ids = project_collaborator::Entity::find()
2053                .select_only()
2054                .column_as(
2055                    project_collaborator::Column::ConnectionId,
2056                    QueryAs::ConnectionId,
2057                )
2058                .filter(project_collaborator::Column::ProjectId.eq(project_id))
2059                .into_values::<i32, QueryAs>()
2060                .stream(&tx)
2061                .await?;
2062
2063            let mut connection_ids = HashSet::default();
2064            while let Some(connection_id) = db_connection_ids.next().await {
2065                connection_ids.insert(ConnectionId(connection_id? as u32));
2066            }
2067
2068            if connection_ids.contains(&connection_id) {
2069                Ok(connection_ids)
2070            } else {
2071                Err(anyhow!("no such project"))?
2072            }
2073        })
2074        .await
2075    }
2076
2077    async fn project_guest_connection_ids(
2078        &self,
2079        project_id: ProjectId,
2080        tx: &DatabaseTransaction,
2081    ) -> Result<Vec<ConnectionId>> {
2082        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
2083        enum QueryAs {
2084            ConnectionId,
2085        }
2086
2087        let mut db_guest_connection_ids = project_collaborator::Entity::find()
2088            .select_only()
2089            .column_as(
2090                project_collaborator::Column::ConnectionId,
2091                QueryAs::ConnectionId,
2092            )
2093            .filter(
2094                project_collaborator::Column::ProjectId
2095                    .eq(project_id)
2096                    .and(project_collaborator::Column::IsHost.eq(false)),
2097            )
2098            .into_values::<i32, QueryAs>()
2099            .stream(tx)
2100            .await?;
2101
2102        let mut guest_connection_ids = Vec::new();
2103        while let Some(connection_id) = db_guest_connection_ids.next().await {
2104            guest_connection_ids.push(ConnectionId(connection_id? as u32));
2105        }
2106        Ok(guest_connection_ids)
2107    }
2108
2109    // access tokens
2110
2111    pub async fn create_access_token_hash(
2112        &self,
2113        user_id: UserId,
2114        access_token_hash: &str,
2115        max_access_token_count: usize,
2116    ) -> Result<()> {
2117        self.transact(|tx| async {
2118            let tx = tx;
2119
2120            access_token::ActiveModel {
2121                user_id: ActiveValue::set(user_id),
2122                hash: ActiveValue::set(access_token_hash.into()),
2123                ..Default::default()
2124            }
2125            .insert(&tx)
2126            .await?;
2127
2128            access_token::Entity::delete_many()
2129                .filter(
2130                    access_token::Column::Id.in_subquery(
2131                        Query::select()
2132                            .column(access_token::Column::Id)
2133                            .from(access_token::Entity)
2134                            .and_where(access_token::Column::UserId.eq(user_id))
2135                            .order_by(access_token::Column::Id, sea_orm::Order::Desc)
2136                            .limit(10000)
2137                            .offset(max_access_token_count as u64)
2138                            .to_owned(),
2139                    ),
2140                )
2141                .exec(&tx)
2142                .await?;
2143            tx.commit().await?;
2144            Ok(())
2145        })
2146        .await
2147    }
2148
2149    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
2150        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
2151        enum QueryAs {
2152            Hash,
2153        }
2154
2155        self.transact(|tx| async move {
2156            Ok(access_token::Entity::find()
2157                .select_only()
2158                .column(access_token::Column::Hash)
2159                .filter(access_token::Column::UserId.eq(user_id))
2160                .order_by_desc(access_token::Column::Id)
2161                .into_values::<_, QueryAs>()
2162                .all(&tx)
2163                .await?)
2164        })
2165        .await
2166    }
2167
2168    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
2169    where
2170        F: Send + Fn(DatabaseTransaction) -> Fut,
2171        Fut: Send + Future<Output = Result<T>>,
2172    {
2173        let body = async {
2174            loop {
2175                let tx = self.pool.begin().await?;
2176
2177                // In Postgres, serializable transactions are opt-in
2178                if let DatabaseBackend::Postgres = self.pool.get_database_backend() {
2179                    tx.execute(Statement::from_string(
2180                        DatabaseBackend::Postgres,
2181                        "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
2182                    ))
2183                    .await?;
2184                }
2185
2186                match f(tx).await {
2187                    Ok(result) => return Ok(result),
2188                    Err(error) => match error {
2189                        Error::Database2(
2190                            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
2191                            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
2192                        ) if error
2193                            .as_database_error()
2194                            .and_then(|error| error.code())
2195                            .as_deref()
2196                            == Some("40001") =>
2197                        {
2198                            // Retry (don't break the loop)
2199                        }
2200                        error @ _ => return Err(error),
2201                    },
2202                }
2203            }
2204        };
2205
2206        #[cfg(test)]
2207        {
2208            if let Some(background) = self.background.as_ref() {
2209                background.simulate_random_delay().await;
2210            }
2211
2212            self.runtime.as_ref().unwrap().block_on(body)
2213        }
2214
2215        #[cfg(not(test))]
2216        {
2217            body.await
2218        }
2219    }
2220}
2221
2222pub struct RoomGuard<T> {
2223    data: T,
2224    _guard: OwnedMutexGuard<()>,
2225    _not_send: PhantomData<Rc<()>>,
2226}
2227
2228impl<T> Deref for RoomGuard<T> {
2229    type Target = T;
2230
2231    fn deref(&self) -> &T {
2232        &self.data
2233    }
2234}
2235
2236impl<T> DerefMut for RoomGuard<T> {
2237    fn deref_mut(&mut self) -> &mut T {
2238        &mut self.data
2239    }
2240}
2241
2242#[derive(Debug, Serialize, Deserialize)]
2243pub struct NewUserParams {
2244    pub github_login: String,
2245    pub github_user_id: i32,
2246    pub invite_count: i32,
2247}
2248
2249#[derive(Debug)]
2250pub struct NewUserResult {
2251    pub user_id: UserId,
2252    pub metrics_id: String,
2253    pub inviting_user_id: Option<UserId>,
2254    pub signup_device_id: Option<String>,
2255}
2256
2257fn random_invite_code() -> String {
2258    nanoid::nanoid!(16)
2259}
2260
2261fn random_email_confirmation_code() -> String {
2262    nanoid::nanoid!(64)
2263}
2264
2265macro_rules! id_type {
2266    ($name:ident) => {
2267        #[derive(
2268            Clone,
2269            Copy,
2270            Debug,
2271            Default,
2272            PartialEq,
2273            Eq,
2274            PartialOrd,
2275            Ord,
2276            Hash,
2277            Serialize,
2278            Deserialize,
2279        )]
2280        #[serde(transparent)]
2281        pub struct $name(pub i32);
2282
2283        impl $name {
2284            #[allow(unused)]
2285            pub const MAX: Self = Self(i32::MAX);
2286
2287            #[allow(unused)]
2288            pub fn from_proto(value: u64) -> Self {
2289                Self(value as i32)
2290            }
2291
2292            #[allow(unused)]
2293            pub fn to_proto(self) -> u64 {
2294                self.0 as u64
2295            }
2296        }
2297
2298        impl std::fmt::Display for $name {
2299            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
2300                self.0.fmt(f)
2301            }
2302        }
2303
2304        impl From<$name> for sea_query::Value {
2305            fn from(value: $name) -> Self {
2306                sea_query::Value::Int(Some(value.0))
2307            }
2308        }
2309
2310        impl sea_orm::TryGetable for $name {
2311            fn try_get(
2312                res: &sea_orm::QueryResult,
2313                pre: &str,
2314                col: &str,
2315            ) -> Result<Self, sea_orm::TryGetError> {
2316                Ok(Self(i32::try_get(res, pre, col)?))
2317            }
2318        }
2319
2320        impl sea_query::ValueType for $name {
2321            fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
2322                match v {
2323                    Value::TinyInt(Some(int)) => {
2324                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
2325                    }
2326                    Value::SmallInt(Some(int)) => {
2327                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
2328                    }
2329                    Value::Int(Some(int)) => {
2330                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
2331                    }
2332                    Value::BigInt(Some(int)) => {
2333                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
2334                    }
2335                    Value::TinyUnsigned(Some(int)) => {
2336                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
2337                    }
2338                    Value::SmallUnsigned(Some(int)) => {
2339                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
2340                    }
2341                    Value::Unsigned(Some(int)) => {
2342                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
2343                    }
2344                    Value::BigUnsigned(Some(int)) => {
2345                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
2346                    }
2347                    _ => Err(sea_query::ValueTypeErr),
2348                }
2349            }
2350
2351            fn type_name() -> String {
2352                stringify!($name).into()
2353            }
2354
2355            fn array_type() -> sea_query::ArrayType {
2356                sea_query::ArrayType::Int
2357            }
2358
2359            fn column_type() -> sea_query::ColumnType {
2360                sea_query::ColumnType::Integer(None)
2361            }
2362        }
2363
2364        impl sea_orm::TryFromU64 for $name {
2365            fn try_from_u64(n: u64) -> Result<Self, DbErr> {
2366                Ok(Self(n.try_into().map_err(|_| {
2367                    DbErr::ConvertFromU64(concat!(
2368                        "error converting ",
2369                        stringify!($name),
2370                        " to u64"
2371                    ))
2372                })?))
2373            }
2374        }
2375
2376        impl sea_query::Nullable for $name {
2377            fn null() -> Value {
2378                Value::Int(None)
2379            }
2380        }
2381    };
2382}
2383
2384id_type!(AccessTokenId);
2385id_type!(ContactId);
2386id_type!(RoomId);
2387id_type!(RoomParticipantId);
2388id_type!(ProjectId);
2389id_type!(ProjectCollaboratorId);
2390id_type!(ReplicaId);
2391id_type!(SignupId);
2392id_type!(UserId);
2393
2394pub struct LeftRoom {
2395    pub room: proto::Room,
2396    pub left_projects: HashMap<ProjectId, LeftProject>,
2397    pub canceled_calls_to_user_ids: Vec<UserId>,
2398}
2399
2400pub struct Project {
2401    pub collaborators: Vec<project_collaborator::Model>,
2402    pub worktrees: BTreeMap<u64, Worktree>,
2403    pub language_servers: Vec<proto::LanguageServer>,
2404}
2405
2406pub struct LeftProject {
2407    pub id: ProjectId,
2408    pub host_user_id: UserId,
2409    pub host_connection_id: ConnectionId,
2410    pub connection_ids: Vec<ConnectionId>,
2411}
2412
2413pub struct Worktree {
2414    pub id: u64,
2415    pub abs_path: String,
2416    pub root_name: String,
2417    pub visible: bool,
2418    pub entries: Vec<proto::Entry>,
2419    pub diagnostic_summaries: Vec<proto::DiagnosticSummary>,
2420    pub scan_id: u64,
2421    pub is_complete: bool,
2422}
2423
2424#[cfg(test)]
2425pub use test::*;
2426
2427#[cfg(test)]
2428mod test {
2429    use super::*;
2430    use gpui::executor::Background;
2431    use lazy_static::lazy_static;
2432    use parking_lot::Mutex;
2433    use rand::prelude::*;
2434    use sea_orm::ConnectionTrait;
2435    use sqlx::migrate::MigrateDatabase;
2436    use std::sync::Arc;
2437
2438    pub struct TestDb {
2439        pub db: Option<Arc<Database>>,
2440        pub connection: Option<sqlx::AnyConnection>,
2441    }
2442
2443    impl TestDb {
2444        pub fn sqlite(background: Arc<Background>) -> Self {
2445            let url = format!("sqlite::memory:");
2446            let runtime = tokio::runtime::Builder::new_current_thread()
2447                .enable_io()
2448                .enable_time()
2449                .build()
2450                .unwrap();
2451
2452            let mut db = runtime.block_on(async {
2453                let mut options = ConnectOptions::new(url);
2454                options.max_connections(5);
2455                let db = Database::new(options).await.unwrap();
2456                let sql = include_str!(concat!(
2457                    env!("CARGO_MANIFEST_DIR"),
2458                    "/migrations.sqlite/20221109000000_test_schema.sql"
2459                ));
2460                db.pool
2461                    .execute(sea_orm::Statement::from_string(
2462                        db.pool.get_database_backend(),
2463                        sql.into(),
2464                    ))
2465                    .await
2466                    .unwrap();
2467                db
2468            });
2469
2470            db.background = Some(background);
2471            db.runtime = Some(runtime);
2472
2473            Self {
2474                db: Some(Arc::new(db)),
2475                connection: None,
2476            }
2477        }
2478
2479        pub fn postgres(background: Arc<Background>) -> Self {
2480            lazy_static! {
2481                static ref LOCK: Mutex<()> = Mutex::new(());
2482            }
2483
2484            let _guard = LOCK.lock();
2485            let mut rng = StdRng::from_entropy();
2486            let url = format!(
2487                "postgres://postgres@localhost/zed-test-{}",
2488                rng.gen::<u128>()
2489            );
2490            let runtime = tokio::runtime::Builder::new_current_thread()
2491                .enable_io()
2492                .enable_time()
2493                .build()
2494                .unwrap();
2495
2496            let mut db = runtime.block_on(async {
2497                sqlx::Postgres::create_database(&url)
2498                    .await
2499                    .expect("failed to create test db");
2500                let mut options = ConnectOptions::new(url);
2501                options
2502                    .max_connections(5)
2503                    .idle_timeout(Duration::from_secs(0));
2504                let db = Database::new(options).await.unwrap();
2505                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
2506                db.migrate(Path::new(migrations_path), false).await.unwrap();
2507                db
2508            });
2509
2510            db.background = Some(background);
2511            db.runtime = Some(runtime);
2512
2513            Self {
2514                db: Some(Arc::new(db)),
2515                connection: None,
2516            }
2517        }
2518
2519        pub fn db(&self) -> &Arc<Database> {
2520            self.db.as_ref().unwrap()
2521        }
2522    }
2523
2524    impl Drop for TestDb {
2525        fn drop(&mut self) {
2526            let db = self.db.take().unwrap();
2527            if let DatabaseBackend::Postgres = db.pool.get_database_backend() {
2528                db.runtime.as_ref().unwrap().block_on(async {
2529                    use util::ResultExt;
2530                    let query = "
2531                        SELECT pg_terminate_backend(pg_stat_activity.pid)
2532                        FROM pg_stat_activity
2533                        WHERE
2534                            pg_stat_activity.datname = current_database() AND
2535                            pid <> pg_backend_pid();
2536                    ";
2537                    db.pool
2538                        .execute(sea_orm::Statement::from_string(
2539                            db.pool.get_database_backend(),
2540                            query.into(),
2541                        ))
2542                        .await
2543                        .log_err();
2544                    sqlx::Postgres::drop_database(db.options.get_url())
2545                        .await
2546                        .log_err();
2547                })
2548            }
2549        }
2550    }
2551}