1mod access_token;
2mod project;
3mod project_collaborator;
4mod room;
5mod room_participant;
6#[cfg(test)]
7mod tests;
8mod user;
9mod worktree;
10
11use crate::{Error, Result};
12use anyhow::anyhow;
13use collections::HashMap;
14use dashmap::DashMap;
15use futures::StreamExt;
16use rpc::{proto, ConnectionId};
17use sea_orm::{
18 entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
19 TransactionTrait,
20};
21use sea_orm::{ActiveValue, ConnectionTrait, IntoActiveModel, QueryOrder, QuerySelect};
22use sea_query::{OnConflict, Query};
23use serde::{Deserialize, Serialize};
24use sqlx::migrate::{Migrate, Migration, MigrationSource};
25use sqlx::Connection;
26use std::ops::{Deref, DerefMut};
27use std::path::Path;
28use std::time::Duration;
29use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
30use tokio::sync::{Mutex, OwnedMutexGuard};
31
32pub use user::Model as User;
33
34pub struct Database {
35 options: ConnectOptions,
36 pool: DatabaseConnection,
37 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
38 #[cfg(test)]
39 background: Option<std::sync::Arc<gpui::executor::Background>>,
40 #[cfg(test)]
41 runtime: Option<tokio::runtime::Runtime>,
42}
43
44impl Database {
45 pub async fn new(options: ConnectOptions) -> Result<Self> {
46 Ok(Self {
47 options: options.clone(),
48 pool: sea_orm::Database::connect(options).await?,
49 rooms: DashMap::with_capacity(16384),
50 #[cfg(test)]
51 background: None,
52 #[cfg(test)]
53 runtime: None,
54 })
55 }
56
57 pub async fn migrate(
58 &self,
59 migrations_path: &Path,
60 ignore_checksum_mismatch: bool,
61 ) -> anyhow::Result<Vec<(Migration, Duration)>> {
62 let migrations = MigrationSource::resolve(migrations_path)
63 .await
64 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
65
66 let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?;
67
68 connection.ensure_migrations_table().await?;
69 let applied_migrations: HashMap<_, _> = connection
70 .list_applied_migrations()
71 .await?
72 .into_iter()
73 .map(|m| (m.version, m))
74 .collect();
75
76 let mut new_migrations = Vec::new();
77 for migration in migrations {
78 match applied_migrations.get(&migration.version) {
79 Some(applied_migration) => {
80 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
81 {
82 Err(anyhow!(
83 "checksum mismatch for applied migration {}",
84 migration.description
85 ))?;
86 }
87 }
88 None => {
89 let elapsed = connection.apply(&migration).await?;
90 new_migrations.push((migration, elapsed));
91 }
92 }
93 }
94
95 Ok(new_migrations)
96 }
97
98 pub async fn create_user(
99 &self,
100 email_address: &str,
101 admin: bool,
102 params: NewUserParams,
103 ) -> Result<NewUserResult> {
104 self.transact(|tx| async {
105 let user = user::Entity::insert(user::ActiveModel {
106 email_address: ActiveValue::set(Some(email_address.into())),
107 github_login: ActiveValue::set(params.github_login.clone()),
108 github_user_id: ActiveValue::set(Some(params.github_user_id)),
109 admin: ActiveValue::set(admin),
110 metrics_id: ActiveValue::set(Uuid::new_v4()),
111 ..Default::default()
112 })
113 .on_conflict(
114 OnConflict::column(user::Column::GithubLogin)
115 .update_column(user::Column::GithubLogin)
116 .to_owned(),
117 )
118 .exec_with_returning(&tx)
119 .await?;
120
121 tx.commit().await?;
122
123 Ok(NewUserResult {
124 user_id: user.id,
125 metrics_id: user.metrics_id.to_string(),
126 signup_device_id: None,
127 inviting_user_id: None,
128 })
129 })
130 .await
131 }
132
133 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
134 self.transact(|tx| async {
135 let tx = tx;
136 Ok(user::Entity::find()
137 .filter(user::Column::Id.is_in(ids.iter().copied()))
138 .all(&tx)
139 .await?)
140 })
141 .await
142 }
143
144 pub async fn get_user_by_github_account(
145 &self,
146 github_login: &str,
147 github_user_id: Option<i32>,
148 ) -> Result<Option<User>> {
149 self.transact(|tx| async {
150 let tx = tx;
151 if let Some(github_user_id) = github_user_id {
152 if let Some(user_by_github_user_id) = user::Entity::find()
153 .filter(user::Column::GithubUserId.eq(github_user_id))
154 .one(&tx)
155 .await?
156 {
157 let mut user_by_github_user_id = user_by_github_user_id.into_active_model();
158 user_by_github_user_id.github_login = ActiveValue::set(github_login.into());
159 Ok(Some(user_by_github_user_id.update(&tx).await?))
160 } else if let Some(user_by_github_login) = user::Entity::find()
161 .filter(user::Column::GithubLogin.eq(github_login))
162 .one(&tx)
163 .await?
164 {
165 let mut user_by_github_login = user_by_github_login.into_active_model();
166 user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id));
167 Ok(Some(user_by_github_login.update(&tx).await?))
168 } else {
169 Ok(None)
170 }
171 } else {
172 Ok(user::Entity::find()
173 .filter(user::Column::GithubLogin.eq(github_login))
174 .one(&tx)
175 .await?)
176 }
177 })
178 .await
179 }
180
181 pub async fn share_project(
182 &self,
183 room_id: RoomId,
184 connection_id: ConnectionId,
185 worktrees: &[proto::WorktreeMetadata],
186 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
187 self.transact(|tx| async move {
188 let participant = room_participant::Entity::find()
189 .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
190 .one(&tx)
191 .await?
192 .ok_or_else(|| anyhow!("could not find participant"))?;
193 if participant.room_id != room_id {
194 return Err(anyhow!("shared project on unexpected room"))?;
195 }
196
197 let project = project::ActiveModel {
198 room_id: ActiveValue::set(participant.room_id),
199 host_user_id: ActiveValue::set(participant.user_id),
200 host_connection_id: ActiveValue::set(connection_id.0 as i32),
201 ..Default::default()
202 }
203 .insert(&tx)
204 .await?;
205
206 worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
207 id: ActiveValue::set(worktree.id as i32),
208 project_id: ActiveValue::set(project.id),
209 abs_path: ActiveValue::set(worktree.abs_path.clone()),
210 root_name: ActiveValue::set(worktree.root_name.clone()),
211 visible: ActiveValue::set(worktree.visible),
212 scan_id: ActiveValue::set(0),
213 is_complete: ActiveValue::set(false),
214 }))
215 .exec(&tx)
216 .await?;
217
218 project_collaborator::ActiveModel {
219 project_id: ActiveValue::set(project.id),
220 connection_id: ActiveValue::set(connection_id.0 as i32),
221 user_id: ActiveValue::set(participant.user_id),
222 replica_id: ActiveValue::set(0),
223 is_host: ActiveValue::set(true),
224 ..Default::default()
225 }
226 .insert(&tx)
227 .await?;
228
229 let room = self.get_room(room_id, &tx).await?;
230 self.commit_room_transaction(room_id, tx, (project.id, room))
231 .await
232 })
233 .await
234 }
235
236 async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
237 let db_room = room::Entity::find_by_id(room_id)
238 .one(tx)
239 .await?
240 .ok_or_else(|| anyhow!("could not find room"))?;
241
242 let mut db_participants = db_room
243 .find_related(room_participant::Entity)
244 .stream(tx)
245 .await?;
246 let mut participants = HashMap::default();
247 let mut pending_participants = Vec::new();
248 while let Some(db_participant) = db_participants.next().await {
249 let db_participant = db_participant?;
250 if let Some(answering_connection_id) = db_participant.answering_connection_id {
251 let location = match (
252 db_participant.location_kind,
253 db_participant.location_project_id,
254 ) {
255 (Some(0), Some(project_id)) => {
256 Some(proto::participant_location::Variant::SharedProject(
257 proto::participant_location::SharedProject {
258 id: project_id.to_proto(),
259 },
260 ))
261 }
262 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
263 Default::default(),
264 )),
265 _ => Some(proto::participant_location::Variant::External(
266 Default::default(),
267 )),
268 };
269 participants.insert(
270 answering_connection_id,
271 proto::Participant {
272 user_id: db_participant.user_id.to_proto(),
273 peer_id: answering_connection_id as u32,
274 projects: Default::default(),
275 location: Some(proto::ParticipantLocation { variant: location }),
276 },
277 );
278 } else {
279 pending_participants.push(proto::PendingParticipant {
280 user_id: db_participant.user_id.to_proto(),
281 calling_user_id: db_participant.calling_user_id.to_proto(),
282 initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
283 });
284 }
285 }
286
287 let mut db_projects = db_room
288 .find_related(project::Entity)
289 .find_with_related(worktree::Entity)
290 .stream(tx)
291 .await?;
292
293 while let Some(row) = db_projects.next().await {
294 let (db_project, db_worktree) = row?;
295 if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
296 let project = if let Some(project) = participant
297 .projects
298 .iter_mut()
299 .find(|project| project.id == db_project.id.to_proto())
300 {
301 project
302 } else {
303 participant.projects.push(proto::ParticipantProject {
304 id: db_project.id.to_proto(),
305 worktree_root_names: Default::default(),
306 });
307 participant.projects.last_mut().unwrap()
308 };
309
310 if let Some(db_worktree) = db_worktree {
311 project.worktree_root_names.push(db_worktree.root_name);
312 }
313 }
314 }
315
316 Ok(proto::Room {
317 id: db_room.id.to_proto(),
318 live_kit_room: db_room.live_kit_room,
319 participants: participants.into_values().collect(),
320 pending_participants,
321 })
322 }
323
324 async fn commit_room_transaction<T>(
325 &self,
326 room_id: RoomId,
327 tx: DatabaseTransaction,
328 data: T,
329 ) -> Result<RoomGuard<T>> {
330 let lock = self.rooms.entry(room_id).or_default().clone();
331 let _guard = lock.lock_owned().await;
332 tx.commit().await?;
333 Ok(RoomGuard {
334 data,
335 _guard,
336 _not_send: PhantomData,
337 })
338 }
339
340 pub async fn create_access_token_hash(
341 &self,
342 user_id: UserId,
343 access_token_hash: &str,
344 max_access_token_count: usize,
345 ) -> Result<()> {
346 self.transact(|tx| async {
347 let tx = tx;
348
349 access_token::ActiveModel {
350 user_id: ActiveValue::set(user_id),
351 hash: ActiveValue::set(access_token_hash.into()),
352 ..Default::default()
353 }
354 .insert(&tx)
355 .await?;
356
357 access_token::Entity::delete_many()
358 .filter(
359 access_token::Column::Id.in_subquery(
360 Query::select()
361 .column(access_token::Column::Id)
362 .from(access_token::Entity)
363 .and_where(access_token::Column::UserId.eq(user_id))
364 .order_by(access_token::Column::Id, sea_orm::Order::Desc)
365 .limit(10000)
366 .offset(max_access_token_count as u64)
367 .to_owned(),
368 ),
369 )
370 .exec(&tx)
371 .await?;
372 tx.commit().await?;
373 Ok(())
374 })
375 .await
376 }
377
378 pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
379 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
380 enum QueryAs {
381 Hash,
382 }
383
384 self.transact(|tx| async move {
385 Ok(access_token::Entity::find()
386 .select_only()
387 .column(access_token::Column::Hash)
388 .filter(access_token::Column::UserId.eq(user_id))
389 .order_by_desc(access_token::Column::Id)
390 .into_values::<_, QueryAs>()
391 .all(&tx)
392 .await?)
393 })
394 .await
395 }
396
397 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
398 where
399 F: Send + Fn(DatabaseTransaction) -> Fut,
400 Fut: Send + Future<Output = Result<T>>,
401 {
402 let body = async {
403 loop {
404 let tx = self.pool.begin().await?;
405
406 // In Postgres, serializable transactions are opt-in
407 if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() {
408 tx.execute(sea_orm::Statement::from_string(
409 sea_orm::DatabaseBackend::Postgres,
410 "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(),
411 ))
412 .await?;
413 }
414
415 match f(tx).await {
416 Ok(result) => return Ok(result),
417 Err(error) => match error {
418 Error::Database2(
419 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
420 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
421 ) if error
422 .as_database_error()
423 .and_then(|error| error.code())
424 .as_deref()
425 == Some("40001") =>
426 {
427 // Retry (don't break the loop)
428 }
429 error @ _ => return Err(error),
430 },
431 }
432 }
433 };
434
435 #[cfg(test)]
436 {
437 if let Some(background) = self.background.as_ref() {
438 background.simulate_random_delay().await;
439 }
440
441 self.runtime.as_ref().unwrap().block_on(body)
442 }
443
444 #[cfg(not(test))]
445 {
446 body.await
447 }
448 }
449}
450
451pub struct RoomGuard<T> {
452 data: T,
453 _guard: OwnedMutexGuard<()>,
454 _not_send: PhantomData<Rc<()>>,
455}
456
457impl<T> Deref for RoomGuard<T> {
458 type Target = T;
459
460 fn deref(&self) -> &T {
461 &self.data
462 }
463}
464
465impl<T> DerefMut for RoomGuard<T> {
466 fn deref_mut(&mut self) -> &mut T {
467 &mut self.data
468 }
469}
470
471#[derive(Debug, Serialize, Deserialize)]
472pub struct NewUserParams {
473 pub github_login: String,
474 pub github_user_id: i32,
475 pub invite_count: i32,
476}
477
478#[derive(Debug)]
479pub struct NewUserResult {
480 pub user_id: UserId,
481 pub metrics_id: String,
482 pub inviting_user_id: Option<UserId>,
483 pub signup_device_id: Option<String>,
484}
485
486fn random_invite_code() -> String {
487 nanoid::nanoid!(16)
488}
489
490fn random_email_confirmation_code() -> String {
491 nanoid::nanoid!(64)
492}
493
494macro_rules! id_type {
495 ($name:ident) => {
496 #[derive(
497 Clone,
498 Copy,
499 Debug,
500 Default,
501 PartialEq,
502 Eq,
503 PartialOrd,
504 Ord,
505 Hash,
506 sqlx::Type,
507 Serialize,
508 Deserialize,
509 )]
510 #[sqlx(transparent)]
511 #[serde(transparent)]
512 pub struct $name(pub i32);
513
514 impl $name {
515 #[allow(unused)]
516 pub const MAX: Self = Self(i32::MAX);
517
518 #[allow(unused)]
519 pub fn from_proto(value: u64) -> Self {
520 Self(value as i32)
521 }
522
523 #[allow(unused)]
524 pub fn to_proto(self) -> u64 {
525 self.0 as u64
526 }
527 }
528
529 impl std::fmt::Display for $name {
530 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
531 self.0.fmt(f)
532 }
533 }
534
535 impl From<$name> for sea_query::Value {
536 fn from(value: $name) -> Self {
537 sea_query::Value::Int(Some(value.0))
538 }
539 }
540
541 impl sea_orm::TryGetable for $name {
542 fn try_get(
543 res: &sea_orm::QueryResult,
544 pre: &str,
545 col: &str,
546 ) -> Result<Self, sea_orm::TryGetError> {
547 Ok(Self(i32::try_get(res, pre, col)?))
548 }
549 }
550
551 impl sea_query::ValueType for $name {
552 fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
553 match v {
554 Value::TinyInt(Some(int)) => {
555 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
556 }
557 Value::SmallInt(Some(int)) => {
558 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
559 }
560 Value::Int(Some(int)) => {
561 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
562 }
563 Value::BigInt(Some(int)) => {
564 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
565 }
566 Value::TinyUnsigned(Some(int)) => {
567 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
568 }
569 Value::SmallUnsigned(Some(int)) => {
570 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
571 }
572 Value::Unsigned(Some(int)) => {
573 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
574 }
575 Value::BigUnsigned(Some(int)) => {
576 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
577 }
578 _ => Err(sea_query::ValueTypeErr),
579 }
580 }
581
582 fn type_name() -> String {
583 stringify!($name).into()
584 }
585
586 fn array_type() -> sea_query::ArrayType {
587 sea_query::ArrayType::Int
588 }
589
590 fn column_type() -> sea_query::ColumnType {
591 sea_query::ColumnType::Integer(None)
592 }
593 }
594
595 impl sea_orm::TryFromU64 for $name {
596 fn try_from_u64(n: u64) -> Result<Self, DbErr> {
597 Ok(Self(n.try_into().map_err(|_| {
598 DbErr::ConvertFromU64(concat!(
599 "error converting ",
600 stringify!($name),
601 " to u64"
602 ))
603 })?))
604 }
605 }
606
607 impl sea_query::Nullable for $name {
608 fn null() -> Value {
609 Value::Int(None)
610 }
611 }
612 };
613}
614
615id_type!(AccessTokenId);
616id_type!(UserId);
617id_type!(RoomId);
618id_type!(RoomParticipantId);
619id_type!(ProjectId);
620id_type!(ProjectCollaboratorId);
621id_type!(WorktreeId);
622
623#[cfg(test)]
624pub use test::*;
625
626#[cfg(test)]
627mod test {
628 use super::*;
629 use gpui::executor::Background;
630 use lazy_static::lazy_static;
631 use parking_lot::Mutex;
632 use rand::prelude::*;
633 use sea_orm::ConnectionTrait;
634 use sqlx::migrate::MigrateDatabase;
635 use std::sync::Arc;
636
637 pub struct TestDb {
638 pub db: Option<Arc<Database>>,
639 pub connection: Option<sqlx::AnyConnection>,
640 }
641
642 impl TestDb {
643 pub fn sqlite(background: Arc<Background>) -> Self {
644 let url = format!("sqlite::memory:");
645 let runtime = tokio::runtime::Builder::new_current_thread()
646 .enable_io()
647 .enable_time()
648 .build()
649 .unwrap();
650
651 let mut db = runtime.block_on(async {
652 let mut options = ConnectOptions::new(url);
653 options.max_connections(5);
654 let db = Database::new(options).await.unwrap();
655 let sql = include_str!(concat!(
656 env!("CARGO_MANIFEST_DIR"),
657 "/migrations.sqlite/20221109000000_test_schema.sql"
658 ));
659 db.pool
660 .execute(sea_orm::Statement::from_string(
661 db.pool.get_database_backend(),
662 sql.into(),
663 ))
664 .await
665 .unwrap();
666 db
667 });
668
669 db.background = Some(background);
670 db.runtime = Some(runtime);
671
672 Self {
673 db: Some(Arc::new(db)),
674 connection: None,
675 }
676 }
677
678 pub fn postgres(background: Arc<Background>) -> Self {
679 lazy_static! {
680 static ref LOCK: Mutex<()> = Mutex::new(());
681 }
682
683 let _guard = LOCK.lock();
684 let mut rng = StdRng::from_entropy();
685 let url = format!(
686 "postgres://postgres@localhost/zed-test-{}",
687 rng.gen::<u128>()
688 );
689 let runtime = tokio::runtime::Builder::new_current_thread()
690 .enable_io()
691 .enable_time()
692 .build()
693 .unwrap();
694
695 let mut db = runtime.block_on(async {
696 sqlx::Postgres::create_database(&url)
697 .await
698 .expect("failed to create test db");
699 let mut options = ConnectOptions::new(url);
700 options
701 .max_connections(5)
702 .idle_timeout(Duration::from_secs(0));
703 let db = Database::new(options).await.unwrap();
704 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
705 db.migrate(Path::new(migrations_path), false).await.unwrap();
706 db
707 });
708
709 db.background = Some(background);
710 db.runtime = Some(runtime);
711
712 Self {
713 db: Some(Arc::new(db)),
714 connection: None,
715 }
716 }
717
718 pub fn db(&self) -> &Arc<Database> {
719 self.db.as_ref().unwrap()
720 }
721 }
722
723 impl Drop for TestDb {
724 fn drop(&mut self) {
725 let db = self.db.take().unwrap();
726 if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() {
727 db.runtime.as_ref().unwrap().block_on(async {
728 use util::ResultExt;
729 let query = "
730 SELECT pg_terminate_backend(pg_stat_activity.pid)
731 FROM pg_stat_activity
732 WHERE
733 pg_stat_activity.datname = current_database() AND
734 pid <> pg_backend_pid();
735 ";
736 db.pool
737 .execute(sea_orm::Statement::from_string(
738 db.pool.get_database_backend(),
739 query.into(),
740 ))
741 .await
742 .log_err();
743 sqlx::Postgres::drop_database(db.options.get_url())
744 .await
745 .log_err();
746 })
747 }
748 }
749 }
750}