1mod project;
2mod project_collaborator;
3mod room;
4mod room_participant;
5#[cfg(test)]
6mod tests;
7mod user;
8mod worktree;
9
10use crate::{Error, Result};
11use anyhow::anyhow;
12use collections::HashMap;
13use dashmap::DashMap;
14use futures::StreamExt;
15use rpc::{proto, ConnectionId};
16use sea_orm::ActiveValue;
17use sea_orm::{
18 entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr,
19 TransactionTrait,
20};
21use sea_query::OnConflict;
22use serde::{Deserialize, Serialize};
23use sqlx::migrate::{Migrate, Migration, MigrationSource};
24use sqlx::Connection;
25use std::ops::{Deref, DerefMut};
26use std::path::Path;
27use std::time::Duration;
28use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc};
29use tokio::sync::{Mutex, OwnedMutexGuard};
30
31pub use user::Model as User;
32
33pub struct Database {
34 url: String,
35 pool: DatabaseConnection,
36 rooms: DashMap<RoomId, Arc<Mutex<()>>>,
37 #[cfg(test)]
38 background: Option<std::sync::Arc<gpui::executor::Background>>,
39 #[cfg(test)]
40 runtime: Option<tokio::runtime::Runtime>,
41}
42
43impl Database {
44 pub async fn new(url: &str, max_connections: u32) -> Result<Self> {
45 let mut options = ConnectOptions::new(url.into());
46 options.max_connections(max_connections);
47 Ok(Self {
48 url: url.into(),
49 pool: sea_orm::Database::connect(options).await?,
50 rooms: DashMap::with_capacity(16384),
51 #[cfg(test)]
52 background: None,
53 #[cfg(test)]
54 runtime: None,
55 })
56 }
57
58 pub async fn migrate(
59 &self,
60 migrations_path: &Path,
61 ignore_checksum_mismatch: bool,
62 ) -> anyhow::Result<(sqlx::AnyConnection, Vec<(Migration, Duration)>)> {
63 let migrations = MigrationSource::resolve(migrations_path)
64 .await
65 .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?;
66
67 let mut connection = sqlx::AnyConnection::connect(&self.url).await?;
68
69 connection.ensure_migrations_table().await?;
70 let applied_migrations: HashMap<_, _> = connection
71 .list_applied_migrations()
72 .await?
73 .into_iter()
74 .map(|m| (m.version, m))
75 .collect();
76
77 let mut new_migrations = Vec::new();
78 for migration in migrations {
79 match applied_migrations.get(&migration.version) {
80 Some(applied_migration) => {
81 if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch
82 {
83 Err(anyhow!(
84 "checksum mismatch for applied migration {}",
85 migration.description
86 ))?;
87 }
88 }
89 None => {
90 let elapsed = connection.apply(&migration).await?;
91 new_migrations.push((migration, elapsed));
92 }
93 }
94 }
95
96 Ok((connection, new_migrations))
97 }
98
99 pub async fn create_user(
100 &self,
101 email_address: &str,
102 admin: bool,
103 params: NewUserParams,
104 ) -> Result<NewUserResult> {
105 self.transact(|tx| async {
106 let user = user::Entity::insert(user::ActiveModel {
107 email_address: ActiveValue::set(Some(email_address.into())),
108 github_login: ActiveValue::set(params.github_login.clone()),
109 github_user_id: ActiveValue::set(Some(params.github_user_id)),
110 admin: ActiveValue::set(admin),
111 metrics_id: ActiveValue::set(Uuid::new_v4()),
112 ..Default::default()
113 })
114 .on_conflict(
115 OnConflict::column(user::Column::GithubLogin)
116 .update_column(user::Column::GithubLogin)
117 .to_owned(),
118 )
119 .exec_with_returning(&tx)
120 .await?;
121
122 tx.commit().await?;
123
124 Ok(NewUserResult {
125 user_id: user.id,
126 metrics_id: user.metrics_id.to_string(),
127 signup_device_id: None,
128 inviting_user_id: None,
129 })
130 })
131 .await
132 }
133
134 pub async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<user::Model>> {
135 self.transact(|tx| async {
136 let tx = tx;
137 Ok(user::Entity::find()
138 .filter(user::Column::Id.is_in(ids.iter().copied()))
139 .all(&tx)
140 .await?)
141 })
142 .await
143 }
144
145 pub async fn share_project(
146 &self,
147 room_id: RoomId,
148 connection_id: ConnectionId,
149 worktrees: &[proto::WorktreeMetadata],
150 ) -> Result<RoomGuard<(ProjectId, proto::Room)>> {
151 self.transact(|tx| async move {
152 let participant = room_participant::Entity::find()
153 .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0))
154 .one(&tx)
155 .await?
156 .ok_or_else(|| anyhow!("could not find participant"))?;
157 if participant.room_id != room_id {
158 return Err(anyhow!("shared project on unexpected room"))?;
159 }
160
161 let project = project::ActiveModel {
162 room_id: ActiveValue::set(participant.room_id),
163 host_user_id: ActiveValue::set(participant.user_id),
164 host_connection_id: ActiveValue::set(connection_id.0 as i32),
165 ..Default::default()
166 }
167 .insert(&tx)
168 .await?;
169
170 worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel {
171 id: ActiveValue::set(worktree.id as i32),
172 project_id: ActiveValue::set(project.id),
173 abs_path: ActiveValue::set(worktree.abs_path.clone()),
174 root_name: ActiveValue::set(worktree.root_name.clone()),
175 visible: ActiveValue::set(worktree.visible),
176 scan_id: ActiveValue::set(0),
177 is_complete: ActiveValue::set(false),
178 }))
179 .exec(&tx)
180 .await?;
181
182 project_collaborator::ActiveModel {
183 project_id: ActiveValue::set(project.id),
184 connection_id: ActiveValue::set(connection_id.0 as i32),
185 user_id: ActiveValue::set(participant.user_id),
186 replica_id: ActiveValue::set(0),
187 is_host: ActiveValue::set(true),
188 ..Default::default()
189 }
190 .insert(&tx)
191 .await?;
192
193 let room = self.get_room(room_id, &tx).await?;
194 self.commit_room_transaction(room_id, tx, (project.id, room))
195 .await
196 })
197 .await
198 }
199
200 async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result<proto::Room> {
201 let db_room = room::Entity::find_by_id(room_id)
202 .one(tx)
203 .await?
204 .ok_or_else(|| anyhow!("could not find room"))?;
205
206 let mut db_participants = db_room
207 .find_related(room_participant::Entity)
208 .stream(tx)
209 .await?;
210 let mut participants = HashMap::default();
211 let mut pending_participants = Vec::new();
212 while let Some(db_participant) = db_participants.next().await {
213 let db_participant = db_participant?;
214 if let Some(answering_connection_id) = db_participant.answering_connection_id {
215 let location = match (
216 db_participant.location_kind,
217 db_participant.location_project_id,
218 ) {
219 (Some(0), Some(project_id)) => {
220 Some(proto::participant_location::Variant::SharedProject(
221 proto::participant_location::SharedProject {
222 id: project_id.to_proto(),
223 },
224 ))
225 }
226 (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
227 Default::default(),
228 )),
229 _ => Some(proto::participant_location::Variant::External(
230 Default::default(),
231 )),
232 };
233 participants.insert(
234 answering_connection_id,
235 proto::Participant {
236 user_id: db_participant.user_id.to_proto(),
237 peer_id: answering_connection_id as u32,
238 projects: Default::default(),
239 location: Some(proto::ParticipantLocation { variant: location }),
240 },
241 );
242 } else {
243 pending_participants.push(proto::PendingParticipant {
244 user_id: db_participant.user_id.to_proto(),
245 calling_user_id: db_participant.calling_user_id.to_proto(),
246 initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()),
247 });
248 }
249 }
250
251 let mut db_projects = db_room
252 .find_related(project::Entity)
253 .find_with_related(worktree::Entity)
254 .stream(tx)
255 .await?;
256
257 while let Some(row) = db_projects.next().await {
258 let (db_project, db_worktree) = row?;
259 if let Some(participant) = participants.get_mut(&db_project.host_connection_id) {
260 let project = if let Some(project) = participant
261 .projects
262 .iter_mut()
263 .find(|project| project.id == db_project.id.to_proto())
264 {
265 project
266 } else {
267 participant.projects.push(proto::ParticipantProject {
268 id: db_project.id.to_proto(),
269 worktree_root_names: Default::default(),
270 });
271 participant.projects.last_mut().unwrap()
272 };
273
274 if let Some(db_worktree) = db_worktree {
275 project.worktree_root_names.push(db_worktree.root_name);
276 }
277 }
278 }
279
280 Ok(proto::Room {
281 id: db_room.id.to_proto(),
282 live_kit_room: db_room.live_kit_room,
283 participants: participants.into_values().collect(),
284 pending_participants,
285 })
286 }
287
288 async fn commit_room_transaction<T>(
289 &self,
290 room_id: RoomId,
291 tx: DatabaseTransaction,
292 data: T,
293 ) -> Result<RoomGuard<T>> {
294 let lock = self.rooms.entry(room_id).or_default().clone();
295 let _guard = lock.lock_owned().await;
296 tx.commit().await?;
297 Ok(RoomGuard {
298 data,
299 _guard,
300 _not_send: PhantomData,
301 })
302 }
303
304 async fn transact<F, Fut, T>(&self, f: F) -> Result<T>
305 where
306 F: Send + Fn(DatabaseTransaction) -> Fut,
307 Fut: Send + Future<Output = Result<T>>,
308 {
309 let body = async {
310 loop {
311 let tx = self.pool.begin().await?;
312 match f(tx).await {
313 Ok(result) => return Ok(result),
314 Err(error) => match error {
315 Error::Database2(
316 DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error))
317 | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)),
318 ) if error
319 .as_database_error()
320 .and_then(|error| error.code())
321 .as_deref()
322 == Some("40001") =>
323 {
324 // Retry (don't break the loop)
325 }
326 error @ _ => return Err(error),
327 },
328 }
329 }
330 };
331
332 #[cfg(test)]
333 {
334 if let Some(background) = self.background.as_ref() {
335 background.simulate_random_delay().await;
336 }
337
338 self.runtime.as_ref().unwrap().block_on(body)
339 }
340
341 #[cfg(not(test))]
342 {
343 body.await
344 }
345 }
346}
347
348pub struct RoomGuard<T> {
349 data: T,
350 _guard: OwnedMutexGuard<()>,
351 _not_send: PhantomData<Rc<()>>,
352}
353
354impl<T> Deref for RoomGuard<T> {
355 type Target = T;
356
357 fn deref(&self) -> &T {
358 &self.data
359 }
360}
361
362impl<T> DerefMut for RoomGuard<T> {
363 fn deref_mut(&mut self) -> &mut T {
364 &mut self.data
365 }
366}
367
368#[derive(Debug, Serialize, Deserialize)]
369pub struct NewUserParams {
370 pub github_login: String,
371 pub github_user_id: i32,
372 pub invite_count: i32,
373}
374
375#[derive(Debug)]
376pub struct NewUserResult {
377 pub user_id: UserId,
378 pub metrics_id: String,
379 pub inviting_user_id: Option<UserId>,
380 pub signup_device_id: Option<String>,
381}
382
383fn random_invite_code() -> String {
384 nanoid::nanoid!(16)
385}
386
387fn random_email_confirmation_code() -> String {
388 nanoid::nanoid!(64)
389}
390
391macro_rules! id_type {
392 ($name:ident) => {
393 #[derive(
394 Clone,
395 Copy,
396 Debug,
397 Default,
398 PartialEq,
399 Eq,
400 PartialOrd,
401 Ord,
402 Hash,
403 sqlx::Type,
404 Serialize,
405 Deserialize,
406 )]
407 #[sqlx(transparent)]
408 #[serde(transparent)]
409 pub struct $name(pub i32);
410
411 impl $name {
412 #[allow(unused)]
413 pub const MAX: Self = Self(i32::MAX);
414
415 #[allow(unused)]
416 pub fn from_proto(value: u64) -> Self {
417 Self(value as i32)
418 }
419
420 #[allow(unused)]
421 pub fn to_proto(self) -> u64 {
422 self.0 as u64
423 }
424 }
425
426 impl std::fmt::Display for $name {
427 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
428 self.0.fmt(f)
429 }
430 }
431
432 impl From<$name> for sea_query::Value {
433 fn from(value: $name) -> Self {
434 sea_query::Value::Int(Some(value.0))
435 }
436 }
437
438 impl sea_orm::TryGetable for $name {
439 fn try_get(
440 res: &sea_orm::QueryResult,
441 pre: &str,
442 col: &str,
443 ) -> Result<Self, sea_orm::TryGetError> {
444 Ok(Self(i32::try_get(res, pre, col)?))
445 }
446 }
447
448 impl sea_query::ValueType for $name {
449 fn try_from(v: Value) -> Result<Self, sea_query::ValueTypeErr> {
450 match v {
451 Value::TinyInt(Some(int)) => {
452 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
453 }
454 Value::SmallInt(Some(int)) => {
455 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
456 }
457 Value::Int(Some(int)) => {
458 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
459 }
460 Value::BigInt(Some(int)) => {
461 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
462 }
463 Value::TinyUnsigned(Some(int)) => {
464 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
465 }
466 Value::SmallUnsigned(Some(int)) => {
467 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
468 }
469 Value::Unsigned(Some(int)) => {
470 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
471 }
472 Value::BigUnsigned(Some(int)) => {
473 Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?))
474 }
475 _ => Err(sea_query::ValueTypeErr),
476 }
477 }
478
479 fn type_name() -> String {
480 stringify!($name).into()
481 }
482
483 fn array_type() -> sea_query::ArrayType {
484 sea_query::ArrayType::Int
485 }
486
487 fn column_type() -> sea_query::ColumnType {
488 sea_query::ColumnType::Integer(None)
489 }
490 }
491
492 impl sea_orm::TryFromU64 for $name {
493 fn try_from_u64(n: u64) -> Result<Self, DbErr> {
494 Ok(Self(n.try_into().map_err(|_| {
495 DbErr::ConvertFromU64(concat!(
496 "error converting ",
497 stringify!($name),
498 " to u64"
499 ))
500 })?))
501 }
502 }
503
504 impl sea_query::Nullable for $name {
505 fn null() -> Value {
506 Value::Int(None)
507 }
508 }
509 };
510}
511
512id_type!(UserId);
513id_type!(RoomId);
514id_type!(RoomParticipantId);
515id_type!(ProjectId);
516id_type!(ProjectCollaboratorId);
517id_type!(WorktreeId);
518
519#[cfg(test)]
520pub use test::*;
521
522#[cfg(test)]
523mod test {
524 use super::*;
525 use gpui::executor::Background;
526 use lazy_static::lazy_static;
527 use parking_lot::Mutex;
528 use rand::prelude::*;
529 use sea_orm::ConnectionTrait;
530 use sqlx::migrate::MigrateDatabase;
531 use std::sync::Arc;
532
533 pub struct TestDb {
534 pub db: Option<Arc<Database>>,
535 pub connection: Option<sqlx::AnyConnection>,
536 }
537
538 impl TestDb {
539 pub fn sqlite(background: Arc<Background>) -> Self {
540 let url = format!("sqlite::memory:");
541 let runtime = tokio::runtime::Builder::new_current_thread()
542 .enable_io()
543 .enable_time()
544 .build()
545 .unwrap();
546
547 let mut db = runtime.block_on(async {
548 let db = Database::new(&url, 5).await.unwrap();
549 let sql = include_str!(concat!(
550 env!("CARGO_MANIFEST_DIR"),
551 "/migrations.sqlite/20221109000000_test_schema.sql"
552 ));
553 db.pool
554 .execute(sea_orm::Statement::from_string(
555 db.pool.get_database_backend(),
556 sql.into(),
557 ))
558 .await
559 .unwrap();
560 db
561 });
562
563 db.background = Some(background);
564 db.runtime = Some(runtime);
565
566 Self {
567 db: Some(Arc::new(db)),
568 connection: None,
569 }
570 }
571
572 pub fn postgres(background: Arc<Background>) -> Self {
573 lazy_static! {
574 static ref LOCK: Mutex<()> = Mutex::new(());
575 }
576
577 let _guard = LOCK.lock();
578 let mut rng = StdRng::from_entropy();
579 let url = format!(
580 "postgres://postgres@localhost/zed-test-{}",
581 rng.gen::<u128>()
582 );
583 let runtime = tokio::runtime::Builder::new_current_thread()
584 .enable_io()
585 .enable_time()
586 .build()
587 .unwrap();
588
589 let mut db = runtime.block_on(async {
590 sqlx::Postgres::create_database(&url)
591 .await
592 .expect("failed to create test db");
593 let db = Database::new(&url, 5).await.unwrap();
594 let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
595 db.migrate(Path::new(migrations_path), false).await.unwrap();
596 db
597 });
598
599 db.background = Some(background);
600 db.runtime = Some(runtime);
601
602 Self {
603 db: Some(Arc::new(db)),
604 connection: None,
605 }
606 }
607
608 pub fn db(&self) -> &Arc<Database> {
609 self.db.as_ref().unwrap()
610 }
611 }
612
613 // TODO: Implement drop
614 // impl Drop for PostgresTestDb {
615 // fn drop(&mut self) {
616 // let db = self.db.take().unwrap();
617 // db.teardown(&self.url);
618 // }
619 // }
620}