db2.rs

  1mod project;
  2mod project_collaborator;
  3mod room;
  4mod room_participant;
  5#[cfg(test)]
  6mod tests;
  7mod user;
  8mod worktree;
  9
 10use crate::{Error, Result};
 11use anyhow::anyhow;
 12use collections::HashMap;
 13use dashmap::DashMap;
 14use futures::StreamExt;
 15use rpc::{proto, ConnectionId};
 16use sea_orm::ActiveValue;
 17use sea_orm::{
 18    entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
 19    TransactionTrait,
 20};
 21use sea_query::OnConflict;
 22use serde::{Deserialize, Serialize};
 23use sqlx::migrate::{Migrate, Migration, MigrationSource};
 24use sqlx::Connection;
 25use std::ops::{Deref, DerefMut};
 26use std::path::Path;
 27use std::time::Duration;
 28use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
 29use tokio::sync::{Mutex, OwnedMutexGuard};
 30
 31pub use user::Model as User;
 32
 33pub struct Database {
 34    url: String,
 35    pool: DatabaseConnection,
 36    rooms: DashMap<RoomId, Arc<Mutex<()>>>,
 37    #[cfg(test)]
 38    background: Option<std::sync::Arc<gpui::executor::Background>>,
 39    #[cfg(test)]
 40    runtime: Option<tokio::runtime::Runtime>,
 41}
 42
 43impl Database {
 44    pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
 45        let mut options = ConnectOptions::new(url.into());
 46        options.max_connections(max_connections);
 47        Ok(Self {
 48            url: url.into(),
 49            pool: sea_orm::Database::connect(options).await?,
 50            rooms: DashMap::with_capacity(16384),
 51            #[cfg(test)]
 52            background: None,
 53            #[cfg(test)]
 54            runtime: None,
 55        })
 56    }
 57
 58    pub async fn migrate(
 59        &self,
 60        migrations_path: &Path,
 61        ignore_checksum_mismatch: bool,
 62    ) -> anyhow::Result<(sqlx::AnyConnection, Vec<(Migration, Duration)>)> {
 63        let migrations = MigrationSource::resolve(migrations_path)
 64            .await
 65            .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
 66
 67        let mut connection = sqlx::AnyConnection::connect(&self.url).await?;
 68
 69        connection.ensure_migrations_table().await?;
 70        let applied_migrations: HashMap<_, _> = connection
 71            .list_applied_migrations()
 72            .await?
 73            .into_iter()
 74            .map(|m| (m.version, m))
 75            .collect();
 76
 77        let mut new_migrations = Vec::new();
 78        for migration in migrations {
 79            match applied_migrations.get(&migration.version) {
 80                Some(applied_migration) => {
 81                    if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
 82                    {
 83                        Err(anyhow!(
 84                            "checksum mismatch for applied migration {}",
 85                            migration.description
 86                        ))?;
 87                    }
 88                }
 89                None => {
 90                    let elapsed = connection.apply(&migration).await?;
 91                    new_migrations.push((migration, elapsed));
 92                }
 93            }
 94        }
 95
 96        Ok((connection, new_migrations))
 97    }
 98
 99    pub async fn create_user(
100        &self,
101        email_address: &str,
102        admin: bool,
103        params: NewUserParams,
104    ) -> Result<NewUserResult> {
105        self.transact(|tx| async {
106            let user = user::Entity::insert(user::ActiveModel {
107                email_address: ActiveValue::set(Some(email_address.into())),
108                github_login: ActiveValue::set(params.github_login.clone()),
109                github_user_id: ActiveValue::set(Some(params.github_user_id)),
110                admin: ActiveValue::set(admin),
111                metrics_id: ActiveValue::set(Uuid::new_v4()),
112                ..Default::default()
113            })
114            .on_conflict(
115                OnConflict::column(user::Column::GithubLogin)
116                    .update_column(user::Column::GithubLogin)
117                    .to_owned(),
118            )
119            .exec_with_returning(&tx)
120            .await?;
121
122            tx.commit().await?;
123
124            Ok(NewUserResult {
125                user_id: user.id,
126                metrics_id: user.metrics_id.to_string(),
127                signup_device_id: None,
128                inviting_user_id: None,
129            })
130        })
131        .await
132    }
133
134    pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
135        self.transact(|tx| async {
136            let tx = tx;
137            Ok(user::Entity::find()
138                .filter(user::Column::Id.is_in(ids.iter().copied()))
139                .all(&tx)
140                .await?)
141        })
142        .await
143    }
144
145    pub async fn share_project(
146        &self,
147        room_id: RoomId,
148        connection_id: ConnectionId,
149        worktrees: &[proto::WorktreeMetadata],
150    ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
151        self.transact(|tx| async move {
152            let participant = room_participant::Entity::find()
153                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
154                .one(&tx)
155                .await?
156                .ok_or_else(|| anyhow!("could not find participant"))?;
157            if participant.room_id != room_id {
158                return Err(anyhow!("shared project on unexpected room"))?;
159            }
160
161            let project = project::ActiveModel {
162                room_id: ActiveValue::set(participant.room_id),
163                host_user_id: ActiveValue::set(participant.user_id),
164                host_connection_id: ActiveValue::set(connection_id.0 as i32),
165                ..Default::default()
166            }
167            .insert(&tx)
168            .await?;
169
170            worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
171                id: ActiveValue::set(worktree.id as i32),
172                project_id: ActiveValue::set(project.id),
173                abs_path: ActiveValue::set(worktree.abs_path.clone()),
174                root_name: ActiveValue::set(worktree.root_name.clone()),
175                visible: ActiveValue::set(worktree.visible),
176                scan_id: ActiveValue::set(0),
177                is_complete: ActiveValue::set(false),
178            }))
179            .exec(&tx)
180            .await?;
181
182            project_collaborator::ActiveModel {
183                project_id: ActiveValue::set(project.id),
184                connection_id: ActiveValue::set(connection_id.0 as i32),
185                user_id: ActiveValue::set(participant.user_id),
186                replica_id: ActiveValue::set(0),
187                is_host: ActiveValue::set(true),
188                ..Default::default()
189            }
190            .insert(&tx)
191            .await?;
192
193            let room = self.get_room(room_id, &tx).await?;
194            self.commit_room_transaction(room_id, tx, (project.id, room))
195                .await
196        })
197        .await
198    }
199
200    async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
201        let db_room = room::Entity::find_by_id(room_id)
202            .one(tx)
203            .await?
204            .ok_or_else(|| anyhow!("could not find room"))?;
205
206        let mut db_participants = db_room
207            .find_related(room_participant::Entity)
208            .stream(tx)
209            .await?;
210        let mut participants = HashMap::default();
211        let mut pending_participants = Vec::new();
212        while let Some(db_participant) = db_participants.next().await {
213            let db_participant = db_participant?;
214            if let Some(answering_connection_id) = db_participant.answering_connection_id {
215                let location = match (
216                    db_participant.location_kind,
217                    db_participant.location_project_id,
218                ) {
219                    (Some(0), Some(project_id)) => {
220                        Some(proto::participant_location::Variant::SharedProject(
221                            proto::participant_location::SharedProject {
222                                id: project_id.to_proto(),
223                            },
224                        ))
225                    }
226                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
227                        Default::default(),
228                    )),
229                    _ => Some(proto::participant_location::Variant::External(
230                        Default::default(),
231                    )),
232                };
233                participants.insert(
234                    answering_connection_id,
235                    proto::Participant {
236                        user_id: db_participant.user_id.to_proto(),
237                        peer_id: answering_connection_id as u32,
238                        projects: Default::default(),
239                        location: Some(proto::ParticipantLocation { variant: location }),
240                    },
241                );
242            } else {
243                pending_participants.push(proto::PendingParticipant {
244                    user_id: db_participant.user_id.to_proto(),
245                    calling_user_id: db_participant.calling_user_id.to_proto(),
246                    initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
247                });
248            }
249        }
250
251        let mut db_projects = db_room
252            .find_related(project::Entity)
253            .find_with_related(worktree::Entity)
254            .stream(tx)
255            .await?;
256
257        while let Some(row) = db_projects.next().await {
258            let (db_project, db_worktree) = row?;
259            if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
260                let project = if let Some(project) = participant
261                    .projects
262                    .iter_mut()
263                    .find(|project| project.id == db_project.id.to_proto())
264                {
265                    project
266                } else {
267                    participant.projects.push(proto::ParticipantProject {
268                        id: db_project.id.to_proto(),
269                        worktree_root_names: Default::default(),
270                    });
271                    participant.projects.last_mut().unwrap()
272                };
273
274                if let Some(db_worktree) = db_worktree {
275                    project.worktree_root_names.push(db_worktree.root_name);
276                }
277            }
278        }
279
280        Ok(proto::Room {
281            id: db_room.id.to_proto(),
282            live_kit_room: db_room.live_kit_room,
283            participants: participants.into_values().collect(),
284            pending_participants,
285        })
286    }
287
288    async fn commit_room_transaction<T>(
289        &self,
290        room_id: RoomId,
291        tx: DatabaseTransaction,
292        data: T,
293    ) -> Result<RoomGuard<T>> {
294        let lock = self.rooms.entry(room_id).or_default().clone();
295        let _guard = lock.lock_owned().await;
296        tx.commit().await?;
297        Ok(RoomGuard {
298            data,
299            _guard,
300            _not_send: PhantomData,
301        })
302    }
303
304    async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
305    where
306        F: Send + Fn(DatabaseTransaction) -> Fut,
307        Fut: Send + Future<Output = Result<T>>,
308    {
309        let body = async {
310            loop {
311                let tx = self.pool.begin().await?;
312                match f(tx).await {
313                    Ok(result) => return Ok(result),
314                    Err(error) => match error {
315                        Error::Database2(
316                            DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
317                            | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
318                        ) if error
319                            .as_database_error()
320                            .and_then(|error| error.code())
321                            .as_deref()
322                            == Some("40001") =>
323                        {
324                            // Retry (don't break the loop)
325                        }
326                        error @ _ => return Err(error),
327                    },
328                }
329            }
330        };
331
332        #[cfg(test)]
333        {
334            if let Some(background) = self.background.as_ref() {
335                background.simulate_random_delay().await;
336            }
337
338            self.runtime.as_ref().unwrap().block_on(body)
339        }
340
341        #[cfg(not(test))]
342        {
343            body.await
344        }
345    }
346}
347
348pub struct RoomGuard<T> {
349    data: T,
350    _guard: OwnedMutexGuard<()>,
351    _not_send: PhantomData<Rc<()>>,
352}
353
354impl<T> Deref for RoomGuard<T> {
355    type Target = T;
356
357    fn deref(&self) -> &T {
358        &self.data
359    }
360}
361
362impl<T> DerefMut for RoomGuard<T> {
363    fn deref_mut(&mut self) -> &mut T {
364        &mut self.data
365    }
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369pub struct NewUserParams {
370    pub github_login: String,
371    pub github_user_id: i32,
372    pub invite_count: i32,
373}
374
375#[derive(Debug)]
376pub struct NewUserResult {
377    pub user_id: UserId,
378    pub metrics_id: String,
379    pub inviting_user_id: Option<UserId>,
380    pub signup_device_id: Option<String>,
381}
382
383fn random_invite_code() -> String {
384    nanoid::nanoid!(16)
385}
386
387fn random_email_confirmation_code() -> String {
388    nanoid::nanoid!(64)
389}
390
391macro_rules! id_type {
392    ($name:ident) => {
393        #[derive(
394            Clone,
395            Copy,
396            Debug,
397            Default,
398            PartialEq,
399            Eq,
400            PartialOrd,
401            Ord,
402            Hash,
403            sqlx::Type,
404            Serialize,
405            Deserialize,
406        )]
407        #[sqlx(transparent)]
408        #[serde(transparent)]
409        pub struct $name(pub i32);
410
411        impl $name {
412            #[allow(unused)]
413            pub const MAX: Self = Self(i32::MAX);
414
415            #[allow(unused)]
416            pub fn from_proto(value: u64) -> Self {
417                Self(value as i32)
418            }
419
420            #[allow(unused)]
421            pub fn to_proto(self) -> u64 {
422                self.0 as u64
423            }
424        }
425
426        impl std::fmt::Display for $name {
427            fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
428                self.0.fmt(f)
429            }
430        }
431
432        impl From<$name> for sea_query::Value {
433            fn from(value: $name) -> Self {
434                sea_query::Value::Int(Some(value.0))
435            }
436        }
437
438        impl sea_orm::TryGetable for $name {
439            fn try_get(
440                res: &sea_orm::QueryResult,
441                pre: &str,
442                col: &str,
443            ) -> Result<Self, sea_orm::TryGetError> {
444                Ok(Self(i32::try_get(res, pre, col)?))
445            }
446        }
447
448        impl sea_query::ValueType for $name {
449            fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
450                match v {
451                    Value::TinyInt(Some(int)) => {
452                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
453                    }
454                    Value::SmallInt(Some(int)) => {
455                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
456                    }
457                    Value::Int(Some(int)) => {
458                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
459                    }
460                    Value::BigInt(Some(int)) => {
461                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
462                    }
463                    Value::TinyUnsigned(Some(int)) => {
464                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
465                    }
466                    Value::SmallUnsigned(Some(int)) => {
467                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
468                    }
469                    Value::Unsigned(Some(int)) => {
470                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
471                    }
472                    Value::BigUnsigned(Some(int)) => {
473                        Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
474                    }
475                    _ => Err(sea_query::ValueTypeErr),
476                }
477            }
478
479            fn type_name() -> String {
480                stringify!($name).into()
481            }
482
483            fn array_type() -> sea_query::ArrayType {
484                sea_query::ArrayType::Int
485            }
486
487            fn column_type() -> sea_query::ColumnType {
488                sea_query::ColumnType::Integer(None)
489            }
490        }
491
492        impl sea_orm::TryFromU64 for $name {
493            fn try_from_u64(n: u64) -> Result<Self, DbErr> {
494                Ok(Self(n.try_into().map_err(|_| {
495                    DbErr::ConvertFromU64(concat!(
496                        "error converting ",
497                        stringify!($name),
498                        " to u64"
499                    ))
500                })?))
501            }
502        }
503
504        impl sea_query::Nullable for $name {
505            fn null() -> Value {
506                Value::Int(None)
507            }
508        }
509    };
510}
511
512id_type!(UserId);
513id_type!(RoomId);
514id_type!(RoomParticipantId);
515id_type!(ProjectId);
516id_type!(ProjectCollaboratorId);
517id_type!(WorktreeId);
518
519#[cfg(test)]
520pub use test::*;
521
522#[cfg(test)]
523mod test {
524    use super::*;
525    use gpui::executor::Background;
526    use lazy_static::lazy_static;
527    use parking_lot::Mutex;
528    use rand::prelude::*;
529    use sea_orm::ConnectionTrait;
530    use sqlx::migrate::MigrateDatabase;
531    use std::sync::Arc;
532
533    pub struct TestDb {
534        pub db: Option<Arc<Database>>,
535        pub connection: Option<sqlx::AnyConnection>,
536    }
537
538    impl TestDb {
539        pub fn sqlite(background: Arc<Background>) -> Self {
540            let url = format!("sqlite::memory:");
541            let runtime = tokio::runtime::Builder::new_current_thread()
542                .enable_io()
543                .enable_time()
544                .build()
545                .unwrap();
546
547            let mut db = runtime.block_on(async {
548                let db = Database::new(&url, 5).await.unwrap();
549                let sql = include_str!(concat!(
550                    env!("CARGO_MANIFEST_DIR"),
551                    "/migrations.sqlite/20221109000000_test_schema.sql"
552                ));
553                db.pool
554                    .execute(sea_orm::Statement::from_string(
555                        db.pool.get_database_backend(),
556                        sql.into(),
557                    ))
558                    .await
559                    .unwrap();
560                db
561            });
562
563            db.background = Some(background);
564            db.runtime = Some(runtime);
565
566            Self {
567                db: Some(Arc::new(db)),
568                connection: None,
569            }
570        }
571
572        pub fn postgres(background: Arc<Background>) -> Self {
573            lazy_static! {
574                static ref LOCK: Mutex<()> = Mutex::new(());
575            }
576
577            let _guard = LOCK.lock();
578            let mut rng = StdRng::from_entropy();
579            let url = format!(
580                "postgres://postgres@localhost/zed-test-{}",
581                rng.gen::<u128>()
582            );
583            let runtime = tokio::runtime::Builder::new_current_thread()
584                .enable_io()
585                .enable_time()
586                .build()
587                .unwrap();
588
589            let mut db = runtime.block_on(async {
590                sqlx::Postgres::create_database(&url)
591                    .await
592                    .expect("failed to create test db");
593                let db = Database::new(&url, 5).await.unwrap();
594                let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
595                db.migrate(Path::new(migrations_path), false).await.unwrap();
596                db
597            });
598
599            db.background = Some(background);
600            db.runtime = Some(runtime);
601
602            Self {
603                db: Some(Arc::new(db)),
604                connection: None,
605            }
606        }
607
608        pub fn db(&self) -> &Arc<Database> {
609            self.db.as_ref().unwrap()
610        }
611    }
612
613    // TODO: Implement drop
614    // impl Drop for PostgresTestDb {
615    //     fn drop(&mut self) {
616    //         let db = self.db.take().unwrap();
617    //         db.teardown(&self.url);
618    //     }
619    // }
620}