db2.rs

  1mod access_token;
  2mod project;
  3mod project_collaborator;
  4mod room;
  5mod room_participant;
  6#[cfg(test)]
  7mod tests;
  8mod user;
  9mod worktree;
 10
 11use crate::{Error, Result};
 12use anyhow::anyhow;
 13use collections::HashMap;
 14use dashmap::DashMap;
 15use futures::StreamExt;
 16use rpc::{proto, ConnectionId};
 17use sea_orm::{
 18    entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
 19    TransactionTrait,
 20};
 21use sea_orm::{ActiveValue, ConnectionTrait, IntoActiveModel, QueryOrder, QuerySelect};
 22use sea_query::{OnConflict, Query};
 23use serde::{Deserialize, Serialize};
 24use sqlx::migrate::{Migrate, Migration, MigrationSource};
 25use sqlx::Connection;
 26use std::ops::{Deref, DerefMut};
 27use std::path::Path;
 28use std::time::Duration;
 29use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
 30use tokio::sync::{Mutex, OwnedMutexGuard};
 31
 32pub use user::Model as User;
 33
 34pub struct Database {
 35    options: ConnectOptions,
 36    pool: DatabaseConnection,
 37    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 38    #[cfg(test)]
 39    background: Option<std::sync::Arc<gpui::executor::Background>>,
 40    #[cfg(test)]
 41    runtime: Option<tokio::runtime::Runtime>,
 42}
 43
 44impl Database {
 45    pub async fn new(options: ConnectOptions) -> Result<Self> {
 46        Ok(Self {
 47            options: options.clone(),
 48            pool: sea_orm::Database::connect(options).await?,
 49            rooms: DashMap::with_capacity(16384),
 50            #[cfg(test)]
 51            background: None,
 52            #[cfg(test)]
 53            runtime: None,
 54        })
 55    }
 56
 57    pub async fn migrate(
 58        &self,
 59        migrations_path: &Path,
 60        ignore_checksum_mismatch: bool,
 61    ) -> anyhow::Result<Vec<(Migration, Duration)>> {
 62        let migrations = MigrationSource::resolve(migrations_path)
 63            .await
 64            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 65
 66        let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
 67
 68        connection.ensure_migrations_table().await?;
 69        let applied_migrations: HashMap<_, _> = connection
 70            .list_applied_migrations()
 71            .await?
 72            .into_iter()
 73            .map(|m| (m.version, m))
 74            .collect();
 75
 76        let mut new_migrations = Vec::new();
 77        for migration in migrations {
 78            match applied_migrations.get(&migration.version) {
 79                Some(applied_migration) => {
 80                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 81                    {
 82                        Err(anyhow!(
 83                            "checksum mismatch for applied migration {}",
 84                            migration.description
 85                        ))?;
 86                    }
 87                }
 88                None => {
 89                    let elapsed = connection.apply(&migration).await?;
 90                    new_migrations.push((migration, elapsed));
 91                }
 92            }
 93        }
 94
 95        Ok(new_migrations)
 96    }
 97
 98    pub async fn create_user(
 99        &self,
100        email_address: &str,
101        admin: bool,
102        params: NewUserParams,
103    ) -> Result<NewUserResult> {
104        self.transact(|tx| async {
105            let user = user::Entity::insert(user::ActiveModel {
106                email_address: ActiveValue::set(Some(email_address.into())),
107                github_login: ActiveValue::set(params.github_login.clone()),
108                github_user_id: ActiveValue::set(Some(params.github_user_id)),
109                admin: ActiveValue::set(admin),
110                metrics_id: ActiveValue::set(Uuid::new_v4()),
111                ..Default::default()
112            })
113            .on_conflict(
114                OnConflict::column(user::Column::GithubLogin)
115                    .update_column(user::Column::GithubLogin)
116                    .to_owned(),
117            )
118            .exec_with_returning(&tx)
119            .await?;
120
121            tx.commit().await?;
122
123            Ok(NewUserResult {
124                user_id: user.id,
125                metrics_id: user.metrics_id.to_string(),
126                signup_device_id: None,
127                inviting_user_id: None,
128            })
129        })
130        .await
131    }
132
133    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
134        self.transact(|tx| async {
135            let tx = tx;
136            Ok(user::Entity::find()
137                .filter(user::Column::Id.is_in(ids.iter().copied()))
138                .all(&tx)
139                .await?)
140        })
141        .await
142    }
143
144    pub async fn get_user_by_github_account(
145        &self,
146        github_login: &str,
147        github_user_id: Option<i32>,
148    ) -> Result<Option<User>> {
149        self.transact(|tx| async {
150            let tx = tx;
151            if let Some(github_user_id) = github_user_id {
152                if let Some(user_by_github_user_id) = user::Entity::find()
153                    .filter(user::Column::GithubUserId.eq(github_user_id))
154                    .one(&tx)
155                    .await?
156                {
157                    let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
158                    user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
159                    Ok(Some(user_by_github_user_id.update(&tx).await?))
160                } else if let Some(user_by_github_login) = user::Entity::find()
161                    .filter(user::Column::GithubLogin.eq(github_login))
162                    .one(&tx)
163                    .await?
164                {
165                    let mut user_by_github_login = user_by_github_login.into_active_model();
166                    user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
167                    Ok(Some(user_by_github_login.update(&tx).await?))
168                } else {
169                    Ok(None)
170                }
171            } else {
172                Ok(user::Entity::find()
173                    .filter(user::Column::GithubLogin.eq(github_login))
174                    .one(&tx)
175                    .await?)
176            }
177        })
178        .await
179    }
180
181    pub async fn share_project(
182        &self,
183        room_id: RoomId,
184        connection_id: ConnectionId,
185        worktrees: &[proto::WorktreeMetadata],
186    ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
187        self.transact(|tx| async move {
188            let participant = room_participant::Entity::find()
189                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
190                .one(&tx)
191                .await?
192                .ok_or_else(|| anyhow!("could not find participant"))?;
193            if participant.room_id != room_id {
194                return Err(anyhow!("shared project on unexpected room"))?;
195            }
196
197            let project = project::ActiveModel {
198                room_id: ActiveValue::set(participant.room_id),
199                host_user_id: ActiveValue::set(participant.user_id),
200                host_connection_id: ActiveValue::set(connection_id.0 as i32),
201                ..Default::default()
202            }
203            .insert(&tx)
204            .await?;
205
206            worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
207                id: ActiveValue::set(worktree.id as i32),
208                project_id: ActiveValue::set(project.id),
209                abs_path: ActiveValue::set(worktree.abs_path.clone()),
210                root_name: ActiveValue::set(worktree.root_name.clone()),
211                visible: ActiveValue::set(worktree.visible),
212                scan_id: ActiveValue::set(0),
213                is_complete: ActiveValue::set(false),
214            }))
215            .exec(&tx)
216            .await?;
217
218            project_collaborator::ActiveModel {
219                project_id: ActiveValue::set(project.id),
220                connection_id: ActiveValue::set(connection_id.0 as i32),
221                user_id: ActiveValue::set(participant.user_id),
222                replica_id: ActiveValue::set(0),
223                is_host: ActiveValue::set(true),
224                ..Default::default()
225            }
226            .insert(&tx)
227            .await?;
228
229            let room = self.get_room(room_id, &tx).await?;
230            self.commit_room_transaction(room_id, tx, (project.id, room))
231                .await
232        })
233        .await
234    }
235
236    async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
237        let db_room = room::Entity::find_by_id(room_id)
238            .one(tx)
239            .await?
240            .ok_or_else(|| anyhow!("could not find room"))?;
241
242        let mut db_participants = db_room
243            .find_related(room_participant::Entity)
244            .stream(tx)
245            .await?;
246        let mut participants = HashMap::default();
247        let mut pending_participants = Vec::new();
248        while let Some(db_participant) = db_participants.next().await {
249            let db_participant = db_participant?;
250            if let Some(answering_connection_id) = db_participant.answering_connection_id {
251                let location = match (
252                    db_participant.location_kind,
253                    db_participant.location_project_id,
254                ) {
255                    (Some(0), Some(project_id)) => {
256                        Some(proto::participant_location::Variant::SharedProject(
257                            proto::participant_location::SharedProject {
258                                id: project_id.to_proto(),
259                            },
260                        ))
261                    }
262                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
263                        Default::default(),
264                    )),
265                    _ => Some(proto::participant_location::Variant::External(
266                        Default::default(),
267                    )),
268                };
269                participants.insert(
270                    answering_connection_id,
271                    proto::Participant {
272                        user_id: db_participant.user_id.to_proto(),
273                        peer_id: answering_connection_id as u32,
274                        projects: Default::default(),
275                        location: Some(proto::ParticipantLocation { variant: location }),
276                    },
277                );
278            } else {
279                pending_participants.push(proto::PendingParticipant {
280                    user_id: db_participant.user_id.to_proto(),
281                    calling_user_id: db_participant.calling_user_id.to_proto(),
282                    initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
283                });
284            }
285        }
286
287        let mut db_projects = db_room
288            .find_related(project::Entity)
289            .find_with_related(worktree::Entity)
290            .stream(tx)
291            .await?;
292
293        while let Some(row) = db_projects.next().await {
294            let (db_project, db_worktree) = row?;
295            if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
296                let project = if let Some(project) = participant
297                    .projects
298                    .iter_mut()
299                    .find(|project| project.id == db_project.id.to_proto())
300                {
301                    project
302                } else {
303                    participant.projects.push(proto::ParticipantProject {
304                        id: db_project.id.to_proto(),
305                        worktree_root_names: Default::default(),
306                    });
307                    participant.projects.last_mut().unwrap()
308                };
309
310                if let Some(db_worktree) = db_worktree {
311                    project.worktree_root_names.push(db_worktree.root_name);
312                }
313            }
314        }
315
316        Ok(proto::Room {
317            id: db_room.id.to_proto(),
318            live_kit_room: db_room.live_kit_room,
319            participants: participants.into_values().collect(),
320            pending_participants,
321        })
322    }
323
324    async fn commit_room_transaction<T>(
325        &self,
326        room_id: RoomId,
327        tx: DatabaseTransaction,
328        data: T,
329    ) -> Result<RoomGuard<T>> {
330        let lock = self.rooms.entry(room_id).or_default().clone();
331        let _guard = lock.lock_owned().await;
332        tx.commit().await?;
333        Ok(RoomGuard {
334            data,
335            _guard,
336            _not_send: PhantomData,
337        })
338    }
339
340    pub async fn create_access_token_hash(
341        &self,
342        user_id: UserId,
343        access_token_hash: &str,
344        max_access_token_count: usize,
345    ) -> Result<()> {
346        self.transact(|tx| async {
347            let tx = tx;
348
349            access_token::ActiveModel {
350                user_id: ActiveValue::set(user_id),
351                hash: ActiveValue::set(access_token_hash.into()),
352                ..Default::default()
353            }
354            .insert(&tx)
355            .await?;
356
357            access_token::Entity::delete_many()
358                .filter(
359                    access_token::Column::Id.in_subquery(
360                        Query::select()
361                            .column(access_token::Column::Id)
362                            .from(access_token::Entity)
363                            .and_where(access_token::Column::UserId.eq(user_id))
364                            .order_by(access_token::Column::Id, sea_orm::Order::Desc)
365                            .limit(10000)
366                            .offset(max_access_token_count as u64)
367                            .to_owned(),
368                    ),
369                )
370                .exec(&tx)
371                .await?;
372            tx.commit().await?;
373            Ok(())
374        })
375        .await
376    }
377
378    pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
379        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
380        enum QueryAs {
381            Hash,
382        }
383
384        self.transact(|tx| async move {
385            Ok(access_token::Entity::find()
386                .select_only()
387                .column(access_token::Column::Hash)
388                .filter(access_token::Column::UserId.eq(user_id))
389                .order_by_desc(access_token::Column::Id)
390                .into_values::<_, QueryAs>()
391                .all(&tx)
392                .await?)
393        })
394        .await
395    }
396
397    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
398    where
399        F: Send + Fn(DatabaseTransaction) -> Fut,
400        Fut: Send + Future<Output = Result<T>>,
401    {
402        let body = async {
403            loop {
404                let tx = self.pool.begin().await?;
405
406                // In Postgres, serializable transactions are opt-in
407                if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() {
408                    tx.execute(sea_orm::Statement::from_string(
409                        sea_orm::DatabaseBackend::Postgres,
410                        "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
411                    ))
412                    .await?;
413                }
414
415                match f(tx).await {
416                    Ok(result) => return Ok(result),
417                    Err(error) => match error {
418                        Error::Database2(
419                            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
420                            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
421                        ) if error
422                            .as_database_error()
423                            .and_then(|error| error.code())
424                            .as_deref()
425                            == Some("40001") =>
426                        {
427                            // Retry (don't break the loop)
428                        }
429                        error @ _ => return Err(error),
430                    },
431                }
432            }
433        };
434
435        #[cfg(test)]
436        {
437            if let Some(background) = self.background.as_ref() {
438                background.simulate_random_delay().await;
439            }
440
441            self.runtime.as_ref().unwrap().block_on(body)
442        }
443
444        #[cfg(not(test))]
445        {
446            body.await
447        }
448    }
449}
450
451pub struct RoomGuard<T> {
452    data: T,
453    _guard: OwnedMutexGuard<()>,
454    _not_send: PhantomData<Rc<()>>,
455}
456
457impl<T> Deref for RoomGuard<T> {
458    type Target = T;
459
460    fn deref(&self) -> &T {
461        &self.data
462    }
463}
464
465impl<T> DerefMut for RoomGuard<T> {
466    fn deref_mut(&mut self) -> &mut T {
467        &mut self.data
468    }
469}
470
471#[derive(Debug, Serialize, Deserialize)]
472pub struct NewUserParams {
473    pub github_login: String,
474    pub github_user_id: i32,
475    pub invite_count: i32,
476}
477
478#[derive(Debug)]
479pub struct NewUserResult {
480    pub user_id: UserId,
481    pub metrics_id: String,
482    pub inviting_user_id: Option<UserId>,
483    pub signup_device_id: Option<String>,
484}
485
486fn random_invite_code() -> String {
487    nanoid::nanoid!(16)
488}
489
490fn random_email_confirmation_code() -> String {
491    nanoid::nanoid!(64)
492}
493
494macro_rules! id_type {
495    ($name:ident) => {
496        #[derive(
497            Clone,
498            Copy,
499            Debug,
500            Default,
501            PartialEq,
502            Eq,
503            PartialOrd,
504            Ord,
505            Hash,
506            sqlx::Type,
507            Serialize,
508            Deserialize,
509        )]
510        #[sqlx(transparent)]
511        #[serde(transparent)]
512        pub struct $name(pub i32);
513
514        impl $name {
515            #[allow(unused)]
516            pub const MAX: Self = Self(i32::MAX);
517
518            #[allow(unused)]
519            pub fn from_proto(value: u64) -> Self {
520                Self(value as i32)
521            }
522
523            #[allow(unused)]
524            pub fn to_proto(self) -> u64 {
525                self.0 as u64
526            }
527        }
528
529        impl std::fmt::Display for $name {
530            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
531                self.0.fmt(f)
532            }
533        }
534
535        impl From<$name> for sea_query::Value {
536            fn from(value: $name) -> Self {
537                sea_query::Value::Int(Some(value.0))
538            }
539        }
540
541        impl sea_orm::TryGetable for $name {
542            fn try_get(
543                res: &sea_orm::QueryResult,
544                pre: &str,
545                col: &str,
546            ) -> Result<Self, sea_orm::TryGetError> {
547                Ok(Self(i32::try_get(res, pre, col)?))
548            }
549        }
550
551        impl sea_query::ValueType for $name {
552            fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
553                match v {
554                    Value::TinyInt(Some(int)) => {
555                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
556                    }
557                    Value::SmallInt(Some(int)) => {
558                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
559                    }
560                    Value::Int(Some(int)) => {
561                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
562                    }
563                    Value::BigInt(Some(int)) => {
564                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
565                    }
566                    Value::TinyUnsigned(Some(int)) => {
567                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
568                    }
569                    Value::SmallUnsigned(Some(int)) => {
570                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
571                    }
572                    Value::Unsigned(Some(int)) => {
573                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
574                    }
575                    Value::BigUnsigned(Some(int)) => {
576                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
577                    }
578                    _ => Err(sea_query::ValueTypeErr),
579                }
580            }
581
582            fn type_name() -> String {
583                stringify!($name).into()
584            }
585
586            fn array_type() -> sea_query::ArrayType {
587                sea_query::ArrayType::Int
588            }
589
590            fn column_type() -> sea_query::ColumnType {
591                sea_query::ColumnType::Integer(None)
592            }
593        }
594
595        impl sea_orm::TryFromU64 for $name {
596            fn try_from_u64(n: u64) -> Result<Self, DbErr> {
597                Ok(Self(n.try_into().map_err(|_| {
598                    DbErr::ConvertFromU64(concat!(
599                        "error converting ",
600                        stringify!($name),
601                        " to u64"
602                    ))
603                })?))
604            }
605        }
606
607        impl sea_query::Nullable for $name {
608            fn null() -> Value {
609                Value::Int(None)
610            }
611        }
612    };
613}
614
615id_type!(AccessTokenId);
616id_type!(UserId);
617id_type!(RoomId);
618id_type!(RoomParticipantId);
619id_type!(ProjectId);
620id_type!(ProjectCollaboratorId);
621id_type!(WorktreeId);
622
623#[cfg(test)]
624pub use test::*;
625
626#[cfg(test)]
627mod test {
628    use super::*;
629    use gpui::executor::Background;
630    use lazy_static::lazy_static;
631    use parking_lot::Mutex;
632    use rand::prelude::*;
633    use sea_orm::ConnectionTrait;
634    use sqlx::migrate::MigrateDatabase;
635    use std::sync::Arc;
636
637    pub struct TestDb {
638        pub db: Option<Arc<Database>>,
639        pub connection: Option<sqlx::AnyConnection>,
640    }
641
642    impl TestDb {
643        pub fn sqlite(background: Arc<Background>) -> Self {
644            let url = format!("sqlite::memory:");
645            let runtime = tokio::runtime::Builder::new_current_thread()
646                .enable_io()
647                .enable_time()
648                .build()
649                .unwrap();
650
651            let mut db = runtime.block_on(async {
652                let mut options = ConnectOptions::new(url);
653                options.max_connections(5);
654                let db = Database::new(options).await.unwrap();
655                let sql = include_str!(concat!(
656                    env!("CARGO_MANIFEST_DIR"),
657                    "/migrations.sqlite/20221109000000_test_schema.sql"
658                ));
659                db.pool
660                    .execute(sea_orm::Statement::from_string(
661                        db.pool.get_database_backend(),
662                        sql.into(),
663                    ))
664                    .await
665                    .unwrap();
666                db
667            });
668
669            db.background = Some(background);
670            db.runtime = Some(runtime);
671
672            Self {
673                db: Some(Arc::new(db)),
674                connection: None,
675            }
676        }
677
678        pub fn postgres(background: Arc<Background>) -> Self {
679            lazy_static! {
680                static ref LOCK: Mutex<()> = Mutex::new(());
681            }
682
683            let _guard = LOCK.lock();
684            let mut rng = StdRng::from_entropy();
685            let url = format!(
686                "postgres://postgres@localhost/zed-test-{}",
687                rng.gen::<u128>()
688            );
689            let runtime = tokio::runtime::Builder::new_current_thread()
690                .enable_io()
691                .enable_time()
692                .build()
693                .unwrap();
694
695            let mut db = runtime.block_on(async {
696                sqlx::Postgres::create_database(&url)
697                    .await
698                    .expect("failed to create test db");
699                let mut options = ConnectOptions::new(url);
700                options
701                    .max_connections(5)
702                    .idle_timeout(Duration::from_secs(0));
703                let db = Database::new(options).await.unwrap();
704                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
705                db.migrate(Path::new(migrations_path), false).await.unwrap();
706                db
707            });
708
709            db.background = Some(background);
710            db.runtime = Some(runtime);
711
712            Self {
713                db: Some(Arc::new(db)),
714                connection: None,
715            }
716        }
717
718        pub fn db(&self) -> &Arc<Database> {
719            self.db.as_ref().unwrap()
720        }
721    }
722
723    impl Drop for TestDb {
724        fn drop(&mut self) {
725            let db = self.db.take().unwrap();
726            if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
727                db.runtime.as_ref().unwrap().block_on(async {
728                    use util::ResultExt;
729                    let query = "
730                        SELECT pg_terminate_backend(pg_stat_activity.pid)
731                        FROM pg_stat_activity
732                        WHERE
733                            pg_stat_activity.datname = current_database() AND
734                            pid <> pg_backend_pid();
735                    ";
736                    db.pool
737                        .execute(sea_orm::Statement::from_string(
738                            db.pool.get_database_backend(),
739                            query.into(),
740                        ))
741                        .await
742                        .log_err();
743                    sqlx::Postgres::drop_database(db.options.get_url())
744                        .await
745                        .log_err();
746                })
747            }
748        }
749    }
750}