diff --git a/Cargo.lock b/Cargo.lock index d3cc50f9912164bea93ff818f12ab68b7e27b914..9a4d3a78d15e7b4ec2baf268ec43d90c9884fb41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "Inflector" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" + [[package]] name = "activity_indicator" version = "0.1.0" @@ -107,6 +113,12 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "aliasable" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" + [[package]] name = "ambient-authority" version = "0.0.1" @@ -562,6 +574,19 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "bae" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33b8de67cc41132507eeece2584804efcb15f85ba516e34c944b7667f480397a" +dependencies = [ + "heck 0.3.3", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "base64" version = "0.13.1" @@ -650,6 +675,51 @@ dependencies = [ "futures-lite", ] +[[package]] +name = "borsh" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15bf3650200d8bffa99015595e10f1fbd17de07abbc25bb067da79e769939bfa" +dependencies = [ + "borsh-derive", + "hashbrown 0.11.2", +] + +[[package]] +name = "borsh-derive" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6441c552f230375d18e3cc377677914d2ca2b0d36e52129fe15450a2dce46775" +dependencies = [ + "borsh-derive-internal", + "borsh-schema-derive-internal", + "proc-macro-crate", + "proc-macro2", + "syn", +] + +[[package]] +name = "borsh-derive-internal" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5449c28a7b352f2d1e592a8a28bf139bc71afb0764a14f3c02500935d8c44065" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "borsh-schema-derive-internal" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdbd5696d8bfa21d53d9fe39a714a18538bad11492a42d066dbbc395fb1951c0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "breadcrumbs" version = "0.1.0" @@ -693,6 +763,27 @@ version = "3.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "572f695136211188308f16ad2ca5c851a712c464060ae6974944458eb83880ba" +[[package]] +name = "bytecheck" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11cac2c12b5adc6570dad2ee1b87eff4955dac476fe12d81e5fdd352e52406f" +dependencies = [ + "bytecheck_derive", + "ptr_meta", +] + +[[package]] +name = "bytecheck_derive" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13e576ebe98e605500b3c8041bb888e966653577172df6dd97398714eb30b9bf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "bytemuck" version = "1.12.3" @@ -850,6 +941,7 @@ dependencies = [ "js-sys", "num-integer", "num-traits", + "serde", "time 0.1.45", "wasm-bindgen", "winapi 0.3.9", @@ -1041,7 +1133,6 @@ name = "collab" version = "0.2.5" dependencies = [ "anyhow", - "async-trait", "async-tungstenite", "axum", "axum-extra", @@ -1051,6 +1142,7 @@ dependencies = [ "client", "collections", "ctor", + "dashmap", "editor", "env_logger", "envy", @@ -1074,6 +1166,8 @@ dependencies = [ "reqwest", "rpc", "scrypt", + "sea-orm", + "sea-query", "serde", "serde_json", "settings", @@ -1546,6 +1640,19 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if 1.0.0", + "hashbrown 0.12.3", + "lock_api", + "once_cell", + "parking_lot_core 0.9.5", +] + [[package]] name = "data-url" version = "0.1.1" @@ -3107,9 +3214,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.25.2" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29f835d03d717946d28b1d1ed632eb6f0e24a299388ee623d0c23118d3e8a7fa" +checksum = "898745e570c7d0453cc1fbc4a701eb6c662ed54e8fec8b7d14be137ebeeb9d14" dependencies = [ "cc", "pkg-config", @@ -3858,6 +3965,29 @@ version = "6.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" +[[package]] +name = "ouroboros" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbb50b356159620db6ac971c6d5c9ab788c9cc38a6f49619fca2a27acb062ca" +dependencies = [ + "aliasable", + "ouroboros_macro", +] + +[[package]] +name = "ouroboros_macro" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0d9d1a6191c4f391f87219d1ea42b23f09ee84d64763cd05ee6ea88d9f384d" +dependencies = [ + "Inflector", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "outline" version = "0.1.0" @@ -4216,6 +4346,15 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "proc-macro-crate" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d6ea3c4595b96363c13943497db34af4460fb474a95c43f4446ad341b8c9785" +dependencies = [ + "toml", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -4461,6 +4600,26 @@ dependencies = [ "cc", ] +[[package]] +name = "ptr_meta" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0738ccf7ea06b608c10564b31debd4f5bc5e197fc8bfe088f68ae5ce81e7a4f1" +dependencies = [ + "ptr_meta_derive", +] + +[[package]] +name = "ptr_meta_derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b845dbfca988fa33db069c0e230574d15a3088f147a87b64c7589eb662c9ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pulldown-cmark" version = "0.9.2" @@ -4697,6 +4856,15 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "rend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79af64b4b6362ffba04eef3a4e10829718a4896dac19daa741851c86781edf95" +dependencies = [ + "bytecheck", +] + [[package]] name = "reqwest" version = "0.11.13" @@ -4774,6 +4942,31 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "rkyv" +version = "0.7.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cec2b3485b07d96ddfd3134767b8a447b45ea4eb91448d0a35180ec0ffd5ed15" +dependencies = [ + "bytecheck", + "hashbrown 0.12.3", + "ptr_meta", + "rend", + "rkyv_derive", + "seahash", +] + +[[package]] +name = "rkyv_derive" +version = "0.7.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eaedadc88b53e36dd32d940ed21ae4d850d5916f2581526921f553a72ac34c4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "rmp" version = "0.8.11" @@ -4901,6 +5094,24 @@ dependencies = [ "walkdir", ] +[[package]] +name = "rust_decimal" +version = "1.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33c321ee4e17d2b7abe12b5d20c1231db708dd36185c8a21e9de5fed6da4dbe9" +dependencies = [ + "arrayvec 0.7.2", + "borsh", + "bytecheck", + "byteorder", + "bytes 1.3.0", + "num-traits", + "rand 0.8.5", + "rkyv", + "serde", + "serde_json", +] + [[package]] name = "rustc-demangle" version = "0.1.21" @@ -4972,6 +5183,12 @@ dependencies = [ "base64", ] +[[package]] +name = "rustversion" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97477e48b4cf8603ad5f7aaf897467cf42ab4218a38ef76fb14c2d6773a6d6a8" + [[package]] name = "rustybuzz" version = "0.3.0" @@ -5113,6 +5330,109 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sea-orm" +version = "0.10.5" +source = "git+https://github.com/zed-industries/sea-orm?rev=18f4c691085712ad014a51792af75a9044bacee6#18f4c691085712ad014a51792af75a9044bacee6" +dependencies = [ + "async-stream", + "async-trait", + "chrono", + "futures 0.3.25", + "futures-util", + "log", + "ouroboros", + "rust_decimal", + "sea-orm-macros", + "sea-query", + "sea-query-binder", + "sea-strum", + "serde", + "serde_json", + "sqlx", + "thiserror", + "time 0.3.17", + "tracing", + "url", + "uuid 1.2.2", +] + +[[package]] +name = "sea-orm-macros" +version = "0.10.5" +source = "git+https://github.com/zed-industries/sea-orm?rev=18f4c691085712ad014a51792af75a9044bacee6#18f4c691085712ad014a51792af75a9044bacee6" +dependencies = [ + "bae", + "heck 0.3.3", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sea-query" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4f0fc4d8e44e1d51c739a68d336252a18bc59553778075d5e32649be6ec92ed" +dependencies = [ + "chrono", + "rust_decimal", + "sea-query-derive", + "serde_json", + "time 0.3.17", + "uuid 1.2.2", +] + +[[package]] +name = "sea-query-binder" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c2585b89c985cfacfe0ec9fc9e7bb055b776c1a2581c4e3c6185af2b8bf8865" +dependencies = [ + "chrono", + "rust_decimal", + "sea-query", + "serde_json", + "sqlx", + "time 0.3.17", + "uuid 1.2.2", +] + +[[package]] +name = "sea-query-derive" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34cdc022b4f606353fe5dc85b09713a04e433323b70163e81513b141c6ae6eb5" +dependencies = [ + "heck 0.3.3", + "proc-macro2", + "quote", + "syn", + "thiserror", +] + +[[package]] +name = "sea-strum" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "391d06a6007842cfe79ac6f7f53911b76dfd69fc9a6769f1cf6569d12ce20e1b" +dependencies = [ + "sea-strum_macros", +] + +[[package]] +name = "sea-strum_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69b4397b825df6ccf1e98bcdabef3bbcfc47ff5853983467850eeab878384f21" +dependencies = [ + "heck 0.3.3", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "seahash" version = "4.1.0" @@ -5626,8 +5946,9 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.6.2" -source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "788841def501aabde58d3666fcea11351ec3962e6ea75dbcd05c84a71d68bcd1" dependencies = [ "sqlx-core", "sqlx-macros", @@ -5636,7 +5957,8 @@ dependencies = [ [[package]] name = "sqlx-core" version = "0.6.2" -source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcbc16ddba161afc99e14d1713a453747a2b07fc097d2009f4c300ec99286105" dependencies = [ "ahash", "atoi", @@ -5644,6 +5966,7 @@ dependencies = [ "bitflags", "byteorder", "bytes 1.3.0", + "chrono", "crc", "crossbeam-queue", "dirs 4.0.0", @@ -5667,10 +5990,12 @@ dependencies = [ "log", "md-5", "memchr", + "num-bigint", "once_cell", "paste", "percent-encoding", "rand 0.8.5", + "rust_decimal", "rustls 0.20.7", "rustls-pemfile", "serde", @@ -5693,7 +6018,8 @@ dependencies = [ [[package]] name = "sqlx-macros" version = "0.6.2" -source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b850fa514dc11f2ee85be9d055c512aa866746adfacd1cb42d867d68e6a5b0d9" dependencies = [ "dotenvy", "either", @@ -5712,7 +6038,8 @@ dependencies = [ [[package]] name = "sqlx-rt" version = "0.6.2" -source = "git+https://github.com/launchbadge/sqlx?rev=4b7053807c705df312bcb9b6281e184bf7534eb3#4b7053807c705df312bcb9b6281e184bf7534eb3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24c5b2d25fa654cc5f841750b8e1cdedbe21189bf9a9382ee90bfa9dd3562396" dependencies = [ "once_cell", "tokio", @@ -6870,6 +7197,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "422ee0de9031b5b948b97a8fc04e3aa35230001a722ddd27943e0be31564ce4c" dependencies = [ "getrandom 0.2.8", + "serde", ] [[package]] diff --git a/crates/call/src/call.rs b/crates/call/src/call.rs index 6b72eb61da37398036869ed0a1f554442a3b52ae..803fbb906adc53ac03cb1826a1f139931a83f8e1 100644 --- a/crates/call/src/call.rs +++ b/crates/call/src/call.rs @@ -22,7 +22,7 @@ pub fn init(client: Arc, user_store: ModelHandle, cx: &mut Mu #[derive(Clone)] pub struct IncomingCall { pub room_id: u64, - pub caller: Arc, + pub calling_user: Arc, pub participants: Vec>, pub initial_project: Option, } @@ -78,9 +78,9 @@ impl ActiveCall { user_store.get_users(envelope.payload.participant_user_ids, cx) }) .await?, - caller: user_store + calling_user: user_store .update(&mut cx, |user_store, cx| { - user_store.get_user(envelope.payload.caller_user_id, cx) + user_store.get_user(envelope.payload.calling_user_id, cx) }) .await?, initial_project: envelope.payload.initial_project, @@ -110,13 +110,13 @@ impl ActiveCall { pub fn invite( &mut self, - recipient_user_id: u64, + called_user_id: u64, initial_project: Option>, cx: &mut ModelContext, ) -> Task> { let client = self.client.clone(); let user_store = self.user_store.clone(); - if !self.pending_invites.insert(recipient_user_id) { + if !self.pending_invites.insert(called_user_id) { return Task::ready(Err(anyhow!("user was already invited"))); } @@ -136,13 +136,13 @@ impl ActiveCall { }; room.update(&mut cx, |room, cx| { - room.call(recipient_user_id, initial_project_id, cx) + room.call(called_user_id, initial_project_id, cx) }) .await?; } else { let room = cx .update(|cx| { - Room::create(recipient_user_id, initial_project, client, user_store, cx) + Room::create(called_user_id, initial_project, client, user_store, cx) }) .await?; @@ -155,7 +155,7 @@ impl ActiveCall { let result = invite.await; this.update(&mut cx, |this, cx| { - this.pending_invites.remove(&recipient_user_id); + this.pending_invites.remove(&called_user_id); cx.notify(); }); result @@ -164,7 +164,7 @@ impl ActiveCall { pub fn cancel_invite( &mut self, - recipient_user_id: u64, + called_user_id: u64, cx: &mut ModelContext, ) -> Task> { let room_id = if let Some(room) = self.room() { @@ -178,7 +178,7 @@ impl ActiveCall { client .request(proto::CancelCall { room_id, - recipient_user_id, + called_user_id, }) .await?; anyhow::Ok(()) diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 7d5153950d76d16f3e3185835eed22ac430fda97..f8a55a3a931a9d349cb4c1a38db753d9e92846cd 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -10,7 +10,7 @@ use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext use live_kit_client::{LocalTrackPublication, LocalVideoTrack, RemoteVideoTrackUpdate}; use postage::stream::Stream; use project::Project; -use std::{mem, os::unix::prelude::OsStrExt, sync::Arc}; +use std::{mem, sync::Arc}; use util::{post_inc, ResultExt}; #[derive(Clone, Debug, PartialEq, Eq)] @@ -53,7 +53,7 @@ impl Entity for Room { fn release(&mut self, _: &mut MutableAppContext) { if self.status.is_online() { - self.client.send(proto::LeaveRoom { id: self.id }).log_err(); + self.client.send(proto::LeaveRoom {}).log_err(); } } } @@ -149,7 +149,7 @@ impl Room { } pub(crate) fn create( - recipient_user_id: u64, + called_user_id: u64, initial_project: Option>, client: Arc, user_store: ModelHandle, @@ -182,7 +182,7 @@ impl Room { match room .update(&mut cx, |room, cx| { room.leave_when_empty = true; - room.call(recipient_user_id, initial_project_id, cx) + room.call(called_user_id, initial_project_id, cx) }) .await { @@ -241,7 +241,7 @@ impl Room { self.participant_user_ids.clear(); self.subscriptions.clear(); self.live_kit.take(); - self.client.send(proto::LeaveRoom { id: self.id })?; + self.client.send(proto::LeaveRoom {})?; Ok(()) } @@ -294,6 +294,11 @@ impl Room { .position(|participant| Some(participant.user_id) == self.client.user_id()); let local_participant = local_participant_ix.map(|ix| room.participants.swap_remove(ix)); + let pending_participant_user_ids = room + .pending_participants + .iter() + .map(|p| p.user_id) + .collect::>(); let remote_participant_user_ids = room .participants .iter() @@ -303,7 +308,7 @@ impl Room { self.user_store.update(cx, move |user_store, cx| { ( user_store.get_users(remote_participant_user_ids, cx), - user_store.get_users(room.pending_participant_user_ids, cx), + user_store.get_users(pending_participant_user_ids, cx), ) }); self.pending_room_update = Some(cx.spawn(|this, mut cx| async move { @@ -487,7 +492,7 @@ impl Room { pub(crate) fn call( &mut self, - recipient_user_id: u64, + called_user_id: u64, initial_project_id: Option, cx: &mut ModelContext, ) -> Task> { @@ -503,7 +508,7 @@ impl Room { let result = client .request(proto::Call { room_id, - recipient_user_id, + called_user_id, initial_project_id, }) .await; @@ -538,7 +543,7 @@ impl Room { id: worktree.id().to_proto(), root_name: worktree.root_name().into(), visible: worktree.is_visible(), - abs_path: worktree.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: worktree.abs_path().to_string_lossy().into(), } }) .collect(), diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index f9b3a88545e1a611fe611e14a86e55c07a6be371..5e10f9ea8f4374d3e0607be2a0f68fc5259bc431 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -16,8 +16,7 @@ use gpui::{ actions, serde_json::{self, Value}, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext, - AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, - ViewHandle, + AsyncAppContext, Entity, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle, }; use http::HttpClient; use lazy_static::lazy_static; @@ -32,6 +31,7 @@ use std::{ convert::TryFrom, fmt::Write as _, future::Future, + marker::PhantomData, path::PathBuf, sync::{Arc, Weak}, time::{Duration, Instant}, @@ -171,7 +171,7 @@ struct ClientState { entity_id_extractors: HashMap u64>, _reconnect_task: Option>, reconnect_interval: Duration, - entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>, + entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>, models_by_message_type: HashMap, entity_types_by_message_type: HashMap, #[allow(clippy::type_complexity)] @@ -181,7 +181,7 @@ struct ClientState { dyn Send + Sync + Fn( - AnyEntityHandle, + Subscriber, Box, &Arc, AsyncAppContext, @@ -190,12 +190,13 @@ struct ClientState { >, } -enum AnyWeakEntityHandle { +enum WeakSubscriber { Model(AnyWeakModelHandle), View(AnyWeakViewHandle), + Pending(Vec>), } -enum AnyEntityHandle { +enum Subscriber { Model(AnyModelHandle), View(AnyViewHandle), } @@ -253,6 +254,54 @@ impl Drop for Subscription { } } +pub struct PendingEntitySubscription { + client: Arc, + remote_id: u64, + _entity_type: PhantomData, + consumed: bool, +} + +impl PendingEntitySubscription { + pub fn set_model(mut self, model: &ModelHandle, cx: &mut AsyncAppContext) -> Subscription { + self.consumed = true; + let mut state = self.client.state.write(); + let id = (TypeId::of::(), self.remote_id); + let Some(WeakSubscriber::Pending(messages)) = + state.entities_by_type_and_remote_id.remove(&id) + else { + unreachable!() + }; + + state + .entities_by_type_and_remote_id + .insert(id, WeakSubscriber::Model(model.downgrade().into())); + drop(state); + for message in messages { + self.client.handle_message(message, cx); + } + Subscription::Entity { + client: Arc::downgrade(&self.client), + id, + } + } +} + +impl Drop for PendingEntitySubscription { + fn drop(&mut self) { + if !self.consumed { + let mut state = self.client.state.write(); + if let Some(WeakSubscriber::Pending(messages)) = state + .entities_by_type_and_remote_id + .remove(&(TypeId::of::(), self.remote_id)) + { + for message in messages { + log::info!("unhandled message {}", message.payload_type_name()); + } + } + } + } +} + impl Client { pub fn new(http: Arc, cx: &AppContext) -> Arc { Arc::new(Self { @@ -348,7 +397,11 @@ impl Client { let this = self.clone(); let reconnect_interval = state.reconnect_interval; state._reconnect_task = Some(cx.spawn(|cx| async move { + #[cfg(any(test, feature = "test-support"))] + let mut rng = StdRng::seed_from_u64(0); + #[cfg(not(any(test, feature = "test-support")))] let mut rng = StdRng::from_entropy(); + let mut delay = INITIAL_RECONNECTION_DELAY; while let Err(error) = this.authenticate_and_connect(true, &cx).await { log::error!("failed to connect {}", error); @@ -386,26 +439,28 @@ impl Client { self.state .write() .entities_by_type_and_remote_id - .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into())); + .insert(id, WeakSubscriber::View(cx.weak_handle().into())); Subscription::Entity { client: Arc::downgrade(self), id, } } - pub fn add_model_for_remote_entity( + pub fn subscribe_to_entity( self: &Arc, remote_id: u64, - cx: &mut ModelContext, - ) -> Subscription { + ) -> PendingEntitySubscription { let id = (TypeId::of::(), remote_id); self.state .write() .entities_by_type_and_remote_id - .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into())); - Subscription::Entity { - client: Arc::downgrade(self), - id, + .insert(id, WeakSubscriber::Pending(Default::default())); + + PendingEntitySubscription { + client: self.clone(), + remote_id, + consumed: false, + _entity_type: PhantomData, } } @@ -433,7 +488,7 @@ impl Client { let prev_handler = state.message_handlers.insert( message_type_id, Arc::new(move |handle, envelope, client, cx| { - let handle = if let AnyEntityHandle::Model(handle) = handle { + let handle = if let Subscriber::Model(handle) = handle { handle } else { unreachable!(); @@ -487,7 +542,7 @@ impl Client { F: 'static + Future>, { self.add_entity_message_handler::(move |handle, message, client, cx| { - if let AnyEntityHandle::View(handle) = handle { + if let Subscriber::View(handle) = handle { handler(handle.downcast::().unwrap(), message, client, cx) } else { unreachable!(); @@ -506,7 +561,7 @@ impl Client { F: 'static + Future>, { self.add_entity_message_handler::(move |handle, message, client, cx| { - if let AnyEntityHandle::Model(handle) = handle { + if let Subscriber::Model(handle) = handle { handler(handle.downcast::().unwrap(), message, client, cx) } else { unreachable!(); @@ -521,7 +576,7 @@ impl Client { H: 'static + Send + Sync - + Fn(AnyEntityHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, + + Fn(Subscriber, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future>, { let model_type_id = TypeId::of::(); @@ -783,94 +838,8 @@ impl Client { let cx = cx.clone(); let this = self.clone(); async move { - let mut message_id = 0_usize; while let Some(message) = incoming.next().await { - let mut state = this.state.write(); - message_id += 1; - let type_name = message.payload_type_name(); - let payload_type_id = message.payload_type_id(); - let sender_id = message.original_sender_id().map(|id| id.0); - - let model = state - .models_by_message_type - .get(&payload_type_id) - .and_then(|model| model.upgrade(&cx)) - .map(AnyEntityHandle::Model) - .or_else(|| { - let entity_type_id = - *state.entity_types_by_message_type.get(&payload_type_id)?; - let entity_id = state - .entity_id_extractors - .get(&message.payload_type_id()) - .map(|extract_entity_id| { - (extract_entity_id)(message.as_ref()) - })?; - - let entity = state - .entities_by_type_and_remote_id - .get(&(entity_type_id, entity_id))?; - if let Some(entity) = entity.upgrade(&cx) { - Some(entity) - } else { - state - .entities_by_type_and_remote_id - .remove(&(entity_type_id, entity_id)); - None - } - }); - - let model = if let Some(model) = model { - model - } else { - log::info!("unhandled message {}", type_name); - continue; - }; - - let handler = state.message_handlers.get(&payload_type_id).cloned(); - // Dropping the state prevents deadlocks if the handler interacts with rpc::Client. - // It also ensures we don't hold the lock while yielding back to the executor, as - // that might cause the executor thread driving this future to block indefinitely. - drop(state); - - if let Some(handler) = handler { - let future = handler(model, message, &this, cx.clone()); - let client_id = this.id; - log::debug!( - "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}", - client_id, - message_id, - sender_id, - type_name - ); - cx.foreground() - .spawn(async move { - match future.await { - Ok(()) => { - log::debug!( - "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}", - client_id, - message_id, - sender_id, - type_name - ); - } - Err(error) => { - log::error!( - "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}", - client_id, - message_id, - sender_id, - type_name, - error - ); - } - } - }) - .detach(); - } else { - log::info!("unhandled message {}", type_name); - } - + this.handle_message(message, &cx); // Don't starve the main thread when receiving lots of messages at once. smol::future::yield_now().await; } @@ -1217,6 +1186,97 @@ impl Client { self.peer.respond_with_error(receipt, error) } + fn handle_message( + self: &Arc, + message: Box, + cx: &AsyncAppContext, + ) { + let mut state = self.state.write(); + let type_name = message.payload_type_name(); + let payload_type_id = message.payload_type_id(); + let sender_id = message.original_sender_id().map(|id| id.0); + + let mut subscriber = None; + + if let Some(message_model) = state + .models_by_message_type + .get(&payload_type_id) + .and_then(|model| model.upgrade(cx)) + { + subscriber = Some(Subscriber::Model(message_model)); + } else if let Some((extract_entity_id, entity_type_id)) = + state.entity_id_extractors.get(&payload_type_id).zip( + state + .entity_types_by_message_type + .get(&payload_type_id) + .copied(), + ) + { + let entity_id = (extract_entity_id)(message.as_ref()); + + match state + .entities_by_type_and_remote_id + .get_mut(&(entity_type_id, entity_id)) + { + Some(WeakSubscriber::Pending(pending)) => { + pending.push(message); + return; + } + Some(weak_subscriber @ _) => subscriber = weak_subscriber.upgrade(cx), + _ => {} + } + } + + let subscriber = if let Some(subscriber) = subscriber { + subscriber + } else { + log::info!("unhandled message {}", type_name); + return; + }; + + let handler = state.message_handlers.get(&payload_type_id).cloned(); + // Dropping the state prevents deadlocks if the handler interacts with rpc::Client. + // It also ensures we don't hold the lock while yielding back to the executor, as + // that might cause the executor thread driving this future to block indefinitely. + drop(state); + + if let Some(handler) = handler { + let future = handler(subscriber, message, &self, cx.clone()); + let client_id = self.id; + log::debug!( + "rpc message received. client_id:{}, sender_id:{:?}, type:{}", + client_id, + sender_id, + type_name + ); + cx.foreground() + .spawn(async move { + match future.await { + Ok(()) => { + log::debug!( + "rpc message handled. client_id:{}, sender_id:{:?}, type:{}", + client_id, + sender_id, + type_name + ); + } + Err(error) => { + log::error!( + "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}", + client_id, + sender_id, + type_name, + error + ); + } + } + }) + .detach(); + } else { + log::info!("unhandled message {}", type_name); + } + } + pub fn start_telemetry(&self) { self.telemetry.start(); } @@ -1230,11 +1290,12 @@ impl Client { } } -impl AnyWeakEntityHandle { - fn upgrade(&self, cx: &AsyncAppContext) -> Option { +impl WeakSubscriber { + fn upgrade(&self, cx: &AsyncAppContext) -> Option { match self { - AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model), - AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View), + WeakSubscriber::Model(handle) => handle.upgrade(cx).map(Subscriber::Model), + WeakSubscriber::View(handle) => handle.upgrade(cx).map(Subscriber::View), + WeakSubscriber::Pending(_) => None, } } } @@ -1479,11 +1540,17 @@ mod tests { subscription: None, }); - let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx)); - let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx)); + let _subscription1 = client + .subscribe_to_entity(1) + .set_model(&model1, &mut cx.to_async()); + let _subscription2 = client + .subscribe_to_entity(2) + .set_model(&model2, &mut cx.to_async()); // Ensure dropping a subscription for the same entity type still allows receiving of // messages for other entity IDs of the same type. - let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx)); + let subscription3 = client + .subscribe_to_entity(3) + .set_model(&model3, &mut cx.to_async()); drop(subscription3); server.send(proto::JoinProject { project_id: 1 }); diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 09f379526eec23d44f2057e48b2fb7d7b27e2d17..8725642ae52a4244234dad2c364ebc3294673dce 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -19,12 +19,12 @@ rpc = { path = "../rpc" } util = { path = "../util" } anyhow = "1.0.40" -async-trait = "0.1.50" async-tungstenite = "0.16" axum = { version = "0.5", features = ["json", "headers", "ws"] } axum-extra = { version = "0.3", features = ["erased-json"] } base64 = "0.13" clap = { version = "3.1", features = ["derive"], optional = true } +dashmap = "5.4" envy = "0.4.2" futures = "0.3" hyper = "0.14" @@ -36,9 +36,13 @@ prometheus = "0.13" rand = "0.8" reqwest = { version = "0.11", features = ["json"], optional = true } scrypt = "0.7" +# Remove fork dependency when a version with https://github.com/SeaQL/sea-orm/pull/1283 is released. +sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls"] } +sea-query = "0.27" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" sha-1 = "0.9" +sqlx = { version = "0.6", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] } time = { version = "0.3", features = ["serde", "serde-well-known"] } tokio = { version = "1", features = ["full"] } tokio-tungstenite = "0.17" @@ -49,11 +53,6 @@ tracing = "0.1.34" tracing-log = "0.1.3" tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] } -[dependencies.sqlx] -git = "https://github.com/launchbadge/sqlx" -rev = "4b7053807c705df312bcb9b6281e184bf7534eb3" -features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"] - [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } @@ -76,13 +75,10 @@ env_logger = "0.9" log = { version = "0.4.16", features = ["kv_unstable_serde"] } util = { path = "../util" } lazy_static = "1.4" +sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] } serde_json = { version = "1.0", features = ["preserve_order"] } +sqlx = { version = "0.6", features = ["sqlite"] } unindent = "0.1" -[dev-dependencies.sqlx] -git = "https://github.com/launchbadge/sqlx" -rev = "4b7053807c705df312bcb9b6281e184bf7534eb3" -features = ["sqlite"] - [features] seed-support = ["clap", "lipsum", "reqwest"] diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 63d2661de5d5d2262b371de651b434f6fe1a6c38..90fd8ace122ff0a6e28b879634b574e6876951a0 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -1,4 +1,4 @@ -CREATE TABLE IF NOT EXISTS "users" ( +CREATE TABLE "users" ( "id" INTEGER PRIMARY KEY, "github_login" VARCHAR, "admin" BOOLEAN, @@ -8,7 +8,7 @@ CREATE TABLE IF NOT EXISTS "users" ( "inviter_id" INTEGER REFERENCES users (id), "connected_once" BOOLEAN NOT NULL DEFAULT false, "created_at" TIMESTAMP NOT NULL DEFAULT now, - "metrics_id" VARCHAR(255), + "metrics_id" TEXT, "github_user_id" INTEGER ); CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login"); @@ -16,14 +16,14 @@ CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code"); CREATE INDEX "index_users_on_email_address" ON "users" ("email_address"); CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id"); -CREATE TABLE IF NOT EXISTS "access_tokens" ( +CREATE TABLE "access_tokens" ( "id" INTEGER PRIMARY KEY, "user_id" INTEGER REFERENCES users (id), "hash" VARCHAR(128) ); CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id"); -CREATE TABLE IF NOT EXISTS "contacts" ( +CREATE TABLE "contacts" ( "id" INTEGER PRIMARY KEY, "user_id_a" INTEGER REFERENCES users (id) NOT NULL, "user_id_b" INTEGER REFERENCES users (id) NOT NULL, @@ -34,8 +34,96 @@ CREATE TABLE IF NOT EXISTS "contacts" ( CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_id_b"); CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b"); -CREATE TABLE IF NOT EXISTS "projects" ( +CREATE TABLE "rooms" ( "id" INTEGER PRIMARY KEY, + "live_kit_room" VARCHAR NOT NULL +); + +CREATE TABLE "projects" ( + "id" INTEGER PRIMARY KEY, + "room_id" INTEGER REFERENCES rooms (id) NOT NULL, "host_user_id" INTEGER REFERENCES users (id) NOT NULL, - "unregistered" BOOLEAN NOT NULL DEFAULT false + "host_connection_id" INTEGER NOT NULL, + "host_connection_epoch" TEXT NOT NULL +); +CREATE INDEX "index_projects_on_host_connection_epoch" ON "projects" ("host_connection_epoch"); + +CREATE TABLE "worktrees" ( + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "id" INTEGER NOT NULL, + "root_name" VARCHAR NOT NULL, + "abs_path" VARCHAR NOT NULL, + "visible" BOOL NOT NULL, + "scan_id" INTEGER NOT NULL, + "is_complete" BOOL NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); + +CREATE TABLE "worktree_entries" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, + "id" INTEGER NOT NULL, + "is_dir" BOOL NOT NULL, + "path" VARCHAR NOT NULL, + "inode" INTEGER NOT NULL, + "mtime_seconds" INTEGER NOT NULL, + "mtime_nanos" INTEGER NOT NULL, + "is_symlink" BOOL NOT NULL, + "is_ignored" BOOL NOT NULL, + PRIMARY KEY(project_id, worktree_id, id), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("project_id"); +CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); + +CREATE TABLE "worktree_diagnostic_summaries" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, + "path" VARCHAR NOT NULL, + "language_server_id" INTEGER NOT NULL, + "error_count" INTEGER NOT NULL, + "warning_count" INTEGER NOT NULL, + PRIMARY KEY(project_id, worktree_id, path), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id" ON "worktree_diagnostic_summaries" ("project_id"); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); + +CREATE TABLE "language_servers" ( + "id" INTEGER NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "name" VARCHAR NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("project_id"); + +CREATE TABLE "project_collaborators" ( + "id" INTEGER PRIMARY KEY, + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_epoch" TEXT NOT NULL, + "user_id" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "is_host" BOOLEAN NOT NULL +); +CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); +CREATE INDEX "index_project_collaborators_on_connection_epoch" ON "project_collaborators" ("connection_epoch"); + +CREATE TABLE "room_participants" ( + "id" INTEGER PRIMARY KEY, + "room_id" INTEGER NOT NULL REFERENCES rooms (id), + "user_id" INTEGER NOT NULL REFERENCES users (id), + "answering_connection_id" INTEGER, + "answering_connection_epoch" TEXT, + "location_kind" INTEGER, + "location_project_id" INTEGER, + "initial_project_id" INTEGER, + "calling_user_id" INTEGER NOT NULL REFERENCES users (id), + "calling_connection_id" INTEGER NOT NULL, + "calling_connection_epoch" TEXT NOT NULL ); +CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); +CREATE INDEX "index_room_participants_on_answering_connection_epoch" ON "room_participants" ("answering_connection_epoch"); +CREATE INDEX "index_room_participants_on_calling_connection_epoch" ON "room_participants" ("calling_connection_epoch"); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql new file mode 100644 index 0000000000000000000000000000000000000000..5e8bada2f9492b91212108e0eae1b0b99d53b63a --- /dev/null +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -0,0 +1,91 @@ +CREATE TABLE IF NOT EXISTS "rooms" ( + "id" SERIAL PRIMARY KEY, + "live_kit_room" VARCHAR NOT NULL +); + +ALTER TABLE "projects" + ADD "room_id" INTEGER REFERENCES rooms (id), + ADD "host_connection_id" INTEGER, + ADD "host_connection_epoch" UUID, + DROP COLUMN "unregistered"; +CREATE INDEX "index_projects_on_host_connection_epoch" ON "projects" ("host_connection_epoch"); + +CREATE TABLE "worktrees" ( + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "id" INT8 NOT NULL, + "root_name" VARCHAR NOT NULL, + "abs_path" VARCHAR NOT NULL, + "visible" BOOL NOT NULL, + "scan_id" INT8 NOT NULL, + "is_complete" BOOL NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); + +CREATE TABLE "worktree_entries" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INT8 NOT NULL, + "id" INT8 NOT NULL, + "is_dir" BOOL NOT NULL, + "path" VARCHAR NOT NULL, + "inode" INT8 NOT NULL, + "mtime_seconds" INT8 NOT NULL, + "mtime_nanos" INTEGER NOT NULL, + "is_symlink" BOOL NOT NULL, + "is_ignored" BOOL NOT NULL, + PRIMARY KEY(project_id, worktree_id, id), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("project_id"); +CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); + +CREATE TABLE "worktree_diagnostic_summaries" ( + "project_id" INTEGER NOT NULL, + "worktree_id" INT8 NOT NULL, + "path" VARCHAR NOT NULL, + "language_server_id" INT8 NOT NULL, + "error_count" INTEGER NOT NULL, + "warning_count" INTEGER NOT NULL, + PRIMARY KEY(project_id, worktree_id, path), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE +); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id" ON "worktree_diagnostic_summaries" ("project_id"); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); + +CREATE TABLE "language_servers" ( + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "id" INT8 NOT NULL, + "name" VARCHAR NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("project_id"); + +CREATE TABLE "project_collaborators" ( + "id" SERIAL PRIMARY KEY, + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "connection_id" INTEGER NOT NULL, + "connection_epoch" UUID NOT NULL, + "user_id" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "is_host" BOOLEAN NOT NULL +); +CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); +CREATE INDEX "index_project_collaborators_on_connection_epoch" ON "project_collaborators" ("connection_epoch"); + +CREATE TABLE "room_participants" ( + "id" SERIAL PRIMARY KEY, + "room_id" INTEGER NOT NULL REFERENCES rooms (id), + "user_id" INTEGER NOT NULL REFERENCES users (id), + "answering_connection_id" INTEGER, + "answering_connection_epoch" UUID, + "location_kind" INTEGER, + "location_project_id" INTEGER, + "initial_project_id" INTEGER, + "calling_user_id" INTEGER NOT NULL REFERENCES users (id), + "calling_connection_id" INTEGER NOT NULL, + "calling_connection_epoch" UUID NOT NULL +); +CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); +CREATE INDEX "index_room_participants_on_answering_connection_epoch" ON "room_participants" ("answering_connection_epoch"); +CREATE INDEX "index_room_participants_on_calling_connection_epoch" ON "room_participants" ("calling_connection_epoch"); diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index eb750bed550a8e935a837567127f3ebfe20ccb00..921b4189e824dab10e6692991654c2de223ac096 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,6 +1,6 @@ use crate::{ auth, - db::{Invite, NewUserParams, Signup, User, UserId, WaitlistSummary}, + db::{Invite, NewSignup, NewUserParams, User, UserId, WaitlistSummary}, rpc::{self, ResultExt}, AppState, Error, Result, }; @@ -204,7 +204,7 @@ async fn create_user( #[derive(Deserialize)] struct UpdateUserParams { admin: Option, - invite_count: Option, + invite_count: Option, } async fn update_user( @@ -335,7 +335,7 @@ async fn get_user_for_invite_code( } async fn create_signup( - Json(params): Json, + Json(params): Json, Extension(app): Extension>, ) -> Result<()> { app.db.create_signup(¶ms).await?; diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 63f032f7e65d17f454d603b26c6206c81eacdf65..0c9cf33a6b94b369a9ea47e92e254ffd87e151ab 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -75,7 +75,7 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; -pub async fn create_access_token(db: &db::DefaultDb, user_id: UserId) -> Result { +pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result { let access_token = rpc::auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; diff --git a/crates/collab/src/bin/seed.rs b/crates/collab/src/bin/seed.rs index 324ccdc0c63716813115a7e645a277c314980d23..dfd2ae3a21656fa4b1e8273de748f2765612dc10 100644 --- a/crates/collab/src/bin/seed.rs +++ b/crates/collab/src/bin/seed.rs @@ -1,12 +1,8 @@ -use collab::{Error, Result}; -use db::DefaultDb; +use collab::db; +use db::{ConnectOptions, Database}; use serde::{de::DeserializeOwned, Deserialize}; use std::fmt::Write; -#[allow(unused)] -#[path = "../db.rs"] -mod db; - #[derive(Debug, Deserialize)] struct GitHubUser { id: i32, @@ -17,7 +13,7 @@ struct GitHubUser { #[tokio::main] async fn main() { let database_url = std::env::var("DATABASE_URL").expect("missing DATABASE_URL env var"); - let db = DefaultDb::new(&database_url, 5) + let db = Database::new(ConnectOptions::new(database_url)) .await .expect("failed to connect to postgres database"); let github_token = std::env::var("GITHUB_TOKEN").expect("missing GITHUB_TOKEN env var"); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 47eb803887ca513a38729288abf31c85eac70c44..915acb00eb504f792c3dcd1bef873c6db546dae1 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,214 +1,633 @@ +mod access_token; +mod contact; +mod language_server; +mod project; +mod project_collaborator; +mod room; +mod room_participant; +mod signup; +#[cfg(test)] +mod tests; +mod user; +mod worktree; +mod worktree_diagnostic_summary; +mod worktree_entry; + use crate::{Error, Result}; use anyhow::anyhow; -use axum::http::StatusCode; -use collections::HashMap; +use collections::{BTreeMap, HashMap, HashSet}; +pub use contact::Contact; +use dashmap::DashMap; use futures::StreamExt; -use serde::{Deserialize, Serialize}; -use sqlx::{ - migrate::{Migrate as _, Migration, MigrationSource}, - types::Uuid, - FromRow, +use hyper::StatusCode; +use rpc::{proto, ConnectionId}; +pub use sea_orm::ConnectOptions; +use sea_orm::{ + entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseConnection, DatabaseTransaction, + DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, + Statement, TransactionTrait, }; -use std::{path::Path, time::Duration}; -use time::{OffsetDateTime, PrimitiveDateTime}; - -#[cfg(test)] -pub type DefaultDb = Db; - -#[cfg(not(test))] -pub type DefaultDb = Db; - -pub struct Db { - pool: sqlx::Pool, +use sea_query::{Alias, Expr, OnConflict, Query}; +use serde::{Deserialize, Serialize}; +pub use signup::{Invite, NewSignup, WaitlistSummary}; +use sqlx::migrate::{Migrate, Migration, MigrationSource}; +use sqlx::Connection; +use std::ops::{Deref, DerefMut}; +use std::path::Path; +use std::time::Duration; +use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; +use tokio::sync::{Mutex, OwnedMutexGuard}; +pub use user::Model as User; + +pub struct Database { + options: ConnectOptions, + pool: DatabaseConnection, + rooms: DashMap>>, #[cfg(test)] background: Option>, #[cfg(test)] runtime: Option, + epoch: Uuid, } -macro_rules! test_support { - ($self:ident, { $($token:tt)* }) => {{ - let body = async { - $($token)* - }; - - if cfg!(test) { - #[cfg(not(test))] - unreachable!(); - +impl Database { + pub async fn new(options: ConnectOptions) -> Result { + Ok(Self { + options: options.clone(), + pool: sea_orm::Database::connect(options).await?, + rooms: DashMap::with_capacity(16384), #[cfg(test)] - if let Some(background) = $self.background.as_ref() { - background.simulate_random_delay().await; - } - + background: None, #[cfg(test)] - $self.runtime.as_ref().unwrap().block_on(body) - } else { - body.await - } - }}; -} + runtime: None, + epoch: Uuid::new_v4(), + }) + } -pub trait RowsAffected { - fn rows_affected(&self) -> u64; -} + pub async fn migrate( + &self, + migrations_path: &Path, + ignore_checksum_mismatch: bool, + ) -> anyhow::Result> { + let migrations = MigrationSource::resolve(migrations_path) + .await + .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; -#[cfg(test)] -impl RowsAffected for sqlx::sqlite::SqliteQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} + let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?; -impl RowsAffected for sqlx::postgres::PgQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} + connection.ensure_migrations_table().await?; + let applied_migrations: HashMap<_, _> = connection + .list_applied_migrations() + .await? + .into_iter() + .map(|m| (m.version, m)) + .collect(); -#[cfg(test)] -impl Db { - pub async fn new(url: &str, max_connections: u32) -> Result { - use std::str::FromStr as _; - let options = sqlx::sqlite::SqliteConnectOptions::from_str(url) - .unwrap() - .create_if_missing(true) - .shared_cache(true); - let pool = sqlx::sqlite::SqlitePoolOptions::new() - .min_connections(2) - .max_connections(max_connections) - .connect_with(options) - .await?; - Ok(Self { - pool, - background: None, - runtime: None, - }) - } + let mut new_migrations = Vec::new(); + for migration in migrations { + match applied_migrations.get(&migration.version) { + Some(applied_migration) => { + if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch + { + Err(anyhow!( + "checksum mismatch for applied migration {}", + migration.description + ))?; + } + } + None => { + let elapsed = connection.apply(&migration).await?; + new_migrations.push((migration, elapsed)); + } + } + } - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - test_support!(self, { - let query = " - SELECT users.* - FROM users - WHERE users.id IN (SELECT value from json_each($1)) - "; - Ok(sqlx::query_as(query) - .bind(&serde_json::json!(ids)) - .fetch_all(&self.pool) - .await?) - }) + Ok(new_migrations) } - pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - test_support!(self, { - let query = " - SELECT metrics_id - FROM users - WHERE id = $1 - "; - Ok(sqlx::query_scalar(query) - .bind(id) - .fetch_one(&self.pool) - .await?) + pub async fn clear_stale_data(&self) -> Result<()> { + self.transaction(|tx| async move { + project_collaborator::Entity::delete_many() + .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch)) + .exec(&*tx) + .await?; + room_participant::Entity::delete_many() + .filter( + room_participant::Column::AnsweringConnectionEpoch + .ne(self.epoch) + .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch)), + ) + .exec(&*tx) + .await?; + project::Entity::delete_many() + .filter(project::Column::HostConnectionEpoch.ne(self.epoch)) + .exec(&*tx) + .await?; + room::Entity::delete_many() + .filter( + room::Column::Id.not_in_subquery( + Query::select() + .column(room_participant::Column::RoomId) + .from(room_participant::Entity) + .distinct() + .to_owned(), + ), + ) + .exec(&*tx) + .await?; + Ok(()) }) + .await } + // users + pub async fn create_user( &self, email_address: &str, admin: bool, params: NewUserParams, ) -> Result { - test_support!(self, { - let query = " - INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login - RETURNING id, metrics_id - "; + self.transaction(|tx| async { + let tx = tx; + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(Some(email_address.into())), + github_login: ActiveValue::set(params.github_login.clone()), + github_user_id: ActiveValue::set(Some(params.github_user_id)), + admin: ActiveValue::set(admin), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .on_conflict( + OnConflict::column(user::Column::GithubLogin) + .update_column(user::Column::GithubLogin) + .to_owned(), + ) + .exec_with_returning(&*tx) + .await?; - let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) - .bind(email_address) - .bind(params.github_login) - .bind(params.github_user_id) - .bind(admin) - .bind(Uuid::new_v4().to_string()) - .fetch_one(&self.pool) - .await?; Ok(NewUserResult { - user_id, - metrics_id, + user_id: user.id, + metrics_id: user.metrics_id.to_string(), signup_device_id: None, inviting_user_id: None, }) }) + .await } - pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result> { - unimplemented!() + pub async fn get_user_by_id(&self, id: UserId) -> Result> { + self.transaction(|tx| async move { Ok(user::Entity::find_by_id(id).one(&*tx).await?) }) + .await } - pub async fn create_user_from_invite( + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { + self.transaction(|tx| async { + let tx = tx; + Ok(user::Entity::find() + .filter(user::Column::Id.is_in(ids.iter().copied())) + .all(&*tx) + .await?) + }) + .await + } + + pub async fn get_user_by_github_account( &self, - _invite: &Invite, - _user: NewUserParams, - ) -> Result> { - unimplemented!() + github_login: &str, + github_user_id: Option, + ) -> Result> { + self.transaction(|tx| async move { + let tx = &*tx; + if let Some(github_user_id) = github_user_id { + if let Some(user_by_github_user_id) = user::Entity::find() + .filter(user::Column::GithubUserId.eq(github_user_id)) + .one(tx) + .await? + { + let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); + user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); + Ok(Some(user_by_github_user_id.update(tx).await?)) + } else if let Some(user_by_github_login) = user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(tx) + .await? + { + let mut user_by_github_login = user_by_github_login.into_active_model(); + user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); + Ok(Some(user_by_github_login.update(tx).await?)) + } else { + Ok(None) + } + } else { + Ok(user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(tx) + .await?) + } + }) + .await } - pub async fn create_signup(&self, _signup: &Signup) -> Result<()> { - unimplemented!() + pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { + self.transaction(|tx| async move { + Ok(user::Entity::find() + .order_by_asc(user::Column::GithubLogin) + .limit(limit as u64) + .offset(page as u64 * limit as u64) + .all(&*tx) + .await?) + }) + .await } - pub async fn create_invite_from_code( + pub async fn get_users_with_no_invites( &self, - _code: &str, - _email_address: &str, - _device_id: Option<&str>, - ) -> Result { - unimplemented!() + invited_by_another_user: bool, + ) -> Result> { + self.transaction(|tx| async move { + Ok(user::Entity::find() + .filter( + user::Column::InviteCount + .eq(0) + .and(if invited_by_another_user { + user::Column::InviterId.is_not_null() + } else { + user::Column::InviterId.is_null() + }), + ) + .all(&*tx) + .await?) + }) + .await } - pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> { - unimplemented!() + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + MetricsId, + } + + self.transaction(|tx| async move { + let metrics_id: Uuid = user::Entity::find_by_id(id) + .select_only() + .column(user::Column::MetricsId) + .into_values::<_, QueryAs>() + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("could not find user"))?; + Ok(metrics_id.to_string()) + }) + .await + } + + pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { + self.transaction(|tx| async move { + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .set(user::ActiveModel { + admin: ActiveValue::set(is_admin), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(()) + }) + .await } -} -impl Db { - pub async fn new(url: &str, max_connections: u32) -> Result { - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(max_connections) - .connect(url) + pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { + self.transaction(|tx| async move { + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .set(user::ActiveModel { + connected_once: ActiveValue::set(connected_once), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn destroy_user(&self, id: UserId) -> Result<()> { + self.transaction(|tx| async move { + access_token::Entity::delete_many() + .filter(access_token::Column::UserId.eq(id)) + .exec(&*tx) + .await?; + user::Entity::delete_by_id(id).exec(&*tx).await?; + Ok(()) + }) + .await + } + + // contacts + + pub async fn get_contacts(&self, user_id: UserId) -> Result> { + #[derive(Debug, FromQueryResult)] + struct ContactWithUserBusyStatuses { + user_id_a: UserId, + user_id_b: UserId, + a_to_b: bool, + accepted: bool, + should_notify: bool, + user_a_busy: bool, + user_b_busy: bool, + } + + self.transaction(|tx| async move { + let user_a_participant = Alias::new("user_a_participant"); + let user_b_participant = Alias::new("user_b_participant"); + let mut db_contacts = contact::Entity::find() + .column_as( + Expr::tbl(user_a_participant.clone(), room_participant::Column::Id) + .is_not_null(), + "user_a_busy", + ) + .column_as( + Expr::tbl(user_b_participant.clone(), room_participant::Column::Id) + .is_not_null(), + "user_b_busy", + ) + .filter( + contact::Column::UserIdA + .eq(user_id) + .or(contact::Column::UserIdB.eq(user_id)), + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserARoomParticipant.def(), + user_a_participant, + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserBRoomParticipant.def(), + user_b_participant, + ) + .into_model::() + .stream(&*tx) + .await?; + + let mut contacts = Vec::new(); + while let Some(db_contact) = db_contacts.next().await { + let db_contact = db_contact?; + if db_contact.user_id_a == user_id { + if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_b, + should_notify: db_contact.should_notify && db_contact.a_to_b, + busy: db_contact.user_b_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_b, + }) + } else { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_b, + should_notify: db_contact.should_notify, + }); + } + } else if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_a, + should_notify: db_contact.should_notify && !db_contact.a_to_b, + busy: db_contact.user_a_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_a, + should_notify: db_contact.should_notify, + }); + } else { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_a, + }); + } + } + + contacts.sort_unstable_by_key(|contact| contact.user_id()); + + Ok(contacts) + }) + .await + } + + pub async fn is_user_busy(&self, user_id: UserId) -> Result { + self.transaction(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::UserId.eq(user_id)) + .one(&*tx) + .await?; + Ok(participant.is_some()) + }) + .await + } + + pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { + self.transaction(|tx| async move { + let (id_a, id_b) = if user_id_1 < user_id_2 { + (user_id_1, user_id_2) + } else { + (user_id_2, user_id_1) + }; + + Ok(contact::Entity::find() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::Accepted.eq(true)), + ) + .one(&*tx) + .await? + .is_some()) + }) + .await + } + + pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + self.transaction(|tx| async move { + let (id_a, id_b, a_to_b) = if sender_id < receiver_id { + (sender_id, receiver_id, true) + } else { + (receiver_id, sender_id, false) + }; + + let rows_affected = contact::Entity::insert(contact::ActiveModel { + user_id_a: ActiveValue::set(id_a), + user_id_b: ActiveValue::set(id_b), + a_to_b: ActiveValue::set(a_to_b), + accepted: ActiveValue::set(false), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB]) + .values([ + (contact::Column::Accepted, true.into()), + (contact::Column::ShouldNotify, false.into()), + ]) + .action_and_where( + contact::Column::Accepted.eq(false).and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::UserIdA.eq(id_b)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::UserIdA.eq(id_a))), + ), + ) + .to_owned(), + ) + .exec_without_returning(&*tx) .await?; - Ok(Self { - pool, - #[cfg(test)] - background: None, - #[cfg(test)] - runtime: None, + + if rows_affected == 1 { + Ok(()) + } else { + Err(anyhow!("contact already requested"))? + } }) + .await } - #[cfg(test)] - pub fn teardown(&self, url: &str) { - self.runtime.as_ref().unwrap().block_on(async { - use util::ResultExt; - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid(); - "; - sqlx::query(query).execute(&self.pool).await.log_err(); - self.pool.close().await; - ::drop_database(url) - .await - .log_err(); + pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { + self.transaction(|tx| async move { + let (id_a, id_b) = if responder_id < requester_id { + (responder_id, requester_id) + } else { + (requester_id, responder_id) + }; + + let result = contact::Entity::delete_many() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)), + ) + .exec(&*tx) + .await?; + + if result.rows_affected == 1 { + Ok(()) + } else { + Err(anyhow!("no such contact"))? + } }) + .await + } + + pub async fn dismiss_contact_notification( + &self, + user_id: UserId, + contact_user_id: UserId, + ) -> Result<()> { + self.transaction(|tx| async move { + let (id_a, id_b, a_to_b) = if user_id < contact_user_id { + (user_id, contact_user_id, true) + } else { + (contact_user_id, user_id, false) + }; + + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + should_notify: ActiveValue::set(false), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::Accepted.eq(true)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::Accepted.eq(false))), + ), + ) + .exec(&*tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("no such contact request"))? + } else { + Ok(()) + } + }) + .await + } + + pub async fn respond_to_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + accept: bool, + ) -> Result<()> { + self.transaction(|tx| async move { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + let rows_affected = if accept { + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + accepted: ActiveValue::set(true), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)), + ) + .exec(&*tx) + .await?; + result.rows_affected + } else { + let result = contact::Entity::delete_many() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)) + .and(contact::Column::Accepted.eq(false)), + ) + .exec(&*tx) + .await?; + + result.rows_affected + }; + + if rows_affected == 1 { + Ok(()) + } else { + Err(anyhow!("no such contact request"))? + } + }) + .await + } + + pub fn fuzzy_like_string(string: &str) -> String { + let mut result = String::with_capacity(string.len() * 2 + 1); + for c in string.chars() { + if c.is_alphanumeric() { + result.push('%'); + result.push(c); + } + } + result.push('%'); + result } pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - test_support!(self, { + self.transaction(|tx| async { + let tx = tx; let like_string = Self::fuzzy_like_string(name_query); let query = " SELECT users.* @@ -217,71 +636,195 @@ impl Db { ORDER BY github_login <-> $2 LIMIT $3 "; - Ok(sqlx::query_as(query) - .bind(like_string) - .bind(name_query) - .bind(limit as i32) - .fetch_all(&self.pool) + + Ok(user::Entity::find() + .from_raw_sql(Statement::from_sql_and_values( + self.pool.get_database_backend(), + query.into(), + vec![like_string.into(), name_query.into(), limit.into()], + )) + .all(&*tx) .await?) }) + .await } - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - test_support!(self, { - let query = " - SELECT users.* - FROM users - WHERE users.id = ANY ($1) - "; - Ok(sqlx::query_as(query) - .bind(&ids.into_iter().map(|id| id.0).collect::>()) - .fetch_all(&self.pool) - .await?) + // signups + + pub async fn create_signup(&self, signup: &NewSignup) -> Result<()> { + self.transaction(|tx| async move { + signup::Entity::insert(signup::ActiveModel { + email_address: ActiveValue::set(signup.email_address.clone()), + email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), + email_confirmation_sent: ActiveValue::set(false), + platform_mac: ActiveValue::set(signup.platform_mac), + platform_windows: ActiveValue::set(signup.platform_windows), + platform_linux: ActiveValue::set(signup.platform_linux), + platform_unknown: ActiveValue::set(false), + editor_features: ActiveValue::set(Some(signup.editor_features.clone())), + programming_languages: ActiveValue::set(Some(signup.programming_languages.clone())), + device_id: ActiveValue::set(signup.device_id.clone()), + added_to_mailing_list: ActiveValue::set(signup.added_to_mailing_list), + ..Default::default() + }) + .on_conflict( + OnConflict::column(signup::Column::EmailAddress) + .update_column(signup::Column::EmailAddress) + .to_owned(), + ) + .exec(&*tx) + .await?; + Ok(()) }) + .await } - pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - test_support!(self, { + pub async fn get_waitlist_summary(&self) -> Result { + self.transaction(|tx| async move { let query = " - SELECT metrics_id::text - FROM users - WHERE id = $1 + SELECT + COUNT(*) as count, + COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, + COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count, + COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count, + COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count + FROM ( + SELECT * + FROM signups + WHERE + NOT email_confirmation_sent + ) AS unsent "; - Ok(sqlx::query_scalar(query) - .bind(id) - .fetch_one(&self.pool) + Ok( + WaitlistSummary::find_by_statement(Statement::from_sql_and_values( + self.pool.get_database_backend(), + query.into(), + vec![], + )) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("invalid result"))?, + ) + }) + .await + } + + pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { + let emails = invites + .iter() + .map(|s| s.email_address.as_str()) + .collect::>(); + self.transaction(|tx| async { + let tx = tx; + signup::Entity::update_many() + .filter(signup::Column::EmailAddress.is_in(emails.iter().copied())) + .set(signup::ActiveModel { + email_confirmation_sent: ActiveValue::set(true), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn get_unsent_invites(&self, count: usize) -> Result> { + self.transaction(|tx| async move { + Ok(signup::Entity::find() + .select_only() + .column(signup::Column::EmailAddress) + .column(signup::Column::EmailConfirmationCode) + .filter( + signup::Column::EmailConfirmationSent.eq(false).and( + signup::Column::PlatformMac + .eq(true) + .or(signup::Column::PlatformUnknown.eq(true)), + ), + ) + .order_by_asc(signup::Column::CreatedAt) + .limit(count as u64) + .into_model() + .all(&*tx) .await?) }) + .await } - pub async fn create_user( + // invite codes + + pub async fn create_invite_from_code( &self, + code: &str, email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result { - test_support!(self, { - let query = " - INSERT INTO users (email_address, github_login, github_user_id, admin) - VALUES ($1, $2, $3, $4) - ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login - RETURNING id, metrics_id::text - "; + device_id: Option<&str>, + ) -> Result { + self.transaction(|tx| async move { + let existing_user = user::Entity::find() + .filter(user::Column::EmailAddress.eq(email_address)) + .one(&*tx) + .await?; + + if existing_user.is_some() { + Err(anyhow!("email address is already in use"))?; + } - let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) - .bind(email_address) - .bind(params.github_login) - .bind(params.github_user_id) - .bind(admin) - .fetch_one(&self.pool) + let inviting_user_with_invites = match user::Entity::find() + .filter( + user::Column::InviteCode + .eq(code) + .and(user::Column::InviteCount.gt(0)), + ) + .one(&*tx) + .await? + { + Some(inviting_user) => inviting_user, + None => { + return Err(Error::Http( + StatusCode::UNAUTHORIZED, + "unable to find an invite code with invites remaining".to_string(), + ))? + } + }; + user::Entity::update_many() + .filter( + user::Column::Id + .eq(inviting_user_with_invites.id) + .and(user::Column::InviteCount.gt(0)), + ) + .col_expr( + user::Column::InviteCount, + Expr::col(user::Column::InviteCount).sub(1), + ) + .exec(&*tx) .await?; - Ok(NewUserResult { - user_id, - metrics_id, - signup_device_id: None, - inviting_user_id: None, + + let signup = signup::Entity::insert(signup::ActiveModel { + email_address: ActiveValue::set(email_address.into()), + email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), + email_confirmation_sent: ActiveValue::set(false), + inviting_user_id: ActiveValue::set(Some(inviting_user_with_invites.id)), + platform_linux: ActiveValue::set(false), + platform_mac: ActiveValue::set(false), + platform_windows: ActiveValue::set(false), + platform_unknown: ActiveValue::set(true), + device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())), + ..Default::default() + }) + .on_conflict( + OnConflict::column(signup::Column::EmailAddress) + .update_column(signup::Column::InvitingUserId) + .to_owned(), + ) + .exec_with_returning(&*tx) + .await?; + + Ok(Invite { + email_address: signup.email_address, + email_confirmation_code: signup.email_confirmation_code, }) }) + .await } pub async fn create_user_from_invite( @@ -289,829 +832,1248 @@ impl Db { invite: &Invite, user: NewUserParams, ) -> Result> { - test_support!(self, { - let mut tx = self.pool.begin().await?; - - let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( - i32, - Option, - Option, - Option, - ) = sqlx::query_as( - " - SELECT id, user_id, inviting_user_id, device_id - FROM signups - WHERE - email_address = $1 AND - email_confirmation_code = $2 - ", - ) - .bind(&invite.email_address) - .bind(&invite.email_confirmation_code) - .fetch_optional(&mut tx) - .await? - .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; + self.transaction(|tx| async { + let tx = tx; + let signup = signup::Entity::find() + .filter( + signup::Column::EmailAddress + .eq(invite.email_address.as_str()) + .and( + signup::Column::EmailConfirmationCode + .eq(invite.email_confirmation_code.as_str()), + ), + ) + .one(&*tx) + .await? + .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - if existing_user_id.is_some() { + if signup.user_id.is_some() { return Ok(None); } - let (user_id, metrics_id): (UserId, String) = sqlx::query_as( - " - INSERT INTO users - (email_address, github_login, github_user_id, admin, invite_count, invite_code) - VALUES - ($1, $2, $3, FALSE, $4, $5) - ON CONFLICT (github_login) DO UPDATE SET - email_address = excluded.email_address, - github_user_id = excluded.github_user_id, - admin = excluded.admin - RETURNING id, metrics_id::text - ", - ) - .bind(&invite.email_address) - .bind(&user.github_login) - .bind(&user.github_user_id) - .bind(&user.invite_count) - .bind(random_invite_code()) - .fetch_one(&mut tx) - .await?; + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(Some(invite.email_address.clone())), + github_login: ActiveValue::set(user.github_login.clone()), + github_user_id: ActiveValue::set(Some(user.github_user_id)), + admin: ActiveValue::set(false), + invite_count: ActiveValue::set(user.invite_count), + invite_code: ActiveValue::set(Some(random_invite_code())), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .on_conflict( + OnConflict::column(user::Column::GithubLogin) + .update_columns([ + user::Column::EmailAddress, + user::Column::GithubUserId, + user::Column::Admin, + ]) + .to_owned(), + ) + .exec_with_returning(&*tx) + .await?; + + let mut signup = signup.into_active_model(); + signup.user_id = ActiveValue::set(Some(user.id)); + let signup = signup.update(&*tx).await?; + + if let Some(inviting_user_id) = signup.inviting_user_id { + contact::Entity::insert(contact::ActiveModel { + user_id_a: ActiveValue::set(inviting_user_id), + user_id_b: ActiveValue::set(user.id), + a_to_b: ActiveValue::set(true), + should_notify: ActiveValue::set(true), + accepted: ActiveValue::set(true), + ..Default::default() + }) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec_without_returning(&*tx) + .await?; + } + + Ok(Some(NewUserResult { + user_id: user.id, + metrics_id: user.metrics_id.to_string(), + inviting_user_id: signup.inviting_user_id, + signup_device_id: signup.device_id, + })) + }) + .await + } + + pub async fn set_invite_count_for_user(&self, id: UserId, count: i32) -> Result<()> { + self.transaction(|tx| async move { + if count > 0 { + user::Entity::update_many() + .filter( + user::Column::Id + .eq(id) + .and(user::Column::InviteCode.is_null()), + ) + .set(user::ActiveModel { + invite_code: ActiveValue::set(Some(random_invite_code())), + ..Default::default() + }) + .exec(&*tx) + .await?; + } + + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .set(user::ActiveModel { + invite_count: ActiveValue::set(count), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(()) + }) + .await + } + + pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { + self.transaction(|tx| async move { + match user::Entity::find_by_id(id).one(&*tx).await? { + Some(user) if user.invite_code.is_some() => { + Ok(Some((user.invite_code.unwrap(), user.invite_count))) + } + _ => Ok(None), + } + }) + .await + } + + pub async fn get_user_for_invite_code(&self, code: &str) -> Result { + self.transaction(|tx| async move { + user::Entity::find() + .filter(user::Column::InviteCode.eq(code)) + .one(&*tx) + .await? + .ok_or_else(|| { + Error::Http( + StatusCode::NOT_FOUND, + "that invite code does not exist".to_string(), + ) + }) + }) + .await + } + + // rooms + + pub async fn incoming_call_for_user( + &self, + user_id: UserId, + ) -> Result> { + self.transaction(|tx| async move { + let pending_participant = room_participant::Entity::find() + .filter( + room_participant::Column::UserId + .eq(user_id) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .one(&*tx) + .await?; + + if let Some(pending_participant) = pending_participant { + let room = self.get_room(pending_participant.room_id, &tx).await?; + Ok(Self::build_incoming_call(&room, user_id)) + } else { + Ok(None) + } + }) + .await + } + + pub async fn create_room( + &self, + user_id: UserId, + connection_id: ConnectionId, + live_kit_room: &str, + ) -> Result> { + self.room_transaction(|tx| async move { + let room = room::ActiveModel { + live_kit_room: ActiveValue::set(live_kit_room.into()), + ..Default::default() + } + .insert(&*tx) + .await?; + let room_id = room.id; + + room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(user_id), + answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), + answering_connection_epoch: ActiveValue::set(Some(self.epoch)), + calling_user_id: ActiveValue::set(user_id), + calling_connection_id: ActiveValue::set(connection_id.0 as i32), + calling_connection_epoch: ActiveValue::set(self.epoch), + ..Default::default() + } + .insert(&*tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, room)) + }) + .await + } + + pub async fn call( + &self, + room_id: RoomId, + calling_user_id: UserId, + calling_connection_id: ConnectionId, + called_user_id: UserId, + initial_project_id: Option, + ) -> Result> { + self.room_transaction(|tx| async move { + room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(called_user_id), + calling_user_id: ActiveValue::set(calling_user_id), + calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32), + calling_connection_epoch: ActiveValue::set(self.epoch), + initial_project_id: ActiveValue::set(initial_project_id), + ..Default::default() + } + .insert(&*tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + let incoming_call = Self::build_incoming_call(&room, called_user_id) + .ok_or_else(|| anyhow!("failed to build incoming call"))?; + Ok((room_id, (room, incoming_call))) + }) + .await + } + + pub async fn call_failed( + &self, + room_id: RoomId, + called_user_id: UserId, + ) -> Result> { + self.room_transaction(|tx| async move { + room_participant::Entity::delete_many() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::UserId.eq(called_user_id)), + ) + .exec(&*tx) + .await?; + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, room)) + }) + .await + } + + pub async fn decline_call( + &self, + expected_room_id: Option, + user_id: UserId, + ) -> Result> { + self.room_transaction(|tx| async move { + let participant = room_participant::Entity::find() + .filter( + room_participant::Column::UserId + .eq(user_id) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("could not decline call"))?; + let room_id = participant.room_id; + + if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { + return Err(anyhow!("declining call on unexpected room"))?; + } + + room_participant::Entity::delete(participant.into_active_model()) + .exec(&*tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, room)) + }) + .await + } + + pub async fn cancel_call( + &self, + expected_room_id: Option, + calling_connection_id: ConnectionId, + called_user_id: UserId, + ) -> Result> { + self.room_transaction(|tx| async move { + let participant = room_participant::Entity::find() + .filter( + room_participant::Column::UserId + .eq(called_user_id) + .and( + room_participant::Column::CallingConnectionId + .eq(calling_connection_id.0 as i32), + ) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("could not cancel call"))?; + let room_id = participant.room_id; + if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { + return Err(anyhow!("canceling call on unexpected room"))?; + } + + room_participant::Entity::delete(participant.into_active_model()) + .exec(&*tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, room)) + }) + .await + } + + pub async fn join_room( + &self, + room_id: RoomId, + user_id: UserId, + connection_id: ConnectionId, + ) -> Result> { + self.room_transaction(|tx| async move { + let result = room_participant::Entity::update_many() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::UserId.eq(user_id)) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .set(room_participant::ActiveModel { + answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), + answering_connection_epoch: ActiveValue::set(Some(self.epoch)), + ..Default::default() + }) + .exec(&*tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("room does not exist or was already joined"))? + } else { + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, room)) + } + }) + .await + } + + pub async fn leave_room(&self, connection_id: ConnectionId) -> Result> { + self.room_transaction(|tx| async move { + let leaving_participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + .one(&*tx) + .await?; + + if let Some(leaving_participant) = leaving_participant { + // Leave room. + let room_id = leaving_participant.room_id; + room_participant::Entity::delete_by_id(leaving_participant.id) + .exec(&*tx) + .await?; + + // Cancel pending calls initiated by the leaving user. + let called_participants = room_participant::Entity::find() + .filter( + room_participant::Column::CallingConnectionId + .eq(connection_id.0) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .all(&*tx) + .await?; + room_participant::Entity::delete_many() + .filter( + room_participant::Column::Id + .is_in(called_participants.iter().map(|participant| participant.id)), + ) + .exec(&*tx) + .await?; + let canceled_calls_to_user_ids = called_participants + .into_iter() + .map(|participant| participant.user_id) + .collect(); + + // Detect left projects. + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryProjectIds { + ProjectId, + } + let project_ids: Vec = project_collaborator::Entity::find() + .select_only() + .column_as( + project_collaborator::Column::ProjectId, + QueryProjectIds::ProjectId, + ) + .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0)) + .into_values::<_, QueryProjectIds>() + .all(&*tx) + .await?; + let mut left_projects = HashMap::default(); + let mut collaborators = project_collaborator::Entity::find() + .filter(project_collaborator::Column::ProjectId.is_in(project_ids)) + .stream(&*tx) + .await?; + while let Some(collaborator) = collaborators.next().await { + let collaborator = collaborator?; + let left_project = + left_projects + .entry(collaborator.project_id) + .or_insert(LeftProject { + id: collaborator.project_id, + host_user_id: Default::default(), + connection_ids: Default::default(), + host_connection_id: Default::default(), + }); + + let collaborator_connection_id = + ConnectionId(collaborator.connection_id as u32); + if collaborator_connection_id != connection_id { + left_project.connection_ids.push(collaborator_connection_id); + } + + if collaborator.is_host { + left_project.host_user_id = collaborator.user_id; + left_project.host_connection_id = + ConnectionId(collaborator.connection_id as u32); + } + } + drop(collaborators); + + // Leave projects. + project_collaborator::Entity::delete_many() + .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0)) + .exec(&*tx) + .await?; - sqlx::query( - " - UPDATE signups - SET user_id = $1 - WHERE id = $2 - ", - ) - .bind(&user_id) - .bind(&signup_id) - .execute(&mut tx) - .await?; + // Unshare projects. + project::Entity::delete_many() + .filter( + project::Column::RoomId + .eq(room_id) + .and(project::Column::HostConnectionId.eq(connection_id.0)), + ) + .exec(&*tx) + .await?; - if let Some(inviting_user_id) = inviting_user_id { - sqlx::query( - " - INSERT INTO contacts - (user_id_a, user_id_b, a_to_b, should_notify, accepted) - VALUES - ($1, $2, TRUE, TRUE, TRUE) - ON CONFLICT DO NOTHING - ", - ) - .bind(inviting_user_id) - .bind(user_id) - .execute(&mut tx) - .await?; - } + let room = self.get_room(room_id, &tx).await?; + if room.participants.is_empty() { + room::Entity::delete_by_id(room_id).exec(&*tx).await?; + } - tx.commit().await?; - Ok(Some(NewUserResult { - user_id, - metrics_id, - inviting_user_id, - signup_device_id, - })) - }) - } + let left_room = LeftRoom { + room, + left_projects, + canceled_calls_to_user_ids, + }; - pub async fn create_signup(&self, signup: &Signup) -> Result<()> { - test_support!(self, { - sqlx::query( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - editor_features, - programming_languages, - device_id, - added_to_mailing_list - ) - VALUES - ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8, $9) - ON CONFLICT (email_address) DO UPDATE SET - email_address = excluded.email_address - RETURNING id - ", - ) - .bind(&signup.email_address) - .bind(&random_email_confirmation_code()) - .bind(&signup.platform_linux) - .bind(&signup.platform_mac) - .bind(&signup.platform_windows) - .bind(&signup.editor_features) - .bind(&signup.programming_languages) - .bind(&signup.device_id) - .bind(&signup.added_to_mailing_list) - .execute(&self.pool) - .await?; - Ok(()) + if left_room.room.participants.is_empty() { + self.rooms.remove(&room_id); + } + + Ok((room_id, left_room)) + } else { + Err(anyhow!("could not leave room"))? + } }) + .await } - pub async fn create_invite_from_code( + pub async fn update_room_participant_location( &self, - code: &str, - email_address: &str, - device_id: Option<&str>, - ) -> Result { - test_support!(self, { - let mut tx = self.pool.begin().await?; - - let existing_user: Option = sqlx::query_scalar( - " - SELECT id - FROM users - WHERE email_address = $1 - ", - ) - .bind(email_address) - .fetch_optional(&mut tx) - .await?; - if existing_user.is_some() { - Err(anyhow!("email address is already in use"))?; + room_id: RoomId, + connection_id: ConnectionId, + location: proto::ParticipantLocation, + ) -> Result> { + self.room_transaction(|tx| async { + let tx = tx; + let location_kind; + let location_project_id; + match location + .variant + .as_ref() + .ok_or_else(|| anyhow!("invalid location"))? + { + proto::participant_location::Variant::SharedProject(project) => { + location_kind = 0; + location_project_id = Some(ProjectId::from_proto(project.id)); + } + proto::participant_location::Variant::UnsharedProject(_) => { + location_kind = 1; + location_project_id = None; + } + proto::participant_location::Variant::External(_) => { + location_kind = 2; + location_project_id = None; + } } - let inviting_user_id_with_invites: Option = sqlx::query_scalar( - " - UPDATE users - SET invite_count = invite_count - 1 - WHERE invite_code = $1 AND invite_count > 0 - RETURNING id - ", - ) - .bind(code) - .fetch_optional(&mut tx) - .await?; - - let Some(inviter_id) = inviting_user_id_with_invites else { - return Err(Error::Http( - StatusCode::UNAUTHORIZED, - "unable to find an invite code with invites remaining".to_string(), - )); - }; - - let email_confirmation_code: String = sqlx::query_scalar( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - inviting_user_id, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - device_id + let result = room_participant::Entity::update_many() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)), ) - VALUES - ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4) - ON CONFLICT (email_address) - DO UPDATE SET - inviting_user_id = excluded.inviting_user_id - RETURNING email_confirmation_code - ", - ) - .bind(&email_address) - .bind(&random_email_confirmation_code()) - .bind(&inviter_id) - .bind(&device_id) - .fetch_one(&mut tx) - .await?; - - tx.commit().await?; + .set(room_participant::ActiveModel { + location_kind: ActiveValue::set(Some(location_kind)), + location_project_id: ActiveValue::set(location_project_id), + ..Default::default() + }) + .exec(&*tx) + .await?; - Ok(Invite { - email_address: email_address.into(), - email_confirmation_code, - }) + if result.rows_affected == 1 { + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, room)) + } else { + Err(anyhow!("could not update room participant location"))? + } }) + .await } - pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { - test_support!(self, { - let emails = invites + fn build_incoming_call( + room: &proto::Room, + called_user_id: UserId, + ) -> Option { + let pending_participant = room + .pending_participants + .iter() + .find(|participant| participant.user_id == called_user_id.to_proto())?; + + Some(proto::IncomingCall { + room_id: room.id, + calling_user_id: pending_participant.calling_user_id, + participant_user_ids: room + .participants .iter() - .map(|s| s.email_address.as_str()) - .collect::>(); - sqlx::query( - " - UPDATE signups - SET email_confirmation_sent = TRUE - WHERE email_address = ANY ($1) - ", - ) - .bind(&emails) - .execute(&self.pool) - .await?; - Ok(()) + .map(|participant| participant.user_id) + .collect(), + initial_project: room.participants.iter().find_map(|participant| { + let initial_project_id = pending_participant.initial_project_id?; + participant + .projects + .iter() + .find(|project| project.id == initial_project_id) + .cloned() + }), }) } -} - -impl Db -where - D: sqlx::Database + sqlx::migrate::MigrateDatabase, - D::Connection: sqlx::migrate::Migrate, - for<'a> >::Arguments: sqlx::IntoArguments<'a, D>, - for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>, - for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>, - D::QueryResult: RowsAffected, - String: sqlx::Type, - i32: sqlx::Type, - i64: sqlx::Type, - bool: sqlx::Type, - str: sqlx::Type, - Uuid: sqlx::Type, - sqlx::types::Json: sqlx::Type, - OffsetDateTime: sqlx::Type, - PrimitiveDateTime: sqlx::Type, - usize: sqlx::ColumnIndex, - for<'a> &'a str: sqlx::ColumnIndex, - for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Option: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>, -{ - pub async fn migrate( - &self, - migrations_path: &Path, - ignore_checksum_mismatch: bool, - ) -> anyhow::Result> { - let migrations = MigrationSource::resolve(migrations_path) - .await - .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - - let mut conn = self.pool.acquire().await?; - conn.ensure_migrations_table().await?; - let applied_migrations: HashMap<_, _> = conn - .list_applied_migrations() + async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result { + let db_room = room::Entity::find_by_id(room_id) + .one(tx) .await? - .into_iter() - .map(|m| (m.version, m)) - .collect(); + .ok_or_else(|| anyhow!("could not find room"))?; - let mut new_migrations = Vec::new(); - for migration in migrations { - match applied_migrations.get(&migration.version) { - Some(applied_migration) => { - if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch - { - Err(anyhow!( - "checksum mismatch for applied migration {}", - migration.description - ))?; + let mut db_participants = db_room + .find_related(room_participant::Entity) + .stream(tx) + .await?; + let mut participants = HashMap::default(); + let mut pending_participants = Vec::new(); + while let Some(db_participant) = db_participants.next().await { + let db_participant = db_participant?; + if let Some(answering_connection_id) = db_participant.answering_connection_id { + let location = match ( + db_participant.location_kind, + db_participant.location_project_id, + ) { + (Some(0), Some(project_id)) => { + Some(proto::participant_location::Variant::SharedProject( + proto::participant_location::SharedProject { + id: project_id.to_proto(), + }, + )) } - } - None => { - let elapsed = conn.apply(&migration).await?; - new_migrations.push((migration, elapsed)); - } + (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject( + Default::default(), + )), + _ => Some(proto::participant_location::Variant::External( + Default::default(), + )), + }; + participants.insert( + answering_connection_id, + proto::Participant { + user_id: db_participant.user_id.to_proto(), + peer_id: answering_connection_id as u32, + projects: Default::default(), + location: Some(proto::ParticipantLocation { variant: location }), + }, + ); + } else { + pending_participants.push(proto::PendingParticipant { + user_id: db_participant.user_id.to_proto(), + calling_user_id: db_participant.calling_user_id.to_proto(), + initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()), + }); } } + drop(db_participants); - Ok(new_migrations) - } + let mut db_projects = db_room + .find_related(project::Entity) + .find_with_related(worktree::Entity) + .stream(tx) + .await?; - pub fn fuzzy_like_string(string: &str) -> String { - let mut result = String::with_capacity(string.len() * 2 + 1); - for c in string.chars() { - if c.is_alphanumeric() { - result.push('%'); - result.push(c); + while let Some(row) = db_projects.next().await { + let (db_project, db_worktree) = row?; + if let Some(participant) = participants.get_mut(&db_project.host_connection_id) { + let project = if let Some(project) = participant + .projects + .iter_mut() + .find(|project| project.id == db_project.id.to_proto()) + { + project + } else { + participant.projects.push(proto::ParticipantProject { + id: db_project.id.to_proto(), + worktree_root_names: Default::default(), + }); + participant.projects.last_mut().unwrap() + }; + + if let Some(db_worktree) = db_worktree { + project.worktree_root_names.push(db_worktree.root_name); + } } } - result.push('%'); - result - } - - // users - pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { - test_support!(self, { - let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; - Ok(sqlx::query_as(query) - .bind(limit as i32) - .bind((page * limit) as i32) - .fetch_all(&self.pool) - .await?) + Ok(proto::Room { + id: db_room.id.to_proto(), + live_kit_room: db_room.live_kit_room, + participants: participants.into_values().collect(), + pending_participants, }) } - pub async fn get_user_by_id(&self, id: UserId) -> Result> { - test_support!(self, { - let query = " - SELECT users.* - FROM users - WHERE id = $1 - LIMIT 1 - "; - Ok(sqlx::query_as(query) - .bind(&id) - .fetch_optional(&self.pool) - .await?) - }) - } + // projects - pub async fn get_users_with_no_invites( - &self, - invited_by_another_user: bool, - ) -> Result> { - test_support!(self, { - let query = format!( - " - SELECT users.* - FROM users - WHERE invite_count = 0 - AND inviter_id IS{} NULL - ", - if invited_by_another_user { " NOT" } else { "" } - ); + pub async fn project_count_excluding_admins(&self) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Count, + } - Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?) + self.transaction(|tx| async move { + Ok(project::Entity::find() + .select_only() + .column_as(project::Column::Id.count(), QueryAs::Count) + .inner_join(user::Entity) + .filter(user::Column::Admin.eq(false)) + .into_values::<_, QueryAs>() + .one(&*tx) + .await? + .unwrap_or(0) as usize) }) + .await } - pub async fn get_user_by_github_account( + pub async fn share_project( &self, - github_login: &str, - github_user_id: Option, - ) -> Result> { - test_support!(self, { - if let Some(github_user_id) = github_user_id { - let mut user = sqlx::query_as::<_, User>( - " - UPDATE users - SET github_login = $1 - WHERE github_user_id = $2 - RETURNING * - ", - ) - .bind(github_login) - .bind(github_user_id) - .fetch_optional(&self.pool) - .await?; + room_id: RoomId, + connection_id: ConnectionId, + worktrees: &[proto::WorktreeMetadata], + ) -> Result> { + self.room_transaction(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("could not find participant"))?; + if participant.room_id != room_id { + return Err(anyhow!("shared project on unexpected room"))?; + } - if user.is_none() { - user = sqlx::query_as::<_, User>( - " - UPDATE users - SET github_user_id = $1 - WHERE github_login = $2 - RETURNING * - ", - ) - .bind(github_user_id) - .bind(github_login) - .fetch_optional(&self.pool) - .await?; - } + let project = project::ActiveModel { + room_id: ActiveValue::set(participant.room_id), + host_user_id: ActiveValue::set(participant.user_id), + host_connection_id: ActiveValue::set(connection_id.0 as i32), + host_connection_epoch: ActiveValue::set(self.epoch), + ..Default::default() + } + .insert(&*tx) + .await?; - Ok(user) - } else { - let user = sqlx::query_as( - " - SELECT * FROM users - WHERE github_login = $1 - LIMIT 1 - ", - ) - .bind(github_login) - .fetch_optional(&self.pool) + if !worktrees.is_empty() { + worktree::Entity::insert_many(worktrees.iter().map(|worktree| { + worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i64), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + is_complete: ActiveValue::set(false), + } + })) + .exec(&*tx) .await?; - Ok(user) } - }) - } - pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { - test_support!(self, { - let query = "UPDATE users SET admin = $1 WHERE id = $2"; - Ok(sqlx::query(query) - .bind(is_admin) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) - }) - } + project_collaborator::ActiveModel { + project_id: ActiveValue::set(project.id), + connection_id: ActiveValue::set(connection_id.0 as i32), + connection_epoch: ActiveValue::set(self.epoch), + user_id: ActiveValue::set(participant.user_id), + replica_id: ActiveValue::set(ReplicaId(0)), + is_host: ActiveValue::set(true), + ..Default::default() + } + .insert(&*tx) + .await?; - pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - test_support!(self, { - let query = "UPDATE users SET connected_once = $1 WHERE id = $2"; - Ok(sqlx::query(query) - .bind(connected_once) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, (project.id, room))) }) + .await } - pub async fn destroy_user(&self, id: UserId) -> Result<()> { - test_support!(self, { - let query = "DELETE FROM access_tokens WHERE user_id = $1;"; - sqlx::query(query) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?; - let query = "DELETE FROM users WHERE id = $1;"; - Ok(sqlx::query(query) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + pub async fn unshare_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result)>> { + self.room_transaction(|tx| async move { + let guest_connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("project not found"))?; + if project.host_connection_id == connection_id.0 as i32 { + let room_id = project.room_id; + project::Entity::delete(project.into_active_model()) + .exec(&*tx) + .await?; + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, (room, guest_connection_ids))) + } else { + Err(anyhow!("cannot unshare a project hosted by another user"))? + } }) + .await } - // signups + pub async fn update_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + worktrees: &[proto::WorktreeMetadata], + ) -> Result)>> { + self.room_transaction(|tx| async move { + let project = project::Entity::find_by_id(project_id) + .filter(project::Column::HostConnectionId.eq(connection_id.0)) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + + if !worktrees.is_empty() { + worktree::Entity::insert_many(worktrees.iter().map(|worktree| { + worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i64), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + is_complete: ActiveValue::set(false), + } + })) + .on_conflict( + OnConflict::columns([worktree::Column::ProjectId, worktree::Column::Id]) + .update_column(worktree::Column::RootName) + .to_owned(), + ) + .exec(&*tx) + .await?; + } - pub async fn get_waitlist_summary(&self) -> Result { - test_support!(self, { - Ok(sqlx::query_as( - " - SELECT - COUNT(*) as count, - COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, - COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count, - COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count, - COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count - FROM ( - SELECT * - FROM signups - WHERE - NOT email_confirmation_sent - ) AS unsent - ", - ) - .fetch_one(&self.pool) - .await?) - }) - } + worktree::Entity::delete_many() + .filter( + worktree::Column::ProjectId.eq(project.id).and( + worktree::Column::Id + .is_not_in(worktrees.iter().map(|worktree| worktree.id as i64)), + ), + ) + .exec(&*tx) + .await?; - pub async fn get_unsent_invites(&self, count: usize) -> Result> { - test_support!(self, { - Ok(sqlx::query_as( - " - SELECT - email_address, email_confirmation_code - FROM signups - WHERE - NOT email_confirmation_sent AND - (platform_mac OR platform_unknown) - ORDER BY - created_at - LIMIT $1 - ", - ) - .bind(count as i32) - .fetch_all(&self.pool) - .await?) + let guest_connection_ids = self.project_guest_connection_ids(project.id, &tx).await?; + let room = self.get_room(project.room_id, &tx).await?; + Ok((project.room_id, (room, guest_connection_ids))) }) + .await } - // invite codes + pub async fn update_worktree( + &self, + update: &proto::UpdateWorktree, + connection_id: ConnectionId, + ) -> Result>> { + self.room_transaction(|tx| async move { + let project_id = ProjectId::from_proto(update.project_id); + let worktree_id = update.worktree_id as i64; + + // Ensure the update comes from the host. + let project = project::Entity::find_by_id(project_id) + .filter(project::Column::HostConnectionId.eq(connection_id.0)) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + let room_id = project.room_id; + + // Update metadata. + worktree::Entity::update(worktree::ActiveModel { + id: ActiveValue::set(worktree_id), + project_id: ActiveValue::set(project_id), + root_name: ActiveValue::set(update.root_name.clone()), + scan_id: ActiveValue::set(update.scan_id as i64), + is_complete: ActiveValue::set(update.is_last_update), + abs_path: ActiveValue::set(update.abs_path.clone()), + ..Default::default() + }) + .exec(&*tx) + .await?; - pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { - test_support!(self, { - let mut tx = self.pool.begin().await?; - if count > 0 { - sqlx::query( - " - UPDATE users - SET invite_code = $1 - WHERE id = $2 AND invite_code IS NULL - ", + if !update.updated_entries.is_empty() { + worktree_entry::Entity::insert_many(update.updated_entries.iter().map(|entry| { + let mtime = entry.mtime.clone().unwrap_or_default(); + worktree_entry::ActiveModel { + project_id: ActiveValue::set(project_id), + worktree_id: ActiveValue::set(worktree_id), + id: ActiveValue::set(entry.id as i64), + is_dir: ActiveValue::set(entry.is_dir), + path: ActiveValue::set(entry.path.clone()), + inode: ActiveValue::set(entry.inode as i64), + mtime_seconds: ActiveValue::set(mtime.seconds as i64), + mtime_nanos: ActiveValue::set(mtime.nanos as i32), + is_symlink: ActiveValue::set(entry.is_symlink), + is_ignored: ActiveValue::set(entry.is_ignored), + } + })) + .on_conflict( + OnConflict::columns([ + worktree_entry::Column::ProjectId, + worktree_entry::Column::WorktreeId, + worktree_entry::Column::Id, + ]) + .update_columns([ + worktree_entry::Column::IsDir, + worktree_entry::Column::Path, + worktree_entry::Column::Inode, + worktree_entry::Column::MtimeSeconds, + worktree_entry::Column::MtimeNanos, + worktree_entry::Column::IsSymlink, + worktree_entry::Column::IsIgnored, + ]) + .to_owned(), ) - .bind(random_invite_code()) - .bind(id) - .execute(&mut tx) + .exec(&*tx) .await?; } - sqlx::query( - " - UPDATE users - SET invite_count = $1 - WHERE id = $2 - ", - ) - .bind(count as i32) - .bind(id) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) - }) - } - - pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - test_support!(self, { - let result: Option<(String, i32)> = sqlx::query_as( - " - SELECT invite_code, invite_count - FROM users - WHERE id = $1 AND invite_code IS NOT NULL - ", - ) - .bind(id) - .fetch_optional(&self.pool) - .await?; - if let Some((code, count)) = result { - Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?))) - } else { - Ok(None) + if !update.removed_entries.is_empty() { + worktree_entry::Entity::delete_many() + .filter( + worktree_entry::Column::ProjectId + .eq(project_id) + .and(worktree_entry::Column::WorktreeId.eq(worktree_id)) + .and( + worktree_entry::Column::Id + .is_in(update.removed_entries.iter().map(|id| *id as i64)), + ), + ) + .exec(&*tx) + .await?; } - }) - } - pub async fn get_user_for_invite_code(&self, code: &str) -> Result { - test_support!(self, { - sqlx::query_as( - " - SELECT * - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&self.pool) - .await? - .ok_or_else(|| { - Error::Http( - StatusCode::NOT_FOUND, - "that invite code does not exist".to_string(), - ) - }) + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + Ok((room_id, connection_ids)) }) + .await } - // projects + pub async fn update_diagnostic_summary( + &self, + update: &proto::UpdateDiagnosticSummary, + connection_id: ConnectionId, + ) -> Result>> { + self.room_transaction(|tx| async move { + let project_id = ProjectId::from_proto(update.project_id); + let worktree_id = update.worktree_id as i64; + let summary = update + .summary + .as_ref() + .ok_or_else(|| anyhow!("invalid summary"))?; + + // Ensure the update comes from the host. + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.host_connection_id != connection_id.0 as i32 { + return Err(anyhow!("can't update a project hosted by someone else"))?; + } - /// Registers a new project for the given user. - pub async fn register_project(&self, host_user_id: UserId) -> Result { - test_support!(self, { - Ok(sqlx::query_scalar( - " - INSERT INTO projects(host_user_id) - VALUES ($1) - RETURNING id - ", + // Update summary. + worktree_diagnostic_summary::Entity::insert(worktree_diagnostic_summary::ActiveModel { + project_id: ActiveValue::set(project_id), + worktree_id: ActiveValue::set(worktree_id), + path: ActiveValue::set(summary.path.clone()), + language_server_id: ActiveValue::set(summary.language_server_id as i64), + error_count: ActiveValue::set(summary.error_count as i32), + warning_count: ActiveValue::set(summary.warning_count as i32), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([ + worktree_diagnostic_summary::Column::ProjectId, + worktree_diagnostic_summary::Column::WorktreeId, + worktree_diagnostic_summary::Column::Path, + ]) + .update_columns([ + worktree_diagnostic_summary::Column::LanguageServerId, + worktree_diagnostic_summary::Column::ErrorCount, + worktree_diagnostic_summary::Column::WarningCount, + ]) + .to_owned(), ) - .bind(host_user_id) - .fetch_one(&self.pool) - .await - .map(ProjectId)?) + .exec(&*tx) + .await?; + + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + Ok((project.room_id, connection_ids)) }) + .await } - /// Unregisters a project for the given project id. - pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> { - test_support!(self, { - sqlx::query( - " - UPDATE projects - SET unregistered = TRUE - WHERE id = $1 - ", + pub async fn start_language_server( + &self, + update: &proto::StartLanguageServer, + connection_id: ConnectionId, + ) -> Result>> { + self.room_transaction(|tx| async move { + let project_id = ProjectId::from_proto(update.project_id); + let server = update + .server + .as_ref() + .ok_or_else(|| anyhow!("invalid language server"))?; + + // Ensure the update comes from the host. + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.host_connection_id != connection_id.0 as i32 { + return Err(anyhow!("can't update a project hosted by someone else"))?; + } + + // Add the newly-started language server. + language_server::Entity::insert(language_server::ActiveModel { + project_id: ActiveValue::set(project_id), + id: ActiveValue::set(server.id as i64), + name: ActiveValue::set(server.name.clone()), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([ + language_server::Column::ProjectId, + language_server::Column::Id, + ]) + .update_column(language_server::Column::Name) + .to_owned(), ) - .bind(project_id) - .execute(&self.pool) + .exec(&*tx) .await?; - Ok(()) + + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + Ok((project.room_id, connection_ids)) }) + .await } - // contacts - - pub async fn get_contacts(&self, user_id: UserId) -> Result> { - test_support!(self, { - let query = " - SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify - FROM contacts - WHERE user_id_a = $1 OR user_id_b = $1; - "; + pub async fn join_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.room_transaction(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("must join a room first"))?; - let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query) - .bind(user_id) - .fetch(&self.pool); + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.room_id != participant.room_id { + return Err(anyhow!("no such project"))?; + } - let mut contacts = Vec::new(); - while let Some(row) = rows.next().await { - let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?; + let mut collaborators = project + .find_related(project_collaborator::Entity) + .all(&*tx) + .await?; + let replica_ids = collaborators + .iter() + .map(|c| c.replica_id) + .collect::>(); + let mut replica_id = ReplicaId(1); + while replica_ids.contains(&replica_id) { + replica_id.0 += 1; + } + let new_collaborator = project_collaborator::ActiveModel { + project_id: ActiveValue::set(project_id), + connection_id: ActiveValue::set(connection_id.0 as i32), + connection_epoch: ActiveValue::set(self.epoch), + user_id: ActiveValue::set(participant.user_id), + replica_id: ActiveValue::set(replica_id), + is_host: ActiveValue::set(false), + ..Default::default() + } + .insert(&*tx) + .await?; + collaborators.push(new_collaborator); - if user_id_a == user_id { - if accepted { - contacts.push(Contact::Accepted { - user_id: user_id_b, - should_notify: should_notify && a_to_b, - }); - } else if a_to_b { - contacts.push(Contact::Outgoing { user_id: user_id_b }) - } else { - contacts.push(Contact::Incoming { - user_id: user_id_b, - should_notify, + let db_worktrees = project.find_related(worktree::Entity).all(&*tx).await?; + let mut worktrees = db_worktrees + .into_iter() + .map(|db_worktree| { + ( + db_worktree.id as u64, + Worktree { + id: db_worktree.id as u64, + abs_path: db_worktree.abs_path, + root_name: db_worktree.root_name, + visible: db_worktree.visible, + entries: Default::default(), + diagnostic_summaries: Default::default(), + scan_id: db_worktree.scan_id as u64, + is_complete: db_worktree.is_complete, + }, + ) + }) + .collect::>(); + + // Populate worktree entries. + { + let mut db_entries = worktree_entry::Entity::find() + .filter(worktree_entry::Column::ProjectId.eq(project_id)) + .stream(&*tx) + .await?; + while let Some(db_entry) = db_entries.next().await { + let db_entry = db_entry?; + if let Some(worktree) = worktrees.get_mut(&(db_entry.worktree_id as u64)) { + worktree.entries.push(proto::Entry { + id: db_entry.id as u64, + is_dir: db_entry.is_dir, + path: db_entry.path, + inode: db_entry.inode as u64, + mtime: Some(proto::Timestamp { + seconds: db_entry.mtime_seconds as u64, + nanos: db_entry.mtime_nanos as u32, + }), + is_symlink: db_entry.is_symlink, + is_ignored: db_entry.is_ignored, }); } - } else if accepted { - contacts.push(Contact::Accepted { - user_id: user_id_a, - should_notify: should_notify && !a_to_b, - }); - } else if a_to_b { - contacts.push(Contact::Incoming { - user_id: user_id_a, - should_notify, - }); - } else { - contacts.push(Contact::Outgoing { user_id: user_id_a }); } } - contacts.sort_unstable_by_key(|contact| contact.user_id()); - - Ok(contacts) - }) - } + // Populate worktree diagnostic summaries. + { + let mut db_summaries = worktree_diagnostic_summary::Entity::find() + .filter(worktree_diagnostic_summary::Column::ProjectId.eq(project_id)) + .stream(&*tx) + .await?; + while let Some(db_summary) = db_summaries.next().await { + let db_summary = db_summary?; + if let Some(worktree) = worktrees.get_mut(&(db_summary.worktree_id as u64)) { + worktree + .diagnostic_summaries + .push(proto::DiagnosticSummary { + path: db_summary.path, + language_server_id: db_summary.language_server_id as u64, + error_count: db_summary.error_count as u32, + warning_count: db_summary.warning_count as u32, + }); + } + } + } + + // Populate language servers. + let language_servers = project + .find_related(language_server::Entity) + .all(&*tx) + .await?; - pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - test_support!(self, { - let (id_a, id_b) = if user_id_1 < user_id_2 { - (user_id_1, user_id_2) - } else { - (user_id_2, user_id_1) + let room_id = project.room_id; + let project = Project { + collaborators, + worktrees, + language_servers: language_servers + .into_iter() + .map(|language_server| proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + }) + .collect(), }; - - let query = " - SELECT 1 FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE - LIMIT 1 - "; - Ok(sqlx::query_scalar::<_, i32>(query) - .bind(id_a.0) - .bind(id_b.0) - .fetch_optional(&self.pool) - .await? - .is_some()) + Ok((room_id, (project, replica_id as ReplicaId))) }) + .await } - pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - test_support!(self, { - let (id_a, id_b, a_to_b) = if sender_id < receiver_id { - (sender_id, receiver_id, true) - } else { - (receiver_id, sender_id, false) - }; - let query = " - INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify) - VALUES ($1, $2, $3, FALSE, TRUE) - ON CONFLICT (user_id_a, user_id_b) DO UPDATE - SET - accepted = TRUE, - should_notify = FALSE - WHERE - NOT contacts.accepted AND - ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR - (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a)); - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) + pub async fn leave_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.room_transaction(|tx| async move { + let result = project_collaborator::Entity::delete_many() + .filter( + project_collaborator::Column::ProjectId + .eq(project_id) + .and(project_collaborator::Column::ConnectionId.eq(connection_id.0)), + ) + .exec(&*tx) .await?; - - if result.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow!("contact already requested"))? + if result.rows_affected == 0 { + Err(anyhow!("not a collaborator on this project"))?; } + + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + let collaborators = project + .find_related(project_collaborator::Entity) + .all(&*tx) + .await?; + let connection_ids = collaborators + .into_iter() + .map(|collaborator| ConnectionId(collaborator.connection_id as u32)) + .collect(); + + let left_project = LeftProject { + id: project_id, + host_user_id: project.host_user_id, + host_connection_id: ConnectionId(project.host_connection_id as u32), + connection_ids, + }; + Ok((project.room_id, left_project)) }) + .await } - pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - test_support!(self, { - let (id_a, id_b) = if responder_id < requester_id { - (responder_id, requester_id) - } else { - (requester_id, responder_id) - }; - let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2; - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .execute(&self.pool) + pub async fn project_collaborators( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result>> { + self.room_transaction(|tx| async move { + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + let collaborators = project_collaborator::Entity::find() + .filter(project_collaborator::Column::ProjectId.eq(project_id)) + .all(&*tx) .await?; - if result.rows_affected() == 1 { - Ok(()) + if collaborators + .iter() + .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) + { + Ok((project.room_id, collaborators)) } else { - Err(anyhow!("no such contact"))? + Err(anyhow!("no such project"))? } }) + .await } - pub async fn dismiss_contact_notification( + pub async fn project_connection_ids( &self, - user_id: UserId, - contact_user_id: UserId, - ) -> Result<()> { - test_support!(self, { - let (id_a, id_b, a_to_b) = if user_id < contact_user_id { - (user_id, contact_user_id, true) - } else { - (contact_user_id, user_id, false) - }; - - let query = " - UPDATE contacts - SET should_notify = FALSE - WHERE - user_id_a = $1 AND user_id_b = $2 AND - ( - (a_to_b = $3 AND accepted) OR - (a_to_b != $3 AND NOT accepted) - ); - "; + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result>> { + self.room_transaction(|tx| async move { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + ConnectionId, + } - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + let mut db_connection_ids = project_collaborator::Entity::find() + .select_only() + .column_as( + project_collaborator::Column::ConnectionId, + QueryAs::ConnectionId, + ) + .filter(project_collaborator::Column::ProjectId.eq(project_id)) + .into_values::() + .stream(&*tx) .await?; - if result.rows_affected() == 0 { - Err(anyhow!("no such contact request"))?; + let mut connection_ids = HashSet::default(); + while let Some(connection_id) = db_connection_ids.next().await { + connection_ids.insert(ConnectionId(connection_id? as u32)); } - Ok(()) + if connection_ids.contains(&connection_id) { + Ok((project.room_id, connection_ids)) + } else { + Err(anyhow!("no such project"))? + } }) + .await } - pub async fn respond_to_contact_request( + async fn project_guest_connection_ids( &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()> { - test_support!(self, { - let (id_a, id_b, a_to_b) = if responder_id < requester_id { - (responder_id, requester_id, false) - } else { - (requester_id, responder_id, true) - }; - let result = if accept { - let query = " - UPDATE contacts - SET accepted = TRUE, should_notify = TRUE - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; - "; - sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) - .await? - } else { - let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted; - "; - sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&self.pool) - .await? - }; - if result.rows_affected() == 1 { - Ok(()) - } else { - Err(anyhow!("no such contact request"))? - } - }) + project_id: ProjectId, + tx: &DatabaseTransaction, + ) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + ConnectionId, + } + + let mut db_guest_connection_ids = project_collaborator::Entity::find() + .select_only() + .column_as( + project_collaborator::Column::ConnectionId, + QueryAs::ConnectionId, + ) + .filter( + project_collaborator::Column::ProjectId + .eq(project_id) + .and(project_collaborator::Column::IsHost.eq(false)), + ) + .into_values::() + .stream(tx) + .await?; + + let mut guest_connection_ids = Vec::new(); + while let Some(connection_id) = db_guest_connection_ids.next().await { + guest_connection_ids.push(ConnectionId(connection_id? as u32)); + } + Ok(guest_connection_ids) } // access tokens @@ -1122,54 +2084,246 @@ where access_token_hash: &str, max_access_token_count: usize, ) -> Result<()> { - test_support!(self, { - let insert_query = " - INSERT INTO access_tokens (user_id, hash) - VALUES ($1, $2); - "; - let cleanup_query = " - DELETE FROM access_tokens - WHERE id IN ( - SELECT id from access_tokens - WHERE user_id = $1 - ORDER BY id DESC - LIMIT 10000 - OFFSET $3 - ) - "; + self.transaction(|tx| async { + let tx = tx; - let mut tx = self.pool.begin().await?; - sqlx::query(insert_query) - .bind(user_id.0) - .bind(access_token_hash) - .execute(&mut tx) - .await?; - sqlx::query(cleanup_query) - .bind(user_id.0) - .bind(access_token_hash) - .bind(max_access_token_count as i32) - .execute(&mut tx) + access_token::ActiveModel { + user_id: ActiveValue::set(user_id), + hash: ActiveValue::set(access_token_hash.into()), + ..Default::default() + } + .insert(&*tx) + .await?; + + access_token::Entity::delete_many() + .filter( + access_token::Column::Id.in_subquery( + Query::select() + .column(access_token::Column::Id) + .from(access_token::Entity) + .and_where(access_token::Column::UserId.eq(user_id)) + .order_by(access_token::Column::Id, sea_orm::Order::Desc) + .limit(10000) + .offset(max_access_token_count as u64) + .to_owned(), + ), + ) + .exec(&*tx) .await?; - Ok(tx.commit().await?) + Ok(()) }) + .await } pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - test_support!(self, { - let query = " - SELECT hash - FROM access_tokens - WHERE user_id = $1 - ORDER BY id DESC - "; - Ok(sqlx::query_scalar(query) - .bind(user_id.0) - .fetch_all(&self.pool) + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Hash, + } + + self.transaction(|tx| async move { + Ok(access_token::Entity::find() + .select_only() + .column(access_token::Column::Hash) + .filter(access_token::Column::UserId.eq(user_id)) + .order_by_desc(access_token::Column::Id) + .into_values::<_, QueryAs>() + .all(&*tx) .await?) }) + .await + } + + async fn transaction(&self, f: F) -> Result + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let body = async { + loop { + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(result) => { + match tx.commit().await.map_err(Into::into) { + Ok(()) => return Ok(result), + Err(error) => { + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } + } + } + } + Err(error) => { + tx.rollback().await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } + } + } + } + }; + + self.run(body).await + } + + async fn room_transaction(&self, f: F) -> Result> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let body = async { + loop { + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok((room_id, data)) => { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + match tx.commit().await.map_err(Into::into) { + Ok(()) => { + return Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }); + } + Err(error) => { + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } + } + } + } + Err(error) => { + tx.rollback().await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } + } + } + } + }; + + self.run(body).await + } + + async fn with_transaction(&self, f: &F) -> Result<(DatabaseTransaction, Result)> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let tx = self + .pool + .begin_with_config(Some(IsolationLevel::Serializable), None) + .await?; + + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else { + return Err(anyhow!("couldn't complete transaction because it's still in use"))?; + }; + + Ok((tx, result)) + } + + async fn run(&self, future: F) -> T + where + F: Future, + { + #[cfg(test)] + { + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + self.runtime.as_ref().unwrap().block_on(future) + } + + #[cfg(not(test))] + { + future.await + } + } +} + +fn is_serialization_error(error: &Error) -> bool { + const SERIALIZATION_FAILURE_CODE: &'static str = "40001"; + match error { + Error::Database( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some(SERIALIZATION_FAILURE_CODE) => + { + true + } + _ => false, + } +} + +struct TransactionHandle(Arc>); + +impl Deref for TransactionHandle { + type Target = DatabaseTransaction; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().as_ref().unwrap() + } +} + +pub struct RoomGuard { + data: T, + _guard: OwnedMutexGuard<()>, + _not_send: PhantomData>, +} + +impl Deref for RoomGuard { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl DerefMut for RoomGuard { + fn deref_mut(&mut self) -> &mut T { + &mut self.data } } +#[derive(Debug, Serialize, Deserialize)] +pub struct NewUserParams { + pub github_login: String, + pub github_user_id: i32, + pub invite_count: i32, +} + +#[derive(Debug)] +pub struct NewUserResult { + pub user_id: UserId, + pub metrics_id: String, + pub inviting_user_id: Option, + pub signup_device_id: Option, +} + +fn random_invite_code() -> String { + nanoid::nanoid!(16) +} + +fn random_email_confirmation_code() -> String { + nanoid::nanoid!(64) +} + macro_rules! id_type { ($name:ident) => { #[derive( @@ -1182,11 +2336,9 @@ macro_rules! id_type { PartialOrd, Ord, Hash, - sqlx::Type, Serialize, Deserialize, )] - #[sqlx(transparent)] #[serde(transparent)] pub struct $name(pub i32); @@ -1210,114 +2362,125 @@ macro_rules! id_type { self.0.fmt(f) } } - }; -} -id_type!(UserId); -#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] -pub struct User { - pub id: UserId, - pub github_login: String, - pub github_user_id: Option, - pub email_address: Option, - pub admin: bool, - pub invite_code: Option, - pub invite_count: i32, - pub connected_once: bool, -} + impl From<$name> for sea_query::Value { + fn from(value: $name) -> Self { + sea_query::Value::Int(Some(value.0)) + } + } -id_type!(ProjectId); -#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] -pub struct Project { - pub id: ProjectId, - pub host_user_id: UserId, - pub unregistered: bool, -} + impl sea_orm::TryGetable for $name { + fn try_get( + res: &sea_orm::QueryResult, + pre: &str, + col: &str, + ) -> Result { + Ok(Self(i32::try_get(res, pre, col)?)) + } + } -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum Contact { - Accepted { - user_id: UserId, - should_notify: bool, - }, - Outgoing { - user_id: UserId, - }, - Incoming { - user_id: UserId, - should_notify: bool, - }, -} + impl sea_query::ValueType for $name { + fn try_from(v: Value) -> Result { + match v { + Value::TinyInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::SmallInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::Int(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::BigInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::TinyUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::SmallUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::Unsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::BigUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + _ => Err(sea_query::ValueTypeErr), + } + } -impl Contact { - pub fn user_id(&self) -> UserId { - match self { - Contact::Accepted { user_id, .. } => *user_id, - Contact::Outgoing { user_id } => *user_id, - Contact::Incoming { user_id, .. } => *user_id, - } - } -} + fn type_name() -> String { + stringify!($name).into() + } -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct IncomingContactRequest { - pub requester_id: UserId, - pub should_notify: bool, -} + fn array_type() -> sea_query::ArrayType { + sea_query::ArrayType::Int + } -#[derive(Clone, Deserialize, Default)] -pub struct Signup { - pub email_address: String, - pub platform_mac: bool, - pub platform_windows: bool, - pub platform_linux: bool, - pub editor_features: Vec, - pub programming_languages: Vec, - pub device_id: Option, - pub added_to_mailing_list: bool, -} + fn column_type() -> sea_query::ColumnType { + sea_query::ColumnType::Integer(None) + } + } -#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)] -pub struct WaitlistSummary { - #[sqlx(default)] - pub count: i64, - #[sqlx(default)] - pub linux_count: i64, - #[sqlx(default)] - pub mac_count: i64, - #[sqlx(default)] - pub windows_count: i64, - #[sqlx(default)] - pub unknown_count: i64, -} + impl sea_orm::TryFromU64 for $name { + fn try_from_u64(n: u64) -> Result { + Ok(Self(n.try_into().map_err(|_| { + DbErr::ConvertFromU64(concat!( + "error converting ", + stringify!($name), + " to u64" + )) + })?)) + } + } -#[derive(Clone, FromRow, PartialEq, Debug, Serialize, Deserialize)] -pub struct Invite { - pub email_address: String, - pub email_confirmation_code: String, + impl sea_query::Nullable for $name { + fn null() -> Value { + Value::Int(None) + } + } + }; } -#[derive(Debug, Serialize, Deserialize)] -pub struct NewUserParams { - pub github_login: String, - pub github_user_id: i32, - pub invite_count: i32, +id_type!(AccessTokenId); +id_type!(ContactId); +id_type!(RoomId); +id_type!(RoomParticipantId); +id_type!(ProjectId); +id_type!(ProjectCollaboratorId); +id_type!(ReplicaId); +id_type!(SignupId); +id_type!(UserId); + +pub struct LeftRoom { + pub room: proto::Room, + pub left_projects: HashMap, + pub canceled_calls_to_user_ids: Vec, } -#[derive(Debug)] -pub struct NewUserResult { - pub user_id: UserId, - pub metrics_id: String, - pub inviting_user_id: Option, - pub signup_device_id: Option, +pub struct Project { + pub collaborators: Vec, + pub worktrees: BTreeMap, + pub language_servers: Vec, } -fn random_invite_code() -> String { - nanoid::nanoid!(16) +pub struct LeftProject { + pub id: ProjectId, + pub host_user_id: UserId, + pub host_connection_id: ConnectionId, + pub connection_ids: Vec, } -fn random_email_confirmation_code() -> String { - nanoid::nanoid!(64) +pub struct Worktree { + pub id: u64, + pub abs_path: String, + pub root_name: String, + pub visible: bool, + pub entries: Vec, + pub diagnostic_summaries: Vec, + pub scan_id: u64, + pub is_complete: bool, } #[cfg(test)] @@ -1330,35 +2493,40 @@ mod test { use lazy_static::lazy_static; use parking_lot::Mutex; use rand::prelude::*; + use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; use std::sync::Arc; - pub struct SqliteTestDb { - pub db: Option>>, - pub conn: sqlx::sqlite::SqliteConnection, + pub struct TestDb { + pub db: Option>, + pub connection: Option, } - pub struct PostgresTestDb { - pub db: Option>>, - pub url: String, - } - - impl SqliteTestDb { - pub fn new(background: Arc) -> Self { - let mut rng = StdRng::from_entropy(); - let url = format!("file:zed-test-{}?mode=memory", rng.gen::()); + impl TestDb { + pub fn sqlite(background: Arc) -> Self { + let url = format!("sqlite::memory:"); let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() .unwrap(); - let (mut db, conn) = runtime.block_on(async { - let db = Db::::new(&url, 5).await.unwrap(); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); - db.migrate(migrations_path.as_ref(), false).await.unwrap(); - let conn = db.pool.acquire().await.unwrap().detach(); - (db, conn) + let mut db = runtime.block_on(async { + let mut options = ConnectOptions::new(url); + options.max_connections(5); + let db = Database::new(options).await.unwrap(); + let sql = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/migrations.sqlite/20221109000000_test_schema.sql" + )); + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + sql.into(), + )) + .await + .unwrap(); + db }); db.background = Some(background); @@ -1366,17 +2534,11 @@ mod test { Self { db: Some(Arc::new(db)), - conn, + connection: None, } } - pub fn db(&self) -> &Arc> { - self.db.as_ref().unwrap() - } - } - - impl PostgresTestDb { - pub fn new(background: Arc) -> Self { + pub fn postgres(background: Arc) -> Self { lazy_static! { static ref LOCK: Mutex<()> = Mutex::new(()); } @@ -1397,7 +2559,11 @@ mod test { sqlx::Postgres::create_database(&url) .await .expect("failed to create test db"); - let db = Db::::new(&url, 5).await.unwrap(); + let mut options = ConnectOptions::new(url); + options + .max_connections(5) + .idle_timeout(Duration::from_secs(0)); + let db = Database::new(options).await.unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); db.migrate(Path::new(migrations_path), false).await.unwrap(); db @@ -1408,19 +2574,40 @@ mod test { Self { db: Some(Arc::new(db)), - url, + connection: None, } } - pub fn db(&self) -> &Arc> { + pub fn db(&self) -> &Arc { self.db.as_ref().unwrap() } } - impl Drop for PostgresTestDb { + impl Drop for TestDb { fn drop(&mut self) { let db = self.db.take().unwrap(); - db.teardown(&self.url); + if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { + db.runtime.as_ref().unwrap().block_on(async { + use util::ResultExt; + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE + pg_stat_activity.datname = current_database() AND + pid <> pg_backend_pid(); + "; + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + query.into(), + )) + .await + .log_err(); + sqlx::Postgres::drop_database(db.options.get_url()) + .await + .log_err(); + }) + } } } } diff --git a/crates/collab/src/db/access_token.rs b/crates/collab/src/db/access_token.rs new file mode 100644 index 0000000000000000000000000000000000000000..f5caa4843dd43bff501ac87870e367a960dd25ac --- /dev/null +++ b/crates/collab/src/db/access_token.rs @@ -0,0 +1,29 @@ +use super::{AccessTokenId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "access_tokens")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: AccessTokenId, + pub user_id: UserId, + pub hash: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/contact.rs b/crates/collab/src/db/contact.rs new file mode 100644 index 0000000000000000000000000000000000000000..c39d6643b3a4066eb159e8dc87f692d1d5ca3c3c --- /dev/null +++ b/crates/collab/src/db/contact.rs @@ -0,0 +1,58 @@ +use super::{ContactId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "contacts")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ContactId, + pub user_id_a: UserId, + pub user_id_b: UserId, + pub a_to_b: bool, + pub should_notify: bool, + pub accepted: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::room_participant::Entity", + from = "Column::UserIdA", + to = "super::room_participant::Column::UserId" + )] + UserARoomParticipant, + #[sea_orm( + belongs_to = "super::room_participant::Entity", + from = "Column::UserIdB", + to = "super::room_participant::Column::UserId" + )] + UserBRoomParticipant, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Contact { + Accepted { + user_id: UserId, + should_notify: bool, + busy: bool, + }, + Outgoing { + user_id: UserId, + }, + Incoming { + user_id: UserId, + should_notify: bool, + }, +} + +impl Contact { + pub fn user_id(&self) -> UserId { + match self { + Contact::Accepted { user_id, .. } => *user_id, + Contact::Outgoing { user_id } => *user_id, + Contact::Incoming { user_id, .. } => *user_id, + } + } +} diff --git a/crates/collab/src/db/language_server.rs b/crates/collab/src/db/language_server.rs new file mode 100644 index 0000000000000000000000000000000000000000..d2c045c12146e1c8797c4cbd7e1ae52c52829e98 --- /dev/null +++ b/crates/collab/src/db/language_server.rs @@ -0,0 +1,30 @@ +use super::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "language_servers")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub id: i64, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::project::Entity", + from = "Column::ProjectId", + to = "super::project::Column::Id" + )] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/project.rs b/crates/collab/src/db/project.rs new file mode 100644 index 0000000000000000000000000000000000000000..971a8fcefb465114c9703003e2a74f6f38d8c397 --- /dev/null +++ b/crates/collab/src/db/project.rs @@ -0,0 +1,67 @@ +use super::{ProjectId, RoomId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "projects")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ProjectId, + pub room_id: RoomId, + pub host_user_id: UserId, + pub host_connection_id: i32, + pub host_connection_epoch: Uuid, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::HostUserId", + to = "super::user::Column::Id" + )] + HostUser, + #[sea_orm( + belongs_to = "super::room::Entity", + from = "Column::RoomId", + to = "super::room::Column::Id" + )] + Room, + #[sea_orm(has_many = "super::worktree::Entity")] + Worktrees, + #[sea_orm(has_many = "super::project_collaborator::Entity")] + Collaborators, + #[sea_orm(has_many = "super::language_server::Entity")] + LanguageServers, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::HostUser.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Room.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Worktrees.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Collaborators.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::LanguageServers.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/project_collaborator.rs b/crates/collab/src/db/project_collaborator.rs new file mode 100644 index 0000000000000000000000000000000000000000..5db307f5df27ec07f282207eb253ddae95c44970 --- /dev/null +++ b/crates/collab/src/db/project_collaborator.rs @@ -0,0 +1,33 @@ +use super::{ProjectCollaboratorId, ProjectId, ReplicaId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "project_collaborators")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ProjectCollaboratorId, + pub project_id: ProjectId, + pub connection_id: i32, + pub connection_epoch: Uuid, + pub user_id: UserId, + pub replica_id: ReplicaId, + pub is_host: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::project::Entity", + from = "Column::ProjectId", + to = "super::project::Column::Id" + )] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/room.rs b/crates/collab/src/db/room.rs new file mode 100644 index 0000000000000000000000000000000000000000..7dbf03a780adbd69c1d3b492e4bcf82557ae70ab --- /dev/null +++ b/crates/collab/src/db/room.rs @@ -0,0 +1,32 @@ +use super::RoomId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "rooms")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: RoomId, + pub live_kit_room: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::room_participant::Entity")] + RoomParticipant, + #[sea_orm(has_many = "super::project::Entity")] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::RoomParticipant.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/room_participant.rs b/crates/collab/src/db/room_participant.rs new file mode 100644 index 0000000000000000000000000000000000000000..783f45aa93e1952be3f5dd2f5efd0d51da6665cd --- /dev/null +++ b/crates/collab/src/db/room_participant.rs @@ -0,0 +1,49 @@ +use super::{ProjectId, RoomId, RoomParticipantId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "room_participants")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: RoomParticipantId, + pub room_id: RoomId, + pub user_id: UserId, + pub answering_connection_id: Option, + pub answering_connection_epoch: Option, + pub location_kind: Option, + pub location_project_id: Option, + pub initial_project_id: Option, + pub calling_user_id: UserId, + pub calling_connection_id: i32, + pub calling_connection_epoch: Uuid, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, + #[sea_orm( + belongs_to = "super::room::Entity", + from = "Column::RoomId", + to = "super::room::Column::Id" + )] + Room, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Room.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/signup.rs b/crates/collab/src/db/signup.rs new file mode 100644 index 0000000000000000000000000000000000000000..ca219736a8be46cb779efb4326a181e37b6749e5 --- /dev/null +++ b/crates/collab/src/db/signup.rs @@ -0,0 +1,56 @@ +use super::{SignupId, UserId}; +use sea_orm::{entity::prelude::*, FromQueryResult}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "signups")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: SignupId, + pub email_address: String, + pub email_confirmation_code: String, + pub email_confirmation_sent: bool, + pub created_at: DateTime, + pub device_id: Option, + pub user_id: Option, + pub inviting_user_id: Option, + pub platform_mac: bool, + pub platform_linux: bool, + pub platform_windows: bool, + pub platform_unknown: bool, + pub editor_features: Option>, + pub programming_languages: Option>, + pub added_to_mailing_list: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] +pub struct Invite { + pub email_address: String, + pub email_confirmation_code: String, +} + +#[derive(Clone, Deserialize)] +pub struct NewSignup { + pub email_address: String, + pub platform_mac: bool, + pub platform_windows: bool, + pub platform_linux: bool, + pub editor_features: Vec, + pub programming_languages: Vec, + pub device_id: Option, + pub added_to_mailing_list: bool, +} + +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)] +pub struct WaitlistSummary { + pub count: i64, + pub linux_count: i64, + pub mac_count: i64, + pub windows_count: i64, + pub unknown_count: i64, +} diff --git a/crates/collab/src/db_tests.rs b/crates/collab/src/db/tests.rs similarity index 91% rename from crates/collab/src/db_tests.rs rename to crates/collab/src/db/tests.rs index 6260eadc4a83f259b3dcb06ce650ea50788112cc..9e70ae4b05798e25fc776bd8fb1a0493caa8729d 100644 --- a/crates/collab/src/db_tests.rs +++ b/crates/collab/src/db/tests.rs @@ -1,4 +1,4 @@ -use super::db::*; +use super::*; use gpui::executor::{Background, Deterministic}; use std::sync::Arc; @@ -6,14 +6,14 @@ macro_rules! test_both_dbs { ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => { #[gpui::test] async fn $postgres_test_name() { - let test_db = PostgresTestDb::new(Deterministic::new(0).build_background()); + let test_db = TestDb::postgres(Deterministic::new(0).build_background()); let $db = test_db.db(); $body } #[gpui::test] async fn $sqlite_test_name() { - let test_db = SqliteTestDb::new(Deterministic::new(0).build_background()); + let test_db = TestDb::sqlite(Deterministic::new(0).build_background()); let $db = test_db.db(); $body } @@ -26,9 +26,10 @@ test_both_dbs!( db, { let mut user_ids = Vec::new(); + let mut user_metric_ids = Vec::new(); for i in 1..=4 { - user_ids.push( - db.create_user( + let user = db + .create_user( &format!("user{i}@example.com"), false, NewUserParams { @@ -38,9 +39,9 @@ test_both_dbs!( }, ) .await - .unwrap() - .user_id, - ); + .unwrap(); + user_ids.push(user.user_id); + user_metric_ids.push(user.metrics_id); } assert_eq!( @@ -52,6 +53,7 @@ test_both_dbs!( github_user_id: Some(1), email_address: Some("user1@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[0].parse().unwrap(), ..Default::default() }, User { @@ -60,6 +62,7 @@ test_both_dbs!( github_user_id: Some(2), email_address: Some("user2@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[1].parse().unwrap(), ..Default::default() }, User { @@ -68,6 +71,7 @@ test_both_dbs!( github_user_id: Some(3), email_address: Some("user3@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[2].parse().unwrap(), ..Default::default() }, User { @@ -76,6 +80,7 @@ test_both_dbs!( github_user_id: Some(4), email_address: Some("user4@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[3].parse().unwrap(), ..Default::default() } ] @@ -258,7 +263,8 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true + should_notify: true, + busy: false, }], ); assert!(db.has_contact(user_1, user_2).await.unwrap()); @@ -268,6 +274,7 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { &[Contact::Accepted { user_id: user_1, should_notify: false, + busy: false, }] ); @@ -284,6 +291,7 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { &[Contact::Accepted { user_id: user_2, should_notify: true, + busy: false, }] ); @@ -296,6 +304,7 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { &[Contact::Accepted { user_id: user_2, should_notify: false, + busy: false, }] ); @@ -309,10 +318,12 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { Contact::Accepted { user_id: user_2, should_notify: false, + busy: false, }, Contact::Accepted { user_id: user_3, - should_notify: false + should_notify: false, + busy: false, } ] ); @@ -320,7 +331,8 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false + should_notify: false, + busy: false, }], ); @@ -335,14 +347,16 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false + should_notify: false, + busy: false, }] ); assert_eq!( db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false + should_notify: false, + busy: false, }], ); }); @@ -390,14 +404,14 @@ test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { #[test] fn test_fuzzy_like_string() { - assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); - assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); - assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); + assert_eq!(Database::fuzzy_like_string("abcd"), "%a%b%c%d%"); + assert_eq!(Database::fuzzy_like_string("x y"), "%x%y%"); + assert_eq!(Database::fuzzy_like_string(" z "), "%z%"); } #[gpui::test] async fn test_fuzzy_search_users() { - let test_db = PostgresTestDb::new(build_background_executor()); + let test_db = TestDb::postgres(build_background_executor()); let db = test_db.db(); for (i, github_login) in [ "California", @@ -433,7 +447,7 @@ async fn test_fuzzy_search_users() { &["rhode-island", "colorado", "oregon"], ); - async fn fuzzy_search_user_names(db: &Db, query: &str) -> Vec { + async fn fuzzy_search_user_names(db: &Database, query: &str) -> Vec { db.fuzzy_search_users(query, 10) .await .unwrap() @@ -445,7 +459,7 @@ async fn test_fuzzy_search_users() { #[gpui::test] async fn test_invite_codes() { - let test_db = PostgresTestDb::new(build_background_executor()); + let test_db = TestDb::postgres(build_background_executor()); let db = test_db.db(); let NewUserResult { user_id: user1, .. } = db @@ -504,14 +518,16 @@ async fn test_invite_codes() { db.get_contacts(user1).await.unwrap(), [Contact::Accepted { user_id: user2, - should_notify: true + should_notify: true, + busy: false, }] ); assert_eq!( db.get_contacts(user2).await.unwrap(), [Contact::Accepted { user_id: user1, - should_notify: false + should_notify: false, + busy: false, }] ); assert_eq!( @@ -550,11 +566,13 @@ async fn test_invite_codes() { [ Contact::Accepted { user_id: user2, - should_notify: true + should_notify: true, + busy: false, }, Contact::Accepted { user_id: user3, - should_notify: true + should_notify: true, + busy: false, } ] ); @@ -562,7 +580,8 @@ async fn test_invite_codes() { db.get_contacts(user3).await.unwrap(), [Contact::Accepted { user_id: user1, - should_notify: false + should_notify: false, + busy: false, }] ); assert_eq!( @@ -607,15 +626,18 @@ async fn test_invite_codes() { [ Contact::Accepted { user_id: user2, - should_notify: true + should_notify: true, + busy: false, }, Contact::Accepted { user_id: user3, - should_notify: true + should_notify: true, + busy: false, }, Contact::Accepted { user_id: user4, - should_notify: true + should_notify: true, + busy: false, } ] ); @@ -623,7 +645,8 @@ async fn test_invite_codes() { db.get_contacts(user4).await.unwrap(), [Contact::Accepted { user_id: user1, - should_notify: false + should_notify: false, + busy: false, }] ); assert_eq!( @@ -641,7 +664,7 @@ async fn test_invite_codes() { #[gpui::test] async fn test_signups() { - let test_db = PostgresTestDb::new(build_background_executor()); + let test_db = TestDb::postgres(build_background_executor()); let db = test_db.db(); let usernames = (0..8).map(|i| format!("person-{i}")).collect::>(); @@ -649,7 +672,7 @@ async fn test_signups() { let all_signups = usernames .iter() .enumerate() - .map(|(i, username)| Signup { + .map(|(i, username)| NewSignup { email_address: format!("{username}@example.com"), platform_mac: true, platform_linux: i % 2 == 0, @@ -659,7 +682,7 @@ async fn test_signups() { device_id: Some(format!("device_id_{i}")), added_to_mailing_list: i != 0, // One user failed to subscribe }) - .collect::>(); + .collect::>(); // people sign up on the waitlist for signup in &all_signups { diff --git a/crates/collab/src/db/user.rs b/crates/collab/src/db/user.rs new file mode 100644 index 0000000000000000000000000000000000000000..c2b157bd0a758880fd6fe64b079fa8760b59df5c --- /dev/null +++ b/crates/collab/src/db/user.rs @@ -0,0 +1,49 @@ +use super::UserId; +use sea_orm::entity::prelude::*; +use serde::Serialize; + +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel, Serialize)] +#[sea_orm(table_name = "users")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: UserId, + pub github_login: String, + pub github_user_id: Option, + pub email_address: Option, + pub admin: bool, + pub invite_code: Option, + pub invite_count: i32, + pub inviter_id: Option, + pub connected_once: bool, + pub metrics_id: Uuid, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::access_token::Entity")] + AccessToken, + #[sea_orm(has_one = "super::room_participant::Entity")] + RoomParticipant, + #[sea_orm(has_many = "super::project::Entity")] + HostedProjects, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::AccessToken.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::RoomParticipant.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::HostedProjects.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/worktree.rs b/crates/collab/src/db/worktree.rs new file mode 100644 index 0000000000000000000000000000000000000000..b9f0f97dee05b71558a050fa808b62f56b2aefd1 --- /dev/null +++ b/crates/collab/src/db/worktree.rs @@ -0,0 +1,34 @@ +use super::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktrees")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i64, + #[sea_orm(primary_key)] + pub project_id: ProjectId, + pub abs_path: String, + pub root_name: String, + pub visible: bool, + pub scan_id: i64, + pub is_complete: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::project::Entity", + from = "Column::ProjectId", + to = "super::project::Column::Id" + )] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/worktree_diagnostic_summary.rs b/crates/collab/src/db/worktree_diagnostic_summary.rs new file mode 100644 index 0000000000000000000000000000000000000000..f3dd8083fb57d9e863ea51de4e4de26b2d594a61 --- /dev/null +++ b/crates/collab/src/db/worktree_diagnostic_summary.rs @@ -0,0 +1,21 @@ +use super::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktree_diagnostic_summaries")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub worktree_id: i64, + #[sea_orm(primary_key)] + pub path: String, + pub language_server_id: i64, + pub error_count: i32, + pub warning_count: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/worktree_entry.rs b/crates/collab/src/db/worktree_entry.rs new file mode 100644 index 0000000000000000000000000000000000000000..413821201a20dd392713f43ec3b4163d7ff31f88 --- /dev/null +++ b/crates/collab/src/db/worktree_entry.rs @@ -0,0 +1,25 @@ +use super::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktree_entries")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub worktree_id: i64, + #[sea_orm(primary_key)] + pub id: i64, + pub is_dir: bool, + pub path: String, + pub inode: i64, + pub mtime_seconds: i64, + pub mtime_nanos: i32, + pub is_symlink: bool, + pub is_ignored: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index b97b4dad3385efbe40f01fcec7278bbec584ae22..c1610d71cd4cb7820990097526ebcd08558f8e59 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -1,5 +1,5 @@ use crate::{ - db::{NewUserParams, ProjectId, SqliteTestDb as TestDb, UserId}, + db::{self, NewUserParams, TestDb, UserId}, rpc::{Executor, Server}, AppState, }; @@ -31,9 +31,7 @@ use language::{ use live_kit_client::MacOSDisplay; use lsp::{self, FakeLanguageServer}; use parking_lot::Mutex; -use project::{ - search::SearchQuery, DiagnosticSummary, Project, ProjectPath, ProjectStore, WorktreeId, -}; +use project::{search::SearchQuery, DiagnosticSummary, Project, ProjectPath, WorktreeId}; use rand::prelude::*; use serde_json::json; use settings::{Formatter, Settings}; @@ -72,8 +70,6 @@ async fn test_basic_calls( deterministic.forbid_parking(); let mut server = TestServer::start(cx_a.background()).await; - let start = std::time::Instant::now(); - let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -105,7 +101,7 @@ async fn test_basic_calls( // User B receives the call. let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); let call_b = incoming_call_b.next().await.unwrap().unwrap(); - assert_eq!(call_b.caller.github_login, "user_a"); + assert_eq!(call_b.calling_user.github_login, "user_a"); // User B connects via another client and also receives a ring on the newly-connected client. let _client_b2 = server.create_client(cx_b2, "user_b").await; @@ -113,7 +109,7 @@ async fn test_basic_calls( let mut incoming_call_b2 = active_call_b2.read_with(cx_b2, |call, _| call.incoming()); deterministic.run_until_parked(); let call_b2 = incoming_call_b2.next().await.unwrap().unwrap(); - assert_eq!(call_b2.caller.github_login, "user_a"); + assert_eq!(call_b2.calling_user.github_login, "user_a"); // User B joins the room using the first client. active_call_b @@ -166,7 +162,7 @@ async fn test_basic_calls( // User C receives the call, but declines it. let call_c = incoming_call_c.next().await.unwrap().unwrap(); - assert_eq!(call_c.caller.github_login, "user_b"); + assert_eq!(call_c.calling_user.github_login, "user_b"); active_call_c.update(cx_c, |call, _| call.decline_incoming().unwrap()); assert!(incoming_call_c.next().await.unwrap().is_none()); @@ -259,8 +255,6 @@ async fn test_basic_calls( pending: Default::default() } ); - - eprintln!("finished test {:?}", start.elapsed()); } #[gpui::test(iterations = 10)] @@ -309,7 +303,7 @@ async fn test_room_uniqueness( // User B receives the call from user A. let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); let call_b1 = incoming_call_b.next().await.unwrap().unwrap(); - assert_eq!(call_b1.caller.github_login, "user_a"); + assert_eq!(call_b1.calling_user.github_login, "user_a"); // Ensure calling users A and B from client C fails. active_call_c @@ -368,7 +362,7 @@ async fn test_room_uniqueness( .unwrap(); deterministic.run_until_parked(); let call_b2 = incoming_call_b.next().await.unwrap().unwrap(); - assert_eq!(call_b2.caller.github_login, "user_c"); + assert_eq!(call_b2.calling_user.github_login, "user_c"); } #[gpui::test(iterations = 10)] @@ -696,7 +690,7 @@ async fn test_share_project( let incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); deterministic.run_until_parked(); let call = incoming_call_b.borrow().clone().unwrap(); - assert_eq!(call.caller.github_login, "user_a"); + assert_eq!(call.calling_user.github_login, "user_a"); let initial_project = call.initial_project.unwrap(); active_call_b .update(cx_b, |call, cx| call.accept_incoming(cx)) @@ -767,7 +761,7 @@ async fn test_share_project( let incoming_call_c = active_call_c.read_with(cx_c, |call, _| call.incoming()); deterministic.run_until_parked(); let call = incoming_call_c.borrow().clone().unwrap(); - assert_eq!(call.caller.github_login, "user_b"); + assert_eq!(call.calling_user.github_login, "user_b"); let initial_project = call.initial_project.unwrap(); active_call_c .update(cx_c, |call, cx| call.accept_incoming(cx)) @@ -2292,7 +2286,6 @@ async fn test_leaving_project( project_id, client_b.client.clone(), client_b.user_store.clone(), - client_b.project_store.clone(), client_b.language_registry.clone(), FakeFs::new(cx.background()), cx, @@ -2416,12 +2409,6 @@ async fn test_collaborating_with_diagnostics( // Wait for server to see the diagnostics update. deterministic.run_until_parked(); - { - let store = server.store.lock().await; - let project = store.project(ProjectId::from_proto(project_id)).unwrap(); - let worktree = project.worktrees.get(&worktree_id.to_proto()).unwrap(); - assert!(!worktree.diagnostic_summaries.is_empty()); - } // Ensure client B observes the new diagnostics. project_b.read_with(cx_b, |project, cx| { @@ -2443,7 +2430,10 @@ async fn test_collaborating_with_diagnostics( // Join project as client C and observe the diagnostics. let project_c = client_c.build_remote_project(project_id, cx_c).await; - let project_c_diagnostic_summaries = Rc::new(RefCell::new(Vec::new())); + let project_c_diagnostic_summaries = + Rc::new(RefCell::new(project_c.read_with(cx_c, |project, cx| { + project.diagnostic_summaries(cx).collect::>() + }))); project_c.update(cx_c, |_, cx| { let summaries = project_c_diagnostic_summaries.clone(); cx.subscribe(&project_c, { @@ -5627,18 +5617,15 @@ async fn test_random_collaboration( } for user_id in &user_ids { let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap(); - let contacts = server - .store - .lock() - .await - .build_initial_contacts_update(contacts) - .contacts; + let pool = server.connection_pool.lock().await; for contact in contacts { - if contact.online { - assert_ne!( - contact.user_id, removed_guest_id.0 as u64, - "removed guest is still a contact of another peer" - ); + if let db::Contact::Accepted { user_id, .. } = contact { + if pool.is_user_online(user_id) { + assert_ne!( + user_id, removed_guest_id, + "removed guest is still a contact of another peer" + ); + } } } } @@ -5830,7 +5817,13 @@ impl TestServer { async fn start(background: Arc) -> Self { static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - let test_db = TestDb::new(background.clone()); + let use_postgres = env::var("USE_POSTGRES").ok(); + let use_postgres = use_postgres.as_deref(); + let test_db = if use_postgres == Some("true") || use_postgres == Some("1") { + TestDb::postgres(background.clone()) + } else { + TestDb::sqlite(background.clone()) + }; let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); let live_kit_server = live_kit_client::TestServer::create( format!("http://livekit.{}.test", live_kit_server_id), @@ -5948,11 +5941,9 @@ impl TestServer { let fs = FakeFs::new(cx.background()); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx)); - let project_store = cx.add_model(|_| ProjectStore::new()); let app_state = Arc::new(workspace::AppState { client: client.clone(), user_store: user_store.clone(), - project_store: project_store.clone(), languages: Arc::new(LanguageRegistry::new(Task::ready(()))), themes: ThemeRegistry::new((), cx.font_cache()), fs: fs.clone(), @@ -5979,7 +5970,6 @@ impl TestServer { remote_projects: Default::default(), next_root_dir_id: 0, user_store, - project_store, fs, language_registry: Arc::new(LanguageRegistry::test()), buffers: Default::default(), @@ -6085,7 +6075,6 @@ struct TestClient { remote_projects: Vec>, next_root_dir_id: usize, pub user_store: ModelHandle, - pub project_store: ModelHandle, language_registry: Arc, fs: Arc, buffers: HashMap, HashSet>>, @@ -6155,7 +6144,6 @@ impl TestClient { Project::local( self.client.clone(), self.user_store.clone(), - self.project_store.clone(), self.language_registry.clone(), self.fs.clone(), cx, @@ -6183,7 +6171,6 @@ impl TestClient { host_project_id, self.client.clone(), self.user_store.clone(), - self.project_store.clone(), self.language_registry.clone(), FakeFs::new(cx.background()), cx, @@ -6319,7 +6306,6 @@ impl TestClient { remote_project_id, client.client.clone(), client.user_store.clone(), - client.project_store.clone(), client.language_registry.clone(), FakeFs::new(cx.background()), cx.to_async(), @@ -6569,7 +6555,7 @@ impl TestClient { buffers.extend(search.await?.into_keys()); } } - 60..=69 => { + 60..=79 => { let worktree = project .read_with(cx, |project, cx| { project diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 518530c539bdcffae03059732e5e9dba4401ac56..24a9fc6117ce81ea493b742c2c6f7cbd6e8ca5d4 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -1,9 +1,21 @@ +pub mod api; +pub mod auth; +pub mod db; +pub mod env; +#[cfg(test)] +mod integration_tests; +pub mod rpc; + use axum::{http::StatusCode, response::IntoResponse}; +use db::Database; +use serde::Deserialize; +use std::{path::PathBuf, sync::Arc}; pub type Result = std::result::Result; pub enum Error { Http(StatusCode, String), + Database(sea_orm::error::DbErr), Internal(anyhow::Error), } @@ -13,9 +25,9 @@ impl From for Error { } } -impl From for Error { - fn from(error: sqlx::Error) -> Self { - Self::Internal(error.into()) +impl From for Error { + fn from(error: sea_orm::error::DbErr) -> Self { + Self::Database(error) } } @@ -41,6 +53,9 @@ impl IntoResponse for Error { fn into_response(self) -> axum::response::Response { match self { Error::Http(code, message) => (code, message).into_response(), + Error::Database(error) => { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() + } Error::Internal(error) => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } @@ -52,6 +67,7 @@ impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::Http(code, message) => (code, message).fmt(f), + Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), } } @@ -61,9 +77,64 @@ impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::Http(code, message) => write!(f, "{code}: {message}"), + Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), } } } impl std::error::Error for Error {} + +#[derive(Default, Deserialize)] +pub struct Config { + pub http_port: u16, + pub database_url: String, + pub api_token: String, + pub invite_link_prefix: String, + pub live_kit_server: Option, + pub live_kit_key: Option, + pub live_kit_secret: Option, + pub rust_log: Option, + pub log_json: Option, +} + +#[derive(Default, Deserialize)] +pub struct MigrateConfig { + pub database_url: String, + pub migrations_path: Option, +} + +pub struct AppState { + pub db: Arc, + pub live_kit_client: Option>, + pub config: Config, +} + +impl AppState { + pub async fn new(config: Config) -> Result> { + let mut db_options = db::ConnectOptions::new(config.database_url.clone()); + db_options.max_connections(5); + let db = Database::new(db_options).await?; + let live_kit_client = if let Some(((server, key), secret)) = config + .live_kit_server + .as_ref() + .zip(config.live_kit_key.as_ref()) + .zip(config.live_kit_secret.as_ref()) + { + Some(Arc::new(live_kit_server::api::LiveKitClient::new( + server.clone(), + key.clone(), + secret.clone(), + )) as Arc) + } else { + None + }; + + let this = Self { + db: Arc::new(db), + live_kit_client, + config, + }; + Ok(Arc::new(this)) + } +} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index dc98a2ee6855c072f5adc9ed95dbad38626eca48..a288e0f3ce83fe8c7a0656f108f15c6088021d68 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -1,86 +1,18 @@ -mod api; -mod auth; -mod db; -mod env; -mod rpc; - -#[cfg(test)] -mod db_tests; -#[cfg(test)] -mod integration_tests; - -use crate::rpc::ResultExt as _; use anyhow::anyhow; use axum::{routing::get, Router}; -use collab::{Error, Result}; -use db::DefaultDb as Db; -use serde::Deserialize; +use collab::{db, env, AppState, Config, MigrateConfig, Result}; +use db::Database; use std::{ env::args, net::{SocketAddr, TcpListener}, - path::{Path, PathBuf}, - sync::Arc, - time::Duration, + path::Path, }; -use tokio::signal; use tracing_log::LogTracer; use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer}; use util::ResultExt; const VERSION: &'static str = env!("CARGO_PKG_VERSION"); -#[derive(Default, Deserialize)] -pub struct Config { - pub http_port: u16, - pub database_url: String, - pub api_token: String, - pub invite_link_prefix: String, - pub live_kit_server: Option, - pub live_kit_key: Option, - pub live_kit_secret: Option, - pub rust_log: Option, - pub log_json: Option, -} - -#[derive(Default, Deserialize)] -pub struct MigrateConfig { - pub database_url: String, - pub migrations_path: Option, -} - -pub struct AppState { - db: Arc, - live_kit_client: Option>, - config: Config, -} - -impl AppState { - async fn new(config: Config) -> Result> { - let db = Db::new(&config.database_url, 5).await?; - let live_kit_client = if let Some(((server, key), secret)) = config - .live_kit_server - .as_ref() - .zip(config.live_kit_key.as_ref()) - .zip(config.live_kit_secret.as_ref()) - { - Some(Arc::new(live_kit_server::api::LiveKitClient::new( - server.clone(), - key.clone(), - secret.clone(), - )) as Arc) - } else { - None - }; - - let this = Self { - db: Arc::new(db), - live_kit_client, - config, - }; - Ok(Arc::new(this)) - } -} - #[tokio::main] async fn main() -> Result<()> { if let Err(error) = env::load_dotenv() { @@ -96,7 +28,9 @@ async fn main() -> Result<()> { } Some("migrate") => { let config = envy::from_env::().expect("error loading config"); - let db = Db::new(&config.database_url, 5).await?; + let mut db_options = db::ConnectOptions::new(config.database_url.clone()); + db_options.max_connections(5); + let db = Database::new(db_options).await?; let migrations_path = config .migrations_path @@ -118,18 +52,19 @@ async fn main() -> Result<()> { init_tracing(&config); let state = AppState::new(config).await?; + state.db.clear_stale_data().await?; + let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) .expect("failed to bind TCP listener"); - let rpc_server = rpc::Server::new(state.clone()); + let rpc_server = collab::rpc::Server::new(state.clone()); - let app = api::routes(rpc_server.clone(), state.clone()) - .merge(rpc::routes(rpc_server.clone())) + let app = collab::api::routes(rpc_server.clone(), state.clone()) + .merge(collab::rpc::routes(rpc_server.clone())) .merge(Router::new().route("/", get(handle_root))); axum::Server::from_tcp(listener)? .serve(app.into_make_service_with_connect_info::()) - .with_graceful_shutdown(graceful_shutdown(rpc_server, state)) .await?; } _ => { @@ -174,52 +109,3 @@ pub fn init_tracing(config: &Config) -> Option<()> { None } - -async fn graceful_shutdown(rpc_server: Arc, state: Arc) { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } - - if let Some(live_kit) = state.live_kit_client.as_ref() { - let deletions = rpc_server - .store() - .await - .rooms() - .values() - .map(|room| { - let name = room.live_kit_room.clone(); - async { - live_kit.delete_room(name).await.trace_err(); - } - }) - .collect::>(); - - tracing::info!("deleting all live-kit rooms"); - if let Err(_) = tokio::time::timeout( - Duration::from_secs(10), - futures::future::join_all(deletions), - ) - .await - { - tracing::error!("timed out waiting for live-kit room deletion"); - } - } -} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 960b7704f1ca17284be10e402f6d571056e8d751..0136a5fec6b1326aace79dc11ea1f6f310c3b705 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,8 +1,8 @@ -mod store; +mod connection_pool; use crate::{ auth, - db::{self, ProjectId, User, UserId}, + db::{self, Database, ProjectId, RoomId, User, UserId}, AppState, Result, }; use anyhow::anyhow; @@ -23,6 +23,7 @@ use axum::{ Extension, Router, TypedHeader, }; use collections::{HashMap, HashSet}; +pub use connection_pool::ConnectionPool; use futures::{ channel::oneshot, future::{self, BoxFuture}, @@ -38,8 +39,10 @@ use rpc::{ use serde::{Serialize, Serializer}; use std::{ any::TypeId, + fmt, future::Future, marker::PhantomData, + mem, net::SocketAddr, ops::{Deref, DerefMut}, rc::Rc, @@ -49,7 +52,6 @@ use std::{ }, time::Duration, }; -pub use store::{Store, Worktree}; use tokio::{ sync::{Mutex, MutexGuard}, time::Sleep, @@ -68,10 +70,10 @@ lazy_static! { } type MessageHandler = - Box, Box) -> BoxFuture<'static, ()>>; + Box, Session) -> BoxFuture<'static, ()>>; struct Response { - server: Arc, + peer: Arc, receipt: Receipt, responded: Arc, } @@ -79,14 +81,66 @@ struct Response { impl Response { fn send(self, payload: R::Response) -> Result<()> { self.responded.store(true, SeqCst); - self.server.peer.respond(self.receipt, payload)?; + self.peer.respond(self.receipt, payload)?; Ok(()) } } +#[derive(Clone)] +struct Session { + user_id: UserId, + connection_id: ConnectionId, + db: Arc>, + peer: Arc, + connection_pool: Arc>, + live_kit_client: Option>, +} + +impl Session { + async fn db(&self) -> MutexGuard { + #[cfg(test)] + tokio::task::yield_now().await; + let guard = self.db.lock().await; + #[cfg(test)] + tokio::task::yield_now().await; + guard + } + + async fn connection_pool(&self) -> ConnectionPoolGuard<'_> { + #[cfg(test)] + tokio::task::yield_now().await; + let guard = self.connection_pool.lock().await; + #[cfg(test)] + tokio::task::yield_now().await; + ConnectionPoolGuard { + guard, + _not_send: PhantomData, + } + } +} + +impl fmt::Debug for Session { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Session") + .field("user_id", &self.user_id) + .field("connection_id", &self.connection_id) + .finish() + } +} + +struct DbHandle(Arc); + +impl Deref for DbHandle { + type Target = Database; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + pub struct Server { peer: Arc, - pub(crate) store: Mutex, + pub(crate) connection_pool: Arc>, app_state: Arc, handlers: HashMap, } @@ -100,8 +154,8 @@ pub trait Executor: Send + Clone { #[derive(Clone)] pub struct RealExecutor; -pub(crate) struct StoreGuard<'a> { - guard: MutexGuard<'a, Store>, +pub(crate) struct ConnectionPoolGuard<'a> { + guard: MutexGuard<'a, ConnectionPool>, _not_send: PhantomData>, } @@ -109,7 +163,7 @@ pub(crate) struct StoreGuard<'a> { pub struct ServerSnapshot<'a> { peer: &'a Peer, #[serde(serialize_with = "serialize_deref")] - store: StoreGuard<'a>, + connection_pool: ConnectionPoolGuard<'a>, } pub fn serialize_deref(value: &T, serializer: S) -> Result @@ -126,81 +180,79 @@ impl Server { let mut server = Self { peer: Peer::new(), app_state, - store: Default::default(), + connection_pool: Default::default(), handlers: Default::default(), }; server - .add_request_handler(Server::ping) - .add_request_handler(Server::create_room) - .add_request_handler(Server::join_room) - .add_message_handler(Server::leave_room) - .add_request_handler(Server::call) - .add_request_handler(Server::cancel_call) - .add_message_handler(Server::decline_call) - .add_request_handler(Server::update_participant_location) - .add_request_handler(Server::share_project) - .add_message_handler(Server::unshare_project) - .add_request_handler(Server::join_project) - .add_message_handler(Server::leave_project) - .add_message_handler(Server::update_project) - .add_request_handler(Server::update_worktree) - .add_message_handler(Server::start_language_server) - .add_message_handler(Server::update_language_server) - .add_message_handler(Server::update_diagnostic_summary) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler( - Server::forward_project_request::, - ) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_message_handler(Server::create_buffer_for_peer) - .add_request_handler(Server::update_buffer) - .add_message_handler(Server::update_buffer_file) - .add_message_handler(Server::buffer_reloaded) - .add_message_handler(Server::buffer_saved) - .add_request_handler(Server::save_buffer) - .add_request_handler(Server::get_users) - .add_request_handler(Server::fuzzy_search_users) - .add_request_handler(Server::request_contact) - .add_request_handler(Server::remove_contact) - .add_request_handler(Server::respond_to_contact_request) - .add_request_handler(Server::follow) - .add_message_handler(Server::unfollow) - .add_message_handler(Server::update_followers) - .add_message_handler(Server::update_diff_base) - .add_request_handler(Server::get_private_user_info); + .add_request_handler(ping) + .add_request_handler(create_room) + .add_request_handler(join_room) + .add_message_handler(leave_room) + .add_request_handler(call) + .add_request_handler(cancel_call) + .add_message_handler(decline_call) + .add_request_handler(update_participant_location) + .add_request_handler(share_project) + .add_message_handler(unshare_project) + .add_request_handler(join_project) + .add_message_handler(leave_project) + .add_request_handler(update_project) + .add_request_handler(update_worktree) + .add_message_handler(start_language_server) + .add_message_handler(update_language_server) + .add_message_handler(update_diagnostic_summary) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_message_handler(create_buffer_for_peer) + .add_request_handler(update_buffer) + .add_message_handler(update_buffer_file) + .add_message_handler(buffer_reloaded) + .add_message_handler(buffer_saved) + .add_request_handler(save_buffer) + .add_request_handler(get_users) + .add_request_handler(fuzzy_search_users) + .add_request_handler(request_contact) + .add_request_handler(remove_contact) + .add_request_handler(respond_to_contact_request) + .add_request_handler(follow) + .add_message_handler(unfollow) + .add_message_handler(update_followers) + .add_message_handler(update_diff_base) + .add_request_handler(get_private_user_info); Arc::new(server) } - fn add_message_handler(&mut self, handler: F) -> &mut Self + fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, TypedEnvelope) -> Fut, + F: 'static + Send + Sync + Fn(TypedEnvelope, Session) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |server, envelope| { + Box::new(move |envelope, session| { let envelope = envelope.into_any().downcast::>().unwrap(); let span = info_span!( "handle message", @@ -212,7 +264,7 @@ impl Server { "message received" ); }); - let future = (handler)(server, *envelope); + let future = (handler)(*envelope, session); async move { if let Err(error) = future.await { tracing::error!(%error, "error handling message"); @@ -228,26 +280,35 @@ impl Server { self } - /// Handle a request while holding a lock to the store. This is useful when we're registering - /// a connection but we want to respond on the connection before anybody else can send on it. + fn add_message_handler(&mut self, handler: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(M, Session) -> Fut, + Fut: 'static + Send + Future>, + M: EnvelopedMessage, + { + self.add_handler(move |envelope, session| handler(envelope.payload, session)); + self + } + fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, TypedEnvelope, Response) -> Fut, + F: 'static + Send + Sync + Fn(M, Response, Session) -> Fut, Fut: Send + Future>, M: RequestMessage, { let handler = Arc::new(handler); - self.add_message_handler(move |server, envelope| { + self.add_handler(move |envelope, session| { let receipt = envelope.receipt(); let handler = handler.clone(); async move { + let peer = session.peer.clone(); let responded = Arc::new(AtomicBool::default()); let response = Response { - server: server.clone(), + peer: peer.clone(), responded: responded.clone(), - receipt: envelope.receipt(), + receipt, }; - match (handler)(server.clone(), envelope, response).await { + match (handler)(envelope.payload, response, session).await { Ok(()) => { if responded.load(std::sync::atomic::Ordering::SeqCst) { Ok(()) @@ -256,7 +317,7 @@ impl Server { } } Err(error) => { - server.peer.respond_with_error( + peer.respond_with_error( receipt, proto::Error { message: error.to_string(), @@ -277,7 +338,7 @@ impl Server { mut send_connection_id: Option>, executor: E, ) -> impl Future> { - let mut this = self.clone(); + let this = self.clone(); let user_id = user.id; let login = user.github_login; let span = info_span!("handle connection", %user_id, %login, %address); @@ -313,22 +374,31 @@ impl Server { ).await?; { - let mut store = this.store().await; - let incoming_call = store.add_connection(connection_id, user_id, user.admin); - if let Some(incoming_call) = incoming_call { - this.peer.send(connection_id, incoming_call)?; - } - - this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?; + let mut pool = this.connection_pool.lock().await; + pool.add_connection(connection_id, user_id, user.admin); + this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?; if let Some((code, count)) = invite_code { this.peer.send(connection_id, proto::UpdateInviteInfo { url: format!("{}{}", this.app_state.config.invite_link_prefix, code), - count, + count: count as u32, })?; } } - this.update_user_contacts(user_id).await?; + + if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? { + this.peer.send(connection_id, incoming_call)?; + } + + let session = Session { + user_id, + connection_id, + db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))), + peer: this.peer.clone(), + connection_pool: this.connection_pool.clone(), + live_kit_client: this.app_state.live_kit_client.clone() + }; + update_user_contacts(user_id, &session).await?; let handle_io = handle_io.fuse(); futures::pin_mut!(handle_io); @@ -360,7 +430,7 @@ impl Server { let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let handle_message = (handler)(this.clone(), message); + let handle_message = (handler)(message, session.clone()); drop(span_enter); let handle_message = handle_message.instrument(span); @@ -382,7 +452,7 @@ impl Server { drop(foreground_message_handlers); tracing::info!(%user_id, %login, %connection_id, %address, "signing out"); - if let Err(error) = this.sign_out(connection_id).await { + if let Err(error) = sign_out(session).await { tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out"); } @@ -390,78 +460,6 @@ impl Server { }.instrument(span) } - #[instrument(skip(self), err)] - async fn sign_out(self: &mut Arc, connection_id: ConnectionId) -> Result<()> { - self.peer.disconnect(connection_id); - - let mut projects_to_unshare = Vec::new(); - let mut contacts_to_update = HashSet::default(); - let mut room_left = None; - { - let mut store = self.store().await; - - #[cfg(test)] - let removed_connection = store.remove_connection(connection_id).unwrap(); - #[cfg(not(test))] - let removed_connection = store.remove_connection(connection_id)?; - - for project in removed_connection.hosted_projects { - projects_to_unshare.push(project.id); - broadcast(connection_id, project.guests.keys().copied(), |conn_id| { - self.peer.send( - conn_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - }); - } - - for project in removed_connection.guest_projects { - broadcast(connection_id, project.connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::RemoveProjectCollaborator { - project_id: project.id.to_proto(), - peer_id: connection_id.0, - }, - ) - }); - } - - if let Some(room) = removed_connection.room { - self.room_updated(&room); - room_left = Some(self.room_left(&room, connection_id)); - } - - contacts_to_update.insert(removed_connection.user_id); - for connection_id in removed_connection.canceled_call_connection_ids { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - contacts_to_update.extend(store.user_id_for_connection(connection_id).ok()); - } - }; - - if let Some(room_left) = room_left { - room_left.await.trace_err(); - } - - for user_id in contacts_to_update { - self.update_user_contacts(user_id).await.trace_err(); - } - - for project_id in projects_to_unshare { - self.app_state - .db - .unregister_project(project_id) - .await - .trace_err(); - } - - Ok(()) - } - pub async fn invite_code_redeemed( self: &Arc, inviter_id: UserId, @@ -469,9 +467,9 @@ impl Server { ) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { - let store = self.store().await; - let invitee_contact = store.contact_for_user(invitee_id, true); - for connection_id in store.connection_ids_for_user(inviter_id) { + let pool = self.connection_pool.lock().await; + let invitee_contact = contact_for_user(invitee_id, true, false, &pool); + for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( connection_id, proto::UpdateContacts { @@ -495,8 +493,8 @@ impl Server { pub async fn invite_count_updated(self: &Arc, user_id: UserId) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { if let Some(invite_code) = &user.invite_code { - let store = self.store().await; - for connection_id in store.connection_ids_for_user(user_id) { + let pool = self.connection_pool.lock().await; + for connection_id in pool.user_connection_ids(user_id) { self.peer.send( connection_id, proto::UpdateInviteInfo { @@ -513,1263 +511,1220 @@ impl Server { Ok(()) } - async fn ping( - self: Arc, - _: TypedEnvelope, - response: Response, - ) -> Result<()> { - response.send(proto::Ack {})?; - Ok(()) + pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { + ServerSnapshot { + connection_pool: ConnectionPoolGuard { + guard: self.connection_pool.lock().await, + _not_send: PhantomData, + }, + peer: &self.peer, + } } +} - async fn create_room( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id; - let room; - { - let mut store = self.store().await; - user_id = store.user_id_for_connection(request.sender_id)?; - room = store.create_room(request.sender_id)?.clone(); - } +impl<'a> Deref for ConnectionPoolGuard<'a> { + type Target = ConnectionPool; - let live_kit_connection_info = - if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { - if let Some(_) = live_kit - .create_room(room.live_kit_room.clone()) - .await - .trace_err() - { - if let Some(token) = live_kit - .room_token(&room.live_kit_room, &request.sender_id.to_string()) - .trace_err() - { - Some(proto::LiveKitConnectionInfo { - server_url: live_kit.url().into(), - token, - }) - } else { - None - } - } else { - None - } - } else { - None - }; + fn deref(&self) -> &Self::Target { + &*self.guard + } +} - response.send(proto::CreateRoomResponse { - room: Some(room), - live_kit_connection_info, - })?; - self.update_user_contacts(user_id).await?; - Ok(()) +impl<'a> DerefMut for ConnectionPoolGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.guard } +} - async fn join_room( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id; - { - let mut store = self.store().await; - user_id = store.user_id_for_connection(request.sender_id)?; - let (room, recipient_connection_ids) = - store.join_room(request.payload.id, request.sender_id)?; - for recipient_id in recipient_connection_ids { - self.peer - .send(recipient_id, proto::CallCanceled {}) - .trace_err(); - } +impl<'a> Drop for ConnectionPoolGuard<'a> { + fn drop(&mut self) { + #[cfg(test)] + self.check_invariants(); + } +} - let live_kit_connection_info = - if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { - if let Some(token) = live_kit - .room_token(&room.live_kit_room, &request.sender_id.to_string()) - .trace_err() - { - Some(proto::LiveKitConnectionInfo { - server_url: live_kit.url().into(), - token, - }) - } else { - None - } - } else { - None - }; +impl Executor for RealExecutor { + type Sleep = Sleep; + + fn spawn_detached>(&self, future: F) { + tokio::task::spawn(future); + } + + fn sleep(&self, duration: Duration) -> Self::Sleep { + tokio::time::sleep(duration) + } +} - response.send(proto::JoinRoomResponse { - room: Some(room.clone()), - live_kit_connection_info, - })?; - self.room_updated(room); +fn broadcast( + sender_id: ConnectionId, + receiver_ids: impl IntoIterator, + mut f: F, +) where + F: FnMut(ConnectionId) -> anyhow::Result<()>, +{ + for receiver_id in receiver_ids { + if receiver_id != sender_id { + f(receiver_id).trace_err(); } - self.update_user_contacts(user_id).await?; - Ok(()) } +} - async fn leave_room(self: Arc, message: TypedEnvelope) -> Result<()> { - let mut contacts_to_update = HashSet::default(); - let room_left; - { - let mut store = self.store().await; - let user_id = store.user_id_for_connection(message.sender_id)?; - let left_room = store.leave_room(message.payload.id, message.sender_id)?; - contacts_to_update.insert(user_id); +lazy_static! { + static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version"); +} - for project in left_room.unshared_projects { - for connection_id in project.connection_ids() { - self.peer.send( - connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - )?; - } - } +pub struct ProtocolVersion(u32); - for project in left_room.left_projects { - if project.remove_collaborator { - for connection_id in project.connection_ids { - self.peer.send( - connection_id, - proto::RemoveProjectCollaborator { - project_id: project.id.to_proto(), - peer_id: message.sender_id.0, - }, - )?; - } +impl Header for ProtocolVersion { + fn name() -> &'static HeaderName { + &ZED_PROTOCOL_VERSION + } - self.peer.send( - message.sender_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - )?; - } - } + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let version = values + .next() + .ok_or_else(axum::headers::Error::invalid)? + .to_str() + .map_err(|_| axum::headers::Error::invalid())? + .parse() + .map_err(|_| axum::headers::Error::invalid())?; + Ok(Self(version)) + } - self.room_updated(&left_room.room); - room_left = self.room_left(&left_room.room, message.sender_id); + fn encode>(&self, values: &mut E) { + values.extend([self.0.to_string().parse().unwrap()]); + } +} - for connection_id in left_room.canceled_call_connection_ids { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - contacts_to_update.extend(store.user_id_for_connection(connection_id).ok()); - } - } +pub fn routes(server: Arc) -> Router { + Router::new() + .route("/rpc", get(handle_websocket_request)) + .layer( + ServiceBuilder::new() + .layer(Extension(server.app_state.clone())) + .layer(middleware::from_fn(auth::validate_header)), + ) + .route("/metrics", get(handle_metrics)) + .layer(Extension(server)) +} - room_left.await.trace_err(); - for user_id in contacts_to_update { - self.update_user_contacts(user_id).await?; +pub async fn handle_websocket_request( + TypedHeader(ProtocolVersion(protocol_version)): TypedHeader, + ConnectInfo(socket_address): ConnectInfo, + Extension(server): Extension>, + Extension(user): Extension, + ws: WebSocketUpgrade, +) -> axum::response::Response { + if protocol_version != rpc::PROTOCOL_VERSION { + return ( + StatusCode::UPGRADE_REQUIRED, + "client must be upgraded".to_string(), + ) + .into_response(); + } + let socket_address = socket_address.to_string(); + ws.on_upgrade(move |socket| { + use util::ResultExt; + let socket = socket + .map_ok(to_tungstenite_message) + .err_into() + .with(|message| async move { Ok(to_axum_message(message)) }); + let connection = Connection::new(Box::pin(socket)); + async move { + server + .handle_connection(connection, socket_address, user, None, RealExecutor) + .await + .log_err(); } + }) +} - Ok(()) - } +pub async fn handle_metrics(Extension(server): Extension>) -> Result { + let connections = server + .connection_pool + .lock() + .await + .connections() + .filter(|connection| !connection.admin) + .count(); - async fn call( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let caller_user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id); - let initial_project_id = request - .payload - .initial_project_id - .map(ProjectId::from_proto); - if !self - .app_state - .db - .has_contact(caller_user_id, recipient_user_id) - .await? - { - return Err(anyhow!("cannot call a user who isn't a contact"))?; - } + METRIC_CONNECTIONS.set(connections as _); - let room_id = request.payload.room_id; - let mut calls = { - let mut store = self.store().await; - let (room, recipient_connection_ids, incoming_call) = store.call( - room_id, - recipient_user_id, - initial_project_id, - request.sender_id, - )?; - self.room_updated(room); - recipient_connection_ids - .into_iter() - .map(|recipient_connection_id| { - self.peer - .request(recipient_connection_id, incoming_call.clone()) - }) - .collect::>() - }; - self.update_user_contacts(recipient_user_id).await?; + let shared_projects = server.app_state.db.project_count_excluding_admins().await?; + METRIC_SHARED_PROJECTS.set(shared_projects as _); - while let Some(call_response) = calls.next().await { - match call_response.as_ref() { - Ok(_) => { - response.send(proto::Ack {})?; - return Ok(()); - } - Err(_) => { - call_response.trace_err(); - } - } - } + let encoder = prometheus::TextEncoder::new(); + let metric_families = prometheus::gather(); + let encoded_metrics = encoder + .encode_to_string(&metric_families) + .map_err(|err| anyhow!("{}", err))?; + Ok(encoded_metrics) +} +#[instrument(err)] +async fn sign_out(session: Session) -> Result<()> { + session.peer.disconnect(session.connection_id); + let decline_calls = { + let mut pool = session.connection_pool().await; + pool.remove_connection(session.connection_id)?; + let mut connections = pool.user_connection_ids(session.user_id); + connections.next().is_none() + }; + + leave_room_for_session(&session).await.trace_err(); + if decline_calls { + if let Some(room) = session + .db() + .await + .decline_call(None, session.user_id) + .await + .trace_err() { - let mut store = self.store().await; - let room = store.call_failed(room_id, recipient_user_id)?; - self.room_updated(&room); + room_updated(&room, &session); } - self.update_user_contacts(recipient_user_id).await?; - - Err(anyhow!("failed to ring call recipient"))? } - async fn cancel_call( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id); - { - let mut store = self.store().await; - let (room, recipient_connection_ids) = store.cancel_call( - request.payload.room_id, - recipient_user_id, - request.sender_id, - )?; - for recipient_id in recipient_connection_ids { - self.peer - .send(recipient_id, proto::CallCanceled {}) - .trace_err(); - } - self.room_updated(room); - response.send(proto::Ack {})?; - } - self.update_user_contacts(recipient_user_id).await?; - Ok(()) - } - - async fn decline_call( - self: Arc, - message: TypedEnvelope, - ) -> Result<()> { - let recipient_user_id; - { - let mut store = self.store().await; - recipient_user_id = store.user_id_for_connection(message.sender_id)?; - let (room, recipient_connection_ids) = - store.decline_call(message.payload.room_id, message.sender_id)?; - for recipient_id in recipient_connection_ids { - self.peer - .send(recipient_id, proto::CallCanceled {}) - .trace_err(); - } - self.room_updated(room); - } - self.update_user_contacts(recipient_user_id).await?; - Ok(()) - } + update_user_contacts(session.user_id, &session).await?; - async fn update_participant_location( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let room_id = request.payload.room_id; - let location = request - .payload - .location - .ok_or_else(|| anyhow!("invalid location"))?; - let mut store = self.store().await; - let room = store.update_participant_location(room_id, location, request.sender_id)?; - self.room_updated(room); - response.send(proto::Ack {})?; - Ok(()) - } - - fn room_updated(&self, room: &proto::Room) { - for participant in &room.participants { - self.peer - .send( - ConnectionId(participant.peer_id), - proto::RoomUpdated { - room: Some(room.clone()), - }, - ) - .trace_err(); - } - } + Ok(()) +} - fn room_left( - &self, - room: &proto::Room, - connection_id: ConnectionId, - ) -> impl Future> { - let client = self.app_state.live_kit_client.clone(); - let room_name = room.live_kit_room.clone(); - let participant_count = room.participants.len(); - async move { - if let Some(client) = client { - client - .remove_participant(room_name.clone(), connection_id.to_string()) - .await?; +async fn ping(_: proto::Ping, response: Response, _session: Session) -> Result<()> { + response.send(proto::Ack {})?; + Ok(()) +} - if participant_count == 0 { - client.delete_room(room_name).await?; - } +async fn create_room( + _request: proto::CreateRoom, + response: Response, + session: Session, +) -> Result<()> { + let live_kit_room = nanoid::nanoid!(30); + let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { + if let Some(_) = live_kit + .create_room(live_kit_room.clone()) + .await + .trace_err() + { + if let Some(token) = live_kit + .room_token(&live_kit_room, &session.connection_id.to_string()) + .trace_err() + { + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + }) + } else { + None } - - Ok(()) + } else { + None } - } + } else { + None + }; - async fn share_project( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id = self - .store() + { + let room = session + .db() .await - .user_id_for_connection(request.sender_id)?; - let project_id = self.app_state.db.register_project(user_id).await?; - let mut store = self.store().await; - let room = store.share_project( - request.payload.room_id, - project_id, - request.payload.worktrees, - request.sender_id, - )?; - response.send(proto::ShareProjectResponse { - project_id: project_id.to_proto(), - })?; - self.room_updated(room); + .create_room(session.user_id, session.connection_id, &live_kit_room) + .await?; - Ok(()) + response.send(proto::CreateRoomResponse { + room: Some(room.clone()), + live_kit_connection_info, + })?; } - async fn unshare_project( - self: Arc, - message: TypedEnvelope, - ) -> Result<()> { - let project_id = ProjectId::from_proto(message.payload.project_id); - let mut store = self.store().await; - let (room, project) = store.unshare_project(project_id, message.sender_id)?; - broadcast( - message.sender_id, - project.guest_connection_ids(), - |conn_id| self.peer.send(conn_id, message.payload.clone()), - ); - self.room_updated(room); - - Ok(()) - } + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} - async fn update_user_contacts(self: &Arc, user_id: UserId) -> Result<()> { - let contacts = self.app_state.db.get_contacts(user_id).await?; - let store = self.store().await; - let updated_contact = store.contact_for_user(user_id, false); - for contact in contacts { - if let db::Contact::Accepted { - user_id: contact_user_id, - .. - } = contact - { - for contact_conn_id in store.connection_ids_for_user(contact_user_id) { - self.peer - .send( - contact_conn_id, - proto::UpdateContacts { - contacts: vec![updated_contact.clone()], - remove_contacts: Default::default(), - incoming_requests: Default::default(), - remove_incoming_requests: Default::default(), - outgoing_requests: Default::default(), - remove_outgoing_requests: Default::default(), - }, - ) - .trace_err(); - } - } - } - Ok(()) +async fn join_room( + request: proto::JoinRoom, + response: Response, + session: Session, +) -> Result<()> { + let room = { + let room = session + .db() + .await + .join_room( + RoomId::from_proto(request.id), + session.user_id, + session.connection_id, + ) + .await?; + room_updated(&room, &session); + room.clone() + }; + + for connection_id in session + .connection_pool() + .await + .user_connection_ids(session.user_id) + { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); } - async fn join_project( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - - let host_user_id; - let guest_user_id; - let host_connection_id; + let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { + if let Some(token) = live_kit + .room_token(&room.live_kit_room, &session.connection_id.to_string()) + .trace_err() { - let state = self.store().await; - let project = state.project(project_id)?; - host_user_id = project.host.user_id; - host_connection_id = project.host_connection_id; - guest_user_id = state.user_id_for_connection(request.sender_id)?; - }; - - tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project"); - - let mut store = self.store().await; - let (project, replica_id) = store.join_project(request.sender_id, project_id)?; - let peer_count = project.guests.len(); - let mut collaborators = Vec::with_capacity(peer_count); - collaborators.push(proto::Collaborator { - peer_id: project.host_connection_id.0, - replica_id: 0, - user_id: project.host.user_id.to_proto(), - }); - let worktrees = project - .worktrees - .iter() - .map(|(id, worktree)| proto::WorktreeMetadata { - id: *id, - root_name: worktree.root_name.clone(), - visible: worktree.visible, - abs_path: worktree.abs_path.clone(), + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, }) - .collect::>(); - - // Add all guests other than the requesting user's own connections as collaborators - for (guest_conn_id, guest) in &project.guests { - if request.sender_id != *guest_conn_id { - collaborators.push(proto::Collaborator { - peer_id: guest_conn_id.0, - replica_id: guest.replica_id as u32, - user_id: guest.user_id.to_proto(), - }); - } - } - - for conn_id in project.connection_ids() { - if conn_id != request.sender_id { - self.peer - .send( - conn_id, - proto::AddProjectCollaborator { - project_id: project_id.to_proto(), - collaborator: Some(proto::Collaborator { - peer_id: request.sender_id.0, - replica_id: replica_id as u32, - user_id: guest_user_id.to_proto(), - }), - }, - ) - .trace_err(); - } + } else { + None } + } else { + None + }; - // First, we send the metadata associated with each worktree. - response.send(proto::JoinProjectResponse { - worktrees: worktrees.clone(), - replica_id: replica_id as u32, - collaborators: collaborators.clone(), - language_servers: project.language_servers.clone(), - })?; - - for (worktree_id, worktree) in &project.worktrees { - #[cfg(any(test, feature = "test-support"))] - const MAX_CHUNK_SIZE: usize = 2; - #[cfg(not(any(test, feature = "test-support")))] - const MAX_CHUNK_SIZE: usize = 256; + response.send(proto::JoinRoomResponse { + room: Some(room), + live_kit_connection_info, + })?; - // Stream this worktree's entries. - let message = proto::UpdateWorktree { - project_id: project_id.to_proto(), - worktree_id: *worktree_id, - abs_path: worktree.abs_path.clone(), - root_name: worktree.root_name.clone(), - updated_entries: worktree.entries.values().cloned().collect(), - removed_entries: Default::default(), - scan_id: worktree.scan_id, - is_last_update: worktree.is_complete, - }; - for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { - self.peer.send(request.sender_id, update.clone())?; - } - - // Stream this worktree's diagnostics. - for summary in worktree.diagnostic_summaries.values() { - self.peer.send( - request.sender_id, - proto::UpdateDiagnosticSummary { - project_id: project_id.to_proto(), - worktree_id: *worktree_id, - summary: Some(summary.clone()), - }, - )?; - } - } + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} - for language_server in &project.language_servers { - self.peer.send( - request.sender_id, - proto::UpdateLanguageServer { - project_id: project_id.to_proto(), - language_server_id: language_server.id, - variant: Some( - proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( - proto::LspDiskBasedDiagnosticsUpdated {}, - ), - ), - }, - )?; - } +async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> { + leave_room_for_session(&session).await +} - Ok(()) +async fn call( + request: proto::Call, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let calling_user_id = session.user_id; + let calling_connection_id = session.connection_id; + let called_user_id = UserId::from_proto(request.called_user_id); + let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); + if !session + .db() + .await + .has_contact(calling_user_id, called_user_id) + .await? + { + return Err(anyhow!("cannot call a user who isn't a contact"))?; } - async fn leave_project( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let sender_id = request.sender_id; - let project_id = ProjectId::from_proto(request.payload.project_id); - let project; - { - let mut store = self.store().await; - project = store.leave_project(project_id, sender_id)?; - tracing::info!( - %project_id, - host_user_id = %project.host_user_id, - host_connection_id = %project.host_connection_id, - "leave project" - ); - - if project.remove_collaborator { - broadcast(sender_id, project.connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::RemoveProjectCollaborator { - project_id: project_id.to_proto(), - peer_id: sender_id.0, - }, - ) - }); + let incoming_call = { + let (room, incoming_call) = &mut *session + .db() + .await + .call( + room_id, + calling_user_id, + calling_connection_id, + called_user_id, + initial_project_id, + ) + .await?; + room_updated(&room, &session); + mem::take(incoming_call) + }; + update_user_contacts(called_user_id, &session).await?; + + let mut calls = session + .connection_pool() + .await + .user_connection_ids(called_user_id) + .map(|connection_id| session.peer.request(connection_id, incoming_call.clone())) + .collect::>(); + + while let Some(call_response) = calls.next().await { + match call_response.as_ref() { + Ok(_) => { + response.send(proto::Ack {})?; + return Ok(()); + } + Err(_) => { + call_response.trace_err(); } } - - Ok(()) - } - - async fn update_project( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - { - let mut state = self.store().await; - let guest_connection_ids = state - .read_project(project_id, request.sender_id)? - .guest_connection_ids(); - let room = - state.update_project(project_id, &request.payload.worktrees, request.sender_id)?; - broadcast(request.sender_id, guest_connection_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - self.room_updated(room); - }; - - Ok(()) } - async fn update_worktree( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let worktree_id = request.payload.worktree_id; - let connection_ids = self.store().await.update_worktree( - request.sender_id, - project_id, - worktree_id, - &request.payload.root_name, - &request.payload.abs_path, - &request.payload.removed_entries, - &request.payload.updated_entries, - request.payload.scan_id, - request.payload.is_last_update, - )?; - - broadcast(request.sender_id, connection_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - response.send(proto::Ack {})?; - Ok(()) + { + let room = session + .db() + .await + .call_failed(room_id, called_user_id) + .await?; + room_updated(&room, &session); } + update_user_contacts(called_user_id, &session).await?; - async fn update_diagnostic_summary( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let summary = request - .payload - .summary - .clone() - .ok_or_else(|| anyhow!("invalid summary"))?; - let receiver_ids = self.store().await.update_diagnostic_summary( - ProjectId::from_proto(request.payload.project_id), - request.payload.worktree_id, - request.sender_id, - summary, - )?; + Err(anyhow!("failed to ring user"))? +} - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - Ok(()) +async fn cancel_call( + request: proto::CancelCall, + response: Response, + session: Session, +) -> Result<()> { + let called_user_id = UserId::from_proto(request.called_user_id); + let room_id = RoomId::from_proto(request.room_id); + { + let room = session + .db() + .await + .cancel_call(Some(room_id), session.connection_id, called_user_id) + .await?; + room_updated(&room, &session); } - async fn start_language_server( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let receiver_ids = self.store().await.start_language_server( - ProjectId::from_proto(request.payload.project_id), - request.sender_id, - request - .payload - .server - .clone() - .ok_or_else(|| anyhow!("invalid language server"))?, - )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - Ok(()) + for connection_id in session + .connection_pool() + .await + .user_connection_ids(called_user_id) + { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); } + response.send(proto::Ack {})?; - async fn update_language_server( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_id, - )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - Ok(()) - } + update_user_contacts(called_user_id, &session).await?; + Ok(()) +} - async fn forward_project_request( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> - where - T: EntityMessage + RequestMessage, +async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { + let room_id = RoomId::from_proto(message.room_id); { - let project_id = ProjectId::from_proto(request.payload.remote_entity_id()); - let host_connection_id = self - .store() + let room = session + .db() .await - .read_project(project_id, request.sender_id)? - .host_connection_id; - let payload = self - .peer - .forward_request(request.sender_id, host_connection_id, request.payload) + .decline_call(Some(room_id), session.user_id) .await?; - - // Ensure project still exists by the time we get the response from the host. - self.store() - .await - .read_project(project_id, request.sender_id)?; - - response.send(payload)?; - Ok(()) + room_updated(&room, &session); } - async fn save_buffer( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let host = self - .store() - .await - .read_project(project_id, request.sender_id)? - .host_connection_id; - let response_payload = self + for connection_id in session + .connection_pool() + .await + .user_connection_ids(session.user_id) + { + session .peer - .forward_request(request.sender_id, host, request.payload.clone()) - .await?; - - let mut guests = self - .store() - .await - .read_project(project_id, request.sender_id)? - .connection_ids(); - guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id); - broadcast(host, guests, |conn_id| { - self.peer - .forward_send(host, conn_id, response_payload.clone()) - }); - response.send(response_payload)?; - Ok(()) + .send(connection_id, proto::CallCanceled {}) + .trace_err(); } + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} - async fn create_buffer_for_peer( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - self.peer.forward_send( - request.sender_id, - ConnectionId(request.payload.peer_id), - request.payload, - )?; - Ok(()) - } +async fn update_participant_location( + request: proto::UpdateParticipantLocation, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let location = request + .location + .ok_or_else(|| anyhow!("invalid location"))?; + let room = session + .db() + .await + .update_room_participant_location(room_id, session.connection_id, location) + .await?; + room_updated(&room, &session); + response.send(proto::Ack {})?; + Ok(()) +} - async fn update_buffer( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let receiver_ids = { - let store = self.store().await; - store.project_connection_ids(project_id, request.sender_id)? - }; +async fn share_project( + request: proto::ShareProject, + response: Response, + session: Session, +) -> Result<()> { + let (project_id, room) = &*session + .db() + .await + .share_project( + RoomId::from_proto(request.room_id), + session.connection_id, + &request.worktrees, + ) + .await?; + response.send(proto::ShareProjectResponse { + project_id: project_id.to_proto(), + })?; + room_updated(&room, &session); - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - response.send(proto::Ack {})?; - Ok(()) - } + Ok(()) +} - async fn update_buffer_file( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_id, - )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - Ok(()) - } +async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(message.project_id); - async fn buffer_reloaded( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_id, - )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - Ok(()) - } + let (room, guest_connection_ids) = &*session + .db() + .await + .unshare_project(project_id, session.connection_id) + .await?; - async fn buffer_saved( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_id, - )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - Ok(()) - } + broadcast( + session.connection_id, + guest_connection_ids.iter().copied(), + |conn_id| session.peer.send(conn_id, message.clone()), + ); + room_updated(&room, &session); - async fn follow( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let leader_id = ConnectionId(request.payload.leader_id); - let follower_id = request.sender_id; - { - let store = self.store().await; - if !store - .project_connection_ids(project_id, follower_id)? - .contains(&leader_id) - { - Err(anyhow!("no such peer"))?; - } - } + Ok(()) +} - let mut response_payload = self +async fn join_project( + request: proto::JoinProject, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let guest_user_id = session.user_id; + + tracing::info!(%project_id, "join project"); + + let (project, replica_id) = &mut *session + .db() + .await + .join_project(project_id, session.connection_id) + .await?; + + let collaborators = project + .collaborators + .iter() + .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32) + .map(|collaborator| proto::Collaborator { + peer_id: collaborator.connection_id as u32, + replica_id: collaborator.replica_id.0 as u32, + user_id: collaborator.user_id.to_proto(), + }) + .collect::>(); + let worktrees = project + .worktrees + .iter() + .map(|(id, worktree)| proto::WorktreeMetadata { + id: *id, + root_name: worktree.root_name.clone(), + visible: worktree.visible, + abs_path: worktree.abs_path.clone(), + }) + .collect::>(); + + for collaborator in &collaborators { + session .peer - .forward_request(request.sender_id, leader_id, request.payload) - .await?; - response_payload - .views - .retain(|view| view.leader_id != Some(follower_id.0)); - response.send(response_payload)?; - Ok(()) + .send( + ConnectionId(collaborator.peer_id), + proto::AddProjectCollaborator { + project_id: project_id.to_proto(), + collaborator: Some(proto::Collaborator { + peer_id: session.connection_id.0, + replica_id: replica_id.0 as u32, + user_id: guest_user_id.to_proto(), + }), + }, + ) + .trace_err(); } - async fn unfollow(self: Arc, request: TypedEnvelope) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let leader_id = ConnectionId(request.payload.leader_id); - let store = self.store().await; - if !store - .project_connection_ids(project_id, request.sender_id)? - .contains(&leader_id) - { - Err(anyhow!("no such peer"))?; + // First, we send the metadata associated with each worktree. + response.send(proto::JoinProjectResponse { + worktrees: worktrees.clone(), + replica_id: replica_id.0 as u32, + collaborators: collaborators.clone(), + language_servers: project.language_servers.clone(), + })?; + + for (worktree_id, worktree) in mem::take(&mut project.worktrees) { + #[cfg(any(test, feature = "test-support"))] + const MAX_CHUNK_SIZE: usize = 2; + #[cfg(not(any(test, feature = "test-support")))] + const MAX_CHUNK_SIZE: usize = 256; + + // Stream this worktree's entries. + let message = proto::UpdateWorktree { + project_id: project_id.to_proto(), + worktree_id, + abs_path: worktree.abs_path.clone(), + root_name: worktree.root_name, + updated_entries: worktree.entries, + removed_entries: Default::default(), + scan_id: worktree.scan_id, + is_last_update: worktree.is_complete, + }; + for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { + session.peer.send(session.connection_id, update.clone())?; } - self.peer - .forward_send(request.sender_id, leader_id, request.payload)?; - Ok(()) - } - async fn update_followers( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let store = self.store().await; - let connection_ids = store.project_connection_ids(project_id, request.sender_id)?; - let leader_id = request - .payload - .variant - .as_ref() - .and_then(|variant| match variant { - proto::update_followers::Variant::CreateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, - }); - for follower_id in &request.payload.follower_ids { - let follower_id = ConnectionId(*follower_id); - if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { - self.peer - .forward_send(request.sender_id, follower_id, request.payload.clone())?; - } + // Stream this worktree's diagnostics. + for summary in worktree.diagnostic_summaries { + session.peer.send( + session.connection_id, + proto::UpdateDiagnosticSummary { + project_id: project_id.to_proto(), + worktree_id: worktree.id, + summary: Some(summary), + }, + )?; } - Ok(()) } - async fn get_users( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_ids = request - .payload - .user_ids - .into_iter() - .map(UserId::from_proto) - .collect(); - let users = self - .app_state - .db - .get_users_by_ids(user_ids) - .await? - .into_iter() - .map(|user| proto::User { - id: user.id.to_proto(), - avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), - github_login: user.github_login, - }) - .collect(); - response.send(proto::UsersResponse { users })?; - Ok(()) + for language_server in &project.language_servers { + session.peer.send( + session.connection_id, + proto::UpdateLanguageServer { + project_id: project_id.to_proto(), + language_server_id: language_server.id, + variant: Some( + proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( + proto::LspDiskBasedDiagnosticsUpdated {}, + ), + ), + }, + )?; } - async fn fuzzy_search_users( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let query = request.payload.query; - let db = &self.app_state.db; - let users = match query.len() { - 0 => vec![], - 1 | 2 => db - .get_user_by_github_account(&query, None) - .await? - .into_iter() - .collect(), - _ => db.fuzzy_search_users(&query, 10).await?, - }; - let users = users - .into_iter() - .filter(|user| user.id != user_id) - .map(|user| proto::User { - id: user.id.to_proto(), - avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), - github_login: user.github_login, - }) - .collect(); - response.send(proto::UsersResponse { users })?; - Ok(()) - } + Ok(()) +} - async fn request_contact( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let requester_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let responder_id = UserId::from_proto(request.payload.responder_id); - if requester_id == responder_id { - return Err(anyhow!("cannot add yourself as a contact"))?; - } +async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { + let sender_id = session.connection_id; + let project_id = ProjectId::from_proto(request.project_id); + + let project = session + .db() + .await + .leave_project(project_id, sender_id) + .await?; + tracing::info!( + %project_id, + host_user_id = %project.host_user_id, + host_connection_id = %project.host_connection_id, + "leave project" + ); + + broadcast( + sender_id, + project.connection_ids.iter().copied(), + |conn_id| { + session.peer.send( + conn_id, + proto::RemoveProjectCollaborator { + project_id: project_id.to_proto(), + peer_id: sender_id.0, + }, + ) + }, + ); - self.app_state - .db - .send_contact_request(requester_id, responder_id) - .await?; + Ok(()) +} - // Update outgoing contact requests of requester - let mut update = proto::UpdateContacts::default(); - update.outgoing_requests.push(responder_id.to_proto()); - for connection_id in self.store().await.connection_ids_for_user(requester_id) { - self.peer.send(connection_id, update.clone())?; - } +async fn update_project( + request: proto::UpdateProject, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let (room, guest_connection_ids) = &*session + .db() + .await + .update_project(project_id, session.connection_id, &request.worktrees) + .await?; + broadcast( + session.connection_id, + guest_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + room_updated(&room, &session); + response.send(proto::Ack {})?; - // Update incoming contact requests of responder - let mut update = proto::UpdateContacts::default(); - update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: requester_id.to_proto(), - should_notify: true, - }); - for connection_id in self.store().await.connection_ids_for_user(responder_id) { - self.peer.send(connection_id, update.clone())?; - } + Ok(()) +} - response.send(proto::Ack {})?; - Ok(()) - } +async fn update_worktree( + request: proto::UpdateWorktree, + response: Response, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .update_worktree(&request, session.connection_id) + .await?; + + broadcast( + session.connection_id, + guest_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + response.send(proto::Ack {})?; + Ok(()) +} - async fn respond_to_contact_request( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let responder_id = self - .store() +async fn update_diagnostic_summary( + message: proto::UpdateDiagnosticSummary, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .update_diagnostic_summary(&message, session.connection_id) + .await?; + + broadcast( + session.connection_id, + guest_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, message.clone()) + }, + ); + + Ok(()) +} + +async fn start_language_server( + request: proto::StartLanguageServer, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .start_language_server(&request, session.connection_id) + .await?; + + broadcast( + session.connection_id, + guest_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn update_language_server( + request: proto::UpdateLanguageServer, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn forward_project_request( + request: T, + response: Response, + session: Session, +) -> Result<()> +where + T: EntityMessage + RequestMessage, +{ + let project_id = ProjectId::from_proto(request.remote_entity_id()); + let host_connection_id = { + let collaborators = session + .db() .await - .user_id_for_connection(request.sender_id)?; - let requester_id = UserId::from_proto(request.payload.requester_id); - if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 { - self.app_state - .db - .dismiss_contact_notification(responder_id, requester_id) - .await?; - } else { - let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32; - self.app_state - .db - .respond_to_contact_request(responder_id, requester_id, accept) - .await?; - - let store = self.store().await; - // Update responder with new contact - let mut update = proto::UpdateContacts::default(); - if accept { - update - .contacts - .push(store.contact_for_user(requester_id, false)); - } - update - .remove_incoming_requests - .push(requester_id.to_proto()); - for connection_id in store.connection_ids_for_user(responder_id) { - self.peer.send(connection_id, update.clone())?; - } + .project_collaborators(project_id, session.connection_id) + .await?; + ConnectionId( + collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))? + .connection_id as u32, + ) + }; - // Update requester with new contact - let mut update = proto::UpdateContacts::default(); - if accept { - update - .contacts - .push(store.contact_for_user(responder_id, true)); - } - update - .remove_outgoing_requests - .push(responder_id.to_proto()); - for connection_id in store.connection_ids_for_user(requester_id) { - self.peer.send(connection_id, update.clone())?; - } - } + let payload = session + .peer + .forward_request(session.connection_id, host_connection_id, request) + .await?; - response.send(proto::Ack {})?; - Ok(()) - } + response.send(payload)?; + Ok(()) +} - async fn remove_contact( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let requester_id = self - .store() +async fn save_buffer( + request: proto::SaveBuffer, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let host_connection_id = { + let collaborators = session + .db() .await - .user_id_for_connection(request.sender_id)?; - let responder_id = UserId::from_proto(request.payload.user_id); - self.app_state - .db - .remove_contact(requester_id, responder_id) + .project_collaborators(project_id, session.connection_id) .await?; + let host = collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + ConnectionId(host.connection_id as u32) + }; + let response_payload = session + .peer + .forward_request(session.connection_id, host_connection_id, request.clone()) + .await?; + + let mut collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + collaborators + .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); + let project_connection_ids = collaborators + .iter() + .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); + broadcast(host_connection_id, project_connection_ids, |conn_id| { + session + .peer + .forward_send(host_connection_id, conn_id, response_payload.clone()) + }); + response.send(response_payload)?; + Ok(()) +} - // Update outgoing contact requests of requester - let mut update = proto::UpdateContacts::default(); - update - .remove_outgoing_requests - .push(responder_id.to_proto()); - for connection_id in self.store().await.connection_ids_for_user(requester_id) { - self.peer.send(connection_id, update.clone())?; - } +async fn create_buffer_for_peer( + request: proto::CreateBufferForPeer, + session: Session, +) -> Result<()> { + session.peer.forward_send( + session.connection_id, + ConnectionId(request.peer_id), + request, + )?; + Ok(()) +} - // Update incoming contact requests of responder - let mut update = proto::UpdateContacts::default(); - update - .remove_incoming_requests - .push(requester_id.to_proto()); - for connection_id in self.store().await.connection_ids_for_user(responder_id) { - self.peer.send(connection_id, update.clone())?; - } +async fn update_buffer( + request: proto::UpdateBuffer, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + broadcast( + session.connection_id, + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + response.send(proto::Ack {})?; + Ok(()) +} - response.send(proto::Ack {})?; - Ok(()) - } +async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + broadcast( + session.connection_id, + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} - async fn update_diff_base( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_id, - )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); - Ok(()) - } +async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} - async fn get_private_user_info( - self: Arc, - request: TypedEnvelope, - response: Response, - ) -> Result<()> { - let user_id = self - .store() +async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn follow( + request: proto::Follow, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); + let follower_id = session.connection_id; + { + let project_connection_ids = session + .db() .await - .user_id_for_connection(request.sender_id)?; - let metrics_id = self.app_state.db.get_user_metrics_id(user_id).await?; - let user = self - .app_state - .db - .get_user_by_id(user_id) - .await? - .ok_or_else(|| anyhow!("user not found"))?; - response.send(proto::GetPrivateUserInfoResponse { - metrics_id, - staff: user.admin, - })?; - Ok(()) - } + .project_connection_ids(project_id, session.connection_id) + .await?; - pub(crate) async fn store(&self) -> StoreGuard<'_> { - #[cfg(test)] - tokio::task::yield_now().await; - let guard = self.store.lock().await; - #[cfg(test)] - tokio::task::yield_now().await; - StoreGuard { - guard, - _not_send: PhantomData, + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; } } - pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { - ServerSnapshot { - store: self.store().await, - peer: &self.peer, - } - } + let mut response_payload = session + .peer + .forward_request(session.connection_id, leader_id, request) + .await?; + response_payload + .views + .retain(|view| view.leader_id != Some(follower_id.0)); + response.send(response_payload)?; + Ok(()) } -impl<'a> Deref for StoreGuard<'a> { - type Target = Store; - - fn deref(&self) -> &Self::Target { - &*self.guard +async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; } + session + .peer + .forward_send(session.connection_id, leader_id, request)?; + Ok(()) } -impl<'a> DerefMut for StoreGuard<'a> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut *self.guard +async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db + .lock() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + let leader_id = request.variant.as_ref().and_then(|variant| match variant { + proto::update_followers::Variant::CreateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, + }); + for follower_id in &request.follower_ids { + let follower_id = ConnectionId(*follower_id); + if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { + session + .peer + .forward_send(session.connection_id, follower_id, request.clone())?; + } } + Ok(()) } -impl<'a> Drop for StoreGuard<'a> { - fn drop(&mut self) { - #[cfg(test)] - self.check_invariants(); - } +async fn get_users( + request: proto::GetUsers, + response: Response, + session: Session, +) -> Result<()> { + let user_ids = request + .user_ids + .into_iter() + .map(UserId::from_proto) + .collect(); + let users = session + .db() + .await + .get_users_by_ids(user_ids) + .await? + .into_iter() + .map(|user| proto::User { + id: user.id.to_proto(), + avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), + github_login: user.github_login, + }) + .collect(); + response.send(proto::UsersResponse { users })?; + Ok(()) } -impl Executor for RealExecutor { - type Sleep = Sleep; +async fn fuzzy_search_users( + request: proto::FuzzySearchUsers, + response: Response, + session: Session, +) -> Result<()> { + let query = request.query; + let users = match query.len() { + 0 => vec![], + 1 | 2 => session + .db() + .await + .get_user_by_github_account(&query, None) + .await? + .into_iter() + .collect(), + _ => session.db().await.fuzzy_search_users(&query, 10).await?, + }; + let users = users + .into_iter() + .filter(|user| user.id != session.user_id) + .map(|user| proto::User { + id: user.id.to_proto(), + avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), + github_login: user.github_login, + }) + .collect(); + response.send(proto::UsersResponse { users })?; + Ok(()) +} - fn spawn_detached>(&self, future: F) { - tokio::task::spawn(future); +async fn request_contact( + request: proto::RequestContact, + response: Response, + session: Session, +) -> Result<()> { + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.responder_id); + if requester_id == responder_id { + return Err(anyhow!("cannot add yourself as a contact"))?; } - fn sleep(&self, duration: Duration) -> Self::Sleep { - tokio::time::sleep(duration) + session + .db() + .await + .send_contact_request(requester_id, responder_id) + .await?; + + // Update outgoing contact requests of requester + let mut update = proto::UpdateContacts::default(); + update.outgoing_requests.push(responder_id.to_proto()); + for connection_id in session + .connection_pool() + .await + .user_connection_ids(requester_id) + { + session.peer.send(connection_id, update.clone())?; } -} -fn broadcast( - sender_id: ConnectionId, - receiver_ids: impl IntoIterator, - mut f: F, -) where - F: FnMut(ConnectionId) -> anyhow::Result<()>, -{ - for receiver_id in receiver_ids { - if receiver_id != sender_id { - f(receiver_id).trace_err(); - } + // Update incoming contact requests of responder + let mut update = proto::UpdateContacts::default(); + update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: requester_id.to_proto(), + should_notify: true, + }); + for connection_id in session + .connection_pool() + .await + .user_connection_ids(responder_id) + { + session.peer.send(connection_id, update.clone())?; } -} -lazy_static! { - static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version"); + response.send(proto::Ack {})?; + Ok(()) } -pub struct ProtocolVersion(u32); +async fn respond_to_contact_request( + request: proto::RespondToContactRequest, + response: Response, + session: Session, +) -> Result<()> { + let responder_id = session.user_id; + let requester_id = UserId::from_proto(request.requester_id); + let db = session.db().await; + if request.response == proto::ContactRequestResponse::Dismiss as i32 { + db.dismiss_contact_notification(responder_id, requester_id) + .await?; + } else { + let accept = request.response == proto::ContactRequestResponse::Accept as i32; -impl Header for ProtocolVersion { - fn name() -> &'static HeaderName { - &ZED_PROTOCOL_VERSION - } + db.respond_to_contact_request(responder_id, requester_id, accept) + .await?; + let busy = db.is_user_busy(requester_id).await?; - fn decode<'i, I>(values: &mut I) -> Result - where - Self: Sized, - I: Iterator, - { - let version = values - .next() - .ok_or_else(axum::headers::Error::invalid)? - .to_str() - .map_err(|_| axum::headers::Error::invalid())? - .parse() - .map_err(|_| axum::headers::Error::invalid())?; - Ok(Self(version)) - } + let pool = session.connection_pool().await; + // Update responder with new contact + let mut update = proto::UpdateContacts::default(); + if accept { + update + .contacts + .push(contact_for_user(requester_id, false, busy, &pool)); + } + update + .remove_incoming_requests + .push(requester_id.to_proto()); + for connection_id in pool.user_connection_ids(responder_id) { + session.peer.send(connection_id, update.clone())?; + } - fn encode>(&self, values: &mut E) { - values.extend([self.0.to_string().parse().unwrap()]); + // Update requester with new contact + let mut update = proto::UpdateContacts::default(); + if accept { + update + .contacts + .push(contact_for_user(responder_id, true, busy, &pool)); + } + update + .remove_outgoing_requests + .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { + session.peer.send(connection_id, update.clone())?; + } } -} -pub fn routes(server: Arc) -> Router { - Router::new() - .route("/rpc", get(handle_websocket_request)) - .layer( - ServiceBuilder::new() - .layer(Extension(server.app_state.clone())) - .layer(middleware::from_fn(auth::validate_header)), - ) - .route("/metrics", get(handle_metrics)) - .layer(Extension(server)) + response.send(proto::Ack {})?; + Ok(()) } -pub async fn handle_websocket_request( - TypedHeader(ProtocolVersion(protocol_version)): TypedHeader, - ConnectInfo(socket_address): ConnectInfo, - Extension(server): Extension>, - Extension(user): Extension, - ws: WebSocketUpgrade, -) -> axum::response::Response { - if protocol_version != rpc::PROTOCOL_VERSION { - return ( - StatusCode::UPGRADE_REQUIRED, - "client must be upgraded".to_string(), - ) - .into_response(); +async fn remove_contact( + request: proto::RemoveContact, + response: Response, + session: Session, +) -> Result<()> { + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.user_id); + let db = session.db().await; + db.remove_contact(requester_id, responder_id).await?; + + let pool = session.connection_pool().await; + // Update outgoing contact requests of requester + let mut update = proto::UpdateContacts::default(); + update + .remove_outgoing_requests + .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { + session.peer.send(connection_id, update.clone())?; } - let socket_address = socket_address.to_string(); - ws.on_upgrade(move |socket| { - use util::ResultExt; - let socket = socket - .map_ok(to_tungstenite_message) - .err_into() - .with(|message| async move { Ok(to_axum_message(message)) }); - let connection = Connection::new(Box::pin(socket)); - async move { - server - .handle_connection(connection, socket_address, user, None, RealExecutor) - .await - .log_err(); - } - }) + + // Update incoming contact requests of responder + let mut update = proto::UpdateContacts::default(); + update + .remove_incoming_requests + .push(requester_id.to_proto()); + for connection_id in pool.user_connection_ids(responder_id) { + session.peer.send(connection_id, update.clone())?; + } + + response.send(proto::Ack {})?; + Ok(()) } -pub async fn handle_metrics(Extension(server): Extension>) -> axum::response::Response { - let metrics = server.store().await.metrics(); - METRIC_CONNECTIONS.set(metrics.connections as _); - METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _); +async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids.iter().copied(), + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} - let encoder = prometheus::TextEncoder::new(); - let metric_families = prometheus::gather(); - match encoder.encode_to_string(&metric_families) { - Ok(string) => (StatusCode::OK, string).into_response(), - Err(error) => ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("failed to encode metrics {:?}", error), - ) - .into_response(), - } +async fn get_private_user_info( + _request: proto::GetPrivateUserInfo, + response: Response, + session: Session, +) -> Result<()> { + let metrics_id = session + .db() + .await + .get_user_metrics_id(session.user_id) + .await?; + let user = session + .db() + .await + .get_user_by_id(session.user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + response.send(proto::GetPrivateUserInfoResponse { + metrics_id, + staff: user.admin, + })?; + Ok(()) } fn to_axum_message(message: TungsteniteMessage) -> AxumMessage { @@ -1800,6 +1755,185 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage { } } +fn build_initial_contacts_update( + contacts: Vec, + pool: &ConnectionPool, +) -> proto::UpdateContacts { + let mut update = proto::UpdateContacts::default(); + + for contact in contacts { + match contact { + db::Contact::Accepted { + user_id, + should_notify, + busy, + } => { + update + .contacts + .push(contact_for_user(user_id, should_notify, busy, &pool)); + } + db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()), + db::Contact::Incoming { + user_id, + should_notify, + } => update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: user_id.to_proto(), + should_notify, + }), + } + } + + update +} + +fn contact_for_user( + user_id: UserId, + should_notify: bool, + busy: bool, + pool: &ConnectionPool, +) -> proto::Contact { + proto::Contact { + user_id: user_id.to_proto(), + online: pool.is_user_online(user_id), + busy, + should_notify, + } +} + +fn room_updated(room: &proto::Room, session: &Session) { + for participant in &room.participants { + session + .peer + .send( + ConnectionId(participant.peer_id), + proto::RoomUpdated { + room: Some(room.clone()), + }, + ) + .trace_err(); + } +} + +async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> { + let db = session.db().await; + let contacts = db.get_contacts(user_id).await?; + let busy = db.is_user_busy(user_id).await?; + + let pool = session.connection_pool().await; + let updated_contact = contact_for_user(user_id, false, busy, &pool); + for contact in contacts { + if let db::Contact::Accepted { + user_id: contact_user_id, + .. + } = contact + { + for contact_conn_id in pool.user_connection_ids(contact_user_id) { + session + .peer + .send( + contact_conn_id, + proto::UpdateContacts { + contacts: vec![updated_contact.clone()], + remove_contacts: Default::default(), + incoming_requests: Default::default(), + remove_incoming_requests: Default::default(), + outgoing_requests: Default::default(), + remove_outgoing_requests: Default::default(), + }, + ) + .trace_err(); + } + } + } + Ok(()) +} + +async fn leave_room_for_session(session: &Session) -> Result<()> { + let mut contacts_to_update = HashSet::default(); + + let canceled_calls_to_user_ids; + let live_kit_room; + let delete_live_kit_room; + { + let mut left_room = session.db().await.leave_room(session.connection_id).await?; + contacts_to_update.insert(session.user_id); + + for project in left_room.left_projects.values() { + for connection_id in &project.connection_ids { + if project.host_user_id == session.user_id { + session + .peer + .send( + *connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); + } else { + session + .peer + .send( + *connection_id, + proto::RemoveProjectCollaborator { + project_id: project.id.to_proto(), + peer_id: session.connection_id.0, + }, + ) + .trace_err(); + } + } + + session + .peer + .send( + session.connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); + } + + room_updated(&left_room.room, &session); + canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids); + live_kit_room = mem::take(&mut left_room.room.live_kit_room); + delete_live_kit_room = left_room.room.participants.is_empty(); + } + + { + let pool = session.connection_pool().await; + for canceled_user_id in canceled_calls_to_user_ids { + for connection_id in pool.user_connection_ids(canceled_user_id) { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); + } + contacts_to_update.insert(canceled_user_id); + } + } + + for contact_user_id in contacts_to_update { + update_user_contacts(contact_user_id, &session).await?; + } + + if let Some(live_kit) = session.live_kit_client.as_ref() { + live_kit + .remove_participant(live_kit_room.clone(), session.connection_id.to_string()) + .await + .trace_err(); + + if delete_live_kit_room { + live_kit.delete_room(live_kit_room).await.trace_err(); + } + } + + Ok(()) +} + pub trait ResultExt { type Ok; diff --git a/crates/collab/src/rpc/connection_pool.rs b/crates/collab/src/rpc/connection_pool.rs new file mode 100644 index 0000000000000000000000000000000000000000..ac7632f7da2ae6d4d6beb95aeb298d8e409f8d80 --- /dev/null +++ b/crates/collab/src/rpc/connection_pool.rs @@ -0,0 +1,93 @@ +use crate::db::UserId; +use anyhow::{anyhow, Result}; +use collections::{BTreeMap, HashSet}; +use rpc::ConnectionId; +use serde::Serialize; +use tracing::instrument; + +#[derive(Default, Serialize)] +pub struct ConnectionPool { + connections: BTreeMap, + connected_users: BTreeMap, +} + +#[derive(Default, Serialize)] +struct ConnectedUser { + connection_ids: HashSet, +} + +#[derive(Serialize)] +pub struct Connection { + pub user_id: UserId, + pub admin: bool, +} + +impl ConnectionPool { + #[instrument(skip(self))] + pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) { + self.connections + .insert(connection_id, Connection { user_id, admin }); + let connected_user = self.connected_users.entry(user_id).or_default(); + connected_user.connection_ids.insert(connection_id); + } + + #[instrument(skip(self))] + pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> { + let connection = self + .connections + .get_mut(&connection_id) + .ok_or_else(|| anyhow!("no such connection"))?; + + let user_id = connection.user_id; + let connected_user = self.connected_users.get_mut(&user_id).unwrap(); + connected_user.connection_ids.remove(&connection_id); + if connected_user.connection_ids.is_empty() { + self.connected_users.remove(&user_id); + } + self.connections.remove(&connection_id).unwrap(); + Ok(()) + } + + pub fn connections(&self) -> impl Iterator { + self.connections.values() + } + + pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator + '_ { + self.connected_users + .get(&user_id) + .into_iter() + .map(|state| &state.connection_ids) + .flatten() + .copied() + } + + pub fn is_user_online(&self, user_id: UserId) -> bool { + !self + .connected_users + .get(&user_id) + .unwrap_or(&Default::default()) + .connection_ids + .is_empty() + } + + #[cfg(test)] + pub fn check_invariants(&self) { + for (connection_id, connection) in &self.connections { + assert!(self + .connected_users + .get(&connection.user_id) + .unwrap() + .connection_ids + .contains(connection_id)); + } + + for (user_id, state) in &self.connected_users { + for connection_id in &state.connection_ids { + assert_eq!( + self.connections.get(connection_id).unwrap().user_id, + *user_id + ); + } + } + } +} diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs deleted file mode 100644 index 0cd2818eacdf05494a3b7470dd623c2b4cdac296..0000000000000000000000000000000000000000 --- a/crates/collab/src/rpc/store.rs +++ /dev/null @@ -1,1182 +0,0 @@ -use crate::db::{self, ProjectId, UserId}; -use anyhow::{anyhow, Result}; -use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet}; -use nanoid::nanoid; -use rpc::{proto, ConnectionId}; -use serde::Serialize; -use std::{borrow::Cow, mem, path::PathBuf, str}; -use tracing::instrument; -use util::post_inc; - -pub type RoomId = u64; - -#[derive(Default, Serialize)] -pub struct Store { - connections: BTreeMap, - connected_users: BTreeMap, - next_room_id: RoomId, - rooms: BTreeMap, - projects: BTreeMap, -} - -#[derive(Default, Serialize)] -struct ConnectedUser { - connection_ids: HashSet, - active_call: Option, -} - -#[derive(Serialize)] -struct ConnectionState { - user_id: UserId, - admin: bool, - projects: BTreeSet, -} - -#[derive(Copy, Clone, Eq, PartialEq, Serialize)] -pub struct Call { - pub caller_user_id: UserId, - pub room_id: RoomId, - pub connection_id: Option, - pub initial_project_id: Option, -} - -#[derive(Serialize)] -pub struct Project { - pub id: ProjectId, - pub room_id: RoomId, - pub host_connection_id: ConnectionId, - pub host: Collaborator, - pub guests: HashMap, - pub active_replica_ids: HashSet, - pub worktrees: BTreeMap, - pub language_servers: Vec, -} - -#[derive(Serialize)] -pub struct Collaborator { - pub replica_id: ReplicaId, - pub user_id: UserId, - pub admin: bool, -} - -#[derive(Default, Serialize)] -pub struct Worktree { - pub abs_path: Vec, - pub root_name: String, - pub visible: bool, - #[serde(skip)] - pub entries: BTreeMap, - #[serde(skip)] - pub diagnostic_summaries: BTreeMap, - pub scan_id: u64, - pub is_complete: bool, -} - -pub type ReplicaId = u16; - -#[derive(Default)] -pub struct RemovedConnectionState<'a> { - pub user_id: UserId, - pub hosted_projects: Vec, - pub guest_projects: Vec, - pub contact_ids: HashSet, - pub room: Option>, - pub canceled_call_connection_ids: Vec, -} - -pub struct LeftProject { - pub id: ProjectId, - pub host_user_id: UserId, - pub host_connection_id: ConnectionId, - pub connection_ids: Vec, - pub remove_collaborator: bool, -} - -pub struct LeftRoom<'a> { - pub room: Cow<'a, proto::Room>, - pub unshared_projects: Vec, - pub left_projects: Vec, - pub canceled_call_connection_ids: Vec, -} - -#[derive(Copy, Clone)] -pub struct Metrics { - pub connections: usize, - pub shared_projects: usize, -} - -impl Store { - pub fn metrics(&self) -> Metrics { - let connections = self.connections.values().filter(|c| !c.admin).count(); - let mut shared_projects = 0; - for project in self.projects.values() { - if let Some(connection) = self.connections.get(&project.host_connection_id) { - if !connection.admin { - shared_projects += 1; - } - } - } - - Metrics { - connections, - shared_projects, - } - } - - #[instrument(skip(self))] - pub fn add_connection( - &mut self, - connection_id: ConnectionId, - user_id: UserId, - admin: bool, - ) -> Option { - self.connections.insert( - connection_id, - ConnectionState { - user_id, - admin, - projects: Default::default(), - }, - ); - let connected_user = self.connected_users.entry(user_id).or_default(); - connected_user.connection_ids.insert(connection_id); - if let Some(active_call) = connected_user.active_call { - if active_call.connection_id.is_some() { - None - } else { - let room = self.room(active_call.room_id)?; - Some(proto::IncomingCall { - room_id: active_call.room_id, - caller_user_id: active_call.caller_user_id.to_proto(), - participant_user_ids: room - .participants - .iter() - .map(|participant| participant.user_id) - .collect(), - initial_project: active_call - .initial_project_id - .and_then(|id| Self::build_participant_project(id, &self.projects)), - }) - } - } else { - None - } - } - - #[instrument(skip(self))] - pub fn remove_connection( - &mut self, - connection_id: ConnectionId, - ) -> Result { - let connection = self - .connections - .get_mut(&connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - - let user_id = connection.user_id; - - let mut result = RemovedConnectionState { - user_id, - ..Default::default() - }; - - let connected_user = self.connected_users.get(&user_id).unwrap(); - if let Some(active_call) = connected_user.active_call.as_ref() { - let room_id = active_call.room_id; - if active_call.connection_id == Some(connection_id) { - let left_room = self.leave_room(room_id, connection_id)?; - result.hosted_projects = left_room.unshared_projects; - result.guest_projects = left_room.left_projects; - result.room = Some(Cow::Owned(left_room.room.into_owned())); - result.canceled_call_connection_ids = left_room.canceled_call_connection_ids; - } else if connected_user.connection_ids.len() == 1 { - let (room, _) = self.decline_call(room_id, connection_id)?; - result.room = Some(Cow::Owned(room.clone())); - } - } - - let connected_user = self.connected_users.get_mut(&user_id).unwrap(); - connected_user.connection_ids.remove(&connection_id); - if connected_user.connection_ids.is_empty() { - self.connected_users.remove(&user_id); - } - self.connections.remove(&connection_id).unwrap(); - - Ok(result) - } - - pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result { - Ok(self - .connections - .get(&connection_id) - .ok_or_else(|| anyhow!("unknown connection"))? - .user_id) - } - - pub fn connection_ids_for_user( - &self, - user_id: UserId, - ) -> impl Iterator + '_ { - self.connected_users - .get(&user_id) - .into_iter() - .map(|state| &state.connection_ids) - .flatten() - .copied() - } - - pub fn is_user_online(&self, user_id: UserId) -> bool { - !self - .connected_users - .get(&user_id) - .unwrap_or(&Default::default()) - .connection_ids - .is_empty() - } - - fn is_user_busy(&self, user_id: UserId) -> bool { - self.connected_users - .get(&user_id) - .unwrap_or(&Default::default()) - .active_call - .is_some() - } - - pub fn build_initial_contacts_update( - &self, - contacts: Vec, - ) -> proto::UpdateContacts { - let mut update = proto::UpdateContacts::default(); - - for contact in contacts { - match contact { - db::Contact::Accepted { - user_id, - should_notify, - } => { - update - .contacts - .push(self.contact_for_user(user_id, should_notify)); - } - db::Contact::Outgoing { user_id } => { - update.outgoing_requests.push(user_id.to_proto()) - } - db::Contact::Incoming { - user_id, - should_notify, - } => update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: user_id.to_proto(), - should_notify, - }), - } - } - - update - } - - pub fn contact_for_user(&self, user_id: UserId, should_notify: bool) -> proto::Contact { - proto::Contact { - user_id: user_id.to_proto(), - online: self.is_user_online(user_id), - busy: self.is_user_busy(user_id), - should_notify, - } - } - - pub fn create_room(&mut self, creator_connection_id: ConnectionId) -> Result<&proto::Room> { - let connection = self - .connections - .get_mut(&creator_connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let connected_user = self - .connected_users - .get_mut(&connection.user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - anyhow::ensure!( - connected_user.active_call.is_none(), - "can't create a room with an active call" - ); - - let room_id = post_inc(&mut self.next_room_id); - let room = proto::Room { - id: room_id, - participants: vec![proto::Participant { - user_id: connection.user_id.to_proto(), - peer_id: creator_connection_id.0, - projects: Default::default(), - location: Some(proto::ParticipantLocation { - variant: Some(proto::participant_location::Variant::External( - proto::participant_location::External {}, - )), - }), - }], - pending_participant_user_ids: Default::default(), - live_kit_room: nanoid!(30), - }; - - self.rooms.insert(room_id, room); - connected_user.active_call = Some(Call { - caller_user_id: connection.user_id, - room_id, - connection_id: Some(creator_connection_id), - initial_project_id: None, - }); - Ok(self.rooms.get(&room_id).unwrap()) - } - - pub fn join_room( - &mut self, - room_id: RoomId, - connection_id: ConnectionId, - ) -> Result<(&proto::Room, Vec)> { - let connection = self - .connections - .get_mut(&connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let user_id = connection.user_id; - let recipient_connection_ids = self.connection_ids_for_user(user_id).collect::>(); - - let connected_user = self - .connected_users - .get_mut(&user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let active_call = connected_user - .active_call - .as_mut() - .ok_or_else(|| anyhow!("not being called"))?; - anyhow::ensure!( - active_call.room_id == room_id && active_call.connection_id.is_none(), - "not being called on this room" - ); - - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - anyhow::ensure!( - room.pending_participant_user_ids - .contains(&user_id.to_proto()), - anyhow!("no such room") - ); - room.pending_participant_user_ids - .retain(|pending| *pending != user_id.to_proto()); - room.participants.push(proto::Participant { - user_id: user_id.to_proto(), - peer_id: connection_id.0, - projects: Default::default(), - location: Some(proto::ParticipantLocation { - variant: Some(proto::participant_location::Variant::External( - proto::participant_location::External {}, - )), - }), - }); - active_call.connection_id = Some(connection_id); - - Ok((room, recipient_connection_ids)) - } - - pub fn leave_room(&mut self, room_id: RoomId, connection_id: ConnectionId) -> Result { - let connection = self - .connections - .get_mut(&connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let user_id = connection.user_id; - - let connected_user = self - .connected_users - .get(&user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - anyhow::ensure!( - connected_user - .active_call - .map_or(false, |call| call.room_id == room_id - && call.connection_id == Some(connection_id)), - "cannot leave a room before joining it" - ); - - // Given that users can only join one room at a time, we can safely unshare - // and leave all projects associated with the connection. - let mut unshared_projects = Vec::new(); - let mut left_projects = Vec::new(); - for project_id in connection.projects.clone() { - if let Ok((_, project)) = self.unshare_project(project_id, connection_id) { - unshared_projects.push(project); - } else if let Ok(project) = self.leave_project(project_id, connection_id) { - left_projects.push(project); - } - } - self.connected_users.get_mut(&user_id).unwrap().active_call = None; - - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - room.participants - .retain(|participant| participant.peer_id != connection_id.0); - - let mut canceled_call_connection_ids = Vec::new(); - room.pending_participant_user_ids - .retain(|pending_participant_user_id| { - if let Some(connected_user) = self - .connected_users - .get_mut(&UserId::from_proto(*pending_participant_user_id)) - { - if let Some(call) = connected_user.active_call.as_ref() { - if call.caller_user_id == user_id { - connected_user.active_call.take(); - canceled_call_connection_ids - .extend(connected_user.connection_ids.iter().copied()); - false - } else { - true - } - } else { - true - } - } else { - true - } - }); - - let room = if room.participants.is_empty() { - Cow::Owned(self.rooms.remove(&room_id).unwrap()) - } else { - Cow::Borrowed(self.rooms.get(&room_id).unwrap()) - }; - - Ok(LeftRoom { - room, - unshared_projects, - left_projects, - canceled_call_connection_ids, - }) - } - - pub fn room(&self, room_id: RoomId) -> Option<&proto::Room> { - self.rooms.get(&room_id) - } - - pub fn rooms(&self) -> &BTreeMap { - &self.rooms - } - - pub fn call( - &mut self, - room_id: RoomId, - recipient_user_id: UserId, - initial_project_id: Option, - from_connection_id: ConnectionId, - ) -> Result<(&proto::Room, Vec, proto::IncomingCall)> { - let caller_user_id = self.user_id_for_connection(from_connection_id)?; - - let recipient_connection_ids = self - .connection_ids_for_user(recipient_user_id) - .collect::>(); - let mut recipient = self - .connected_users - .get_mut(&recipient_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - anyhow::ensure!( - recipient.active_call.is_none(), - "recipient is already on another call" - ); - - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - anyhow::ensure!( - room.participants - .iter() - .any(|participant| participant.peer_id == from_connection_id.0), - "no such room" - ); - anyhow::ensure!( - room.pending_participant_user_ids - .iter() - .all(|user_id| UserId::from_proto(*user_id) != recipient_user_id), - "cannot call the same user more than once" - ); - room.pending_participant_user_ids - .push(recipient_user_id.to_proto()); - - if let Some(initial_project_id) = initial_project_id { - let project = self - .projects - .get(&initial_project_id) - .ok_or_else(|| anyhow!("no such project"))?; - anyhow::ensure!(project.room_id == room_id, "no such project"); - } - - recipient.active_call = Some(Call { - caller_user_id, - room_id, - connection_id: None, - initial_project_id, - }); - - Ok(( - room, - recipient_connection_ids, - proto::IncomingCall { - room_id, - caller_user_id: caller_user_id.to_proto(), - participant_user_ids: room - .participants - .iter() - .map(|participant| participant.user_id) - .collect(), - initial_project: initial_project_id - .and_then(|id| Self::build_participant_project(id, &self.projects)), - }, - )) - } - - pub fn call_failed(&mut self, room_id: RoomId, to_user_id: UserId) -> Result<&proto::Room> { - let mut recipient = self - .connected_users - .get_mut(&to_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - anyhow::ensure!(recipient - .active_call - .map_or(false, |call| call.room_id == room_id - && call.connection_id.is_none())); - recipient.active_call = None; - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - room.pending_participant_user_ids - .retain(|user_id| UserId::from_proto(*user_id) != to_user_id); - Ok(room) - } - - pub fn cancel_call( - &mut self, - room_id: RoomId, - recipient_user_id: UserId, - canceller_connection_id: ConnectionId, - ) -> Result<(&proto::Room, HashSet)> { - let canceller_user_id = self.user_id_for_connection(canceller_connection_id)?; - let canceller = self - .connected_users - .get(&canceller_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let recipient = self - .connected_users - .get(&recipient_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let canceller_active_call = canceller - .active_call - .as_ref() - .ok_or_else(|| anyhow!("no active call"))?; - let recipient_active_call = recipient - .active_call - .as_ref() - .ok_or_else(|| anyhow!("no active call for recipient"))?; - - anyhow::ensure!( - canceller_active_call.room_id == room_id, - "users are on different calls" - ); - anyhow::ensure!( - recipient_active_call.room_id == room_id, - "users are on different calls" - ); - anyhow::ensure!( - recipient_active_call.connection_id.is_none(), - "recipient has already answered" - ); - let room_id = recipient_active_call.room_id; - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - room.pending_participant_user_ids - .retain(|user_id| UserId::from_proto(*user_id) != recipient_user_id); - - let recipient = self.connected_users.get_mut(&recipient_user_id).unwrap(); - recipient.active_call.take(); - - Ok((room, recipient.connection_ids.clone())) - } - - pub fn decline_call( - &mut self, - room_id: RoomId, - recipient_connection_id: ConnectionId, - ) -> Result<(&proto::Room, Vec)> { - let recipient_user_id = self.user_id_for_connection(recipient_connection_id)?; - let recipient = self - .connected_users - .get_mut(&recipient_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - if let Some(active_call) = recipient.active_call { - anyhow::ensure!(active_call.room_id == room_id, "no such room"); - anyhow::ensure!( - active_call.connection_id.is_none(), - "cannot decline a call after joining room" - ); - recipient.active_call.take(); - let recipient_connection_ids = self - .connection_ids_for_user(recipient_user_id) - .collect::>(); - let room = self - .rooms - .get_mut(&active_call.room_id) - .ok_or_else(|| anyhow!("no such room"))?; - room.pending_participant_user_ids - .retain(|user_id| UserId::from_proto(*user_id) != recipient_user_id); - Ok((room, recipient_connection_ids)) - } else { - Err(anyhow!("user is not being called")) - } - } - - pub fn update_participant_location( - &mut self, - room_id: RoomId, - location: proto::ParticipantLocation, - connection_id: ConnectionId, - ) -> Result<&proto::Room> { - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - if let Some(proto::participant_location::Variant::SharedProject(project)) = - location.variant.as_ref() - { - anyhow::ensure!( - room.participants - .iter() - .flat_map(|participant| &participant.projects) - .any(|participant_project| participant_project.id == project.id), - "no such project" - ); - } - - let participant = room - .participants - .iter_mut() - .find(|participant| participant.peer_id == connection_id.0) - .ok_or_else(|| anyhow!("no such room"))?; - participant.location = Some(location); - - Ok(room) - } - - pub fn share_project( - &mut self, - room_id: RoomId, - project_id: ProjectId, - worktrees: Vec, - host_connection_id: ConnectionId, - ) -> Result<&proto::Room> { - let connection = self - .connections - .get_mut(&host_connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - let participant = room - .participants - .iter_mut() - .find(|participant| participant.peer_id == host_connection_id.0) - .ok_or_else(|| anyhow!("no such room"))?; - - connection.projects.insert(project_id); - self.projects.insert( - project_id, - Project { - id: project_id, - room_id, - host_connection_id, - host: Collaborator { - user_id: connection.user_id, - replica_id: 0, - admin: connection.admin, - }, - guests: Default::default(), - active_replica_ids: Default::default(), - worktrees: worktrees - .into_iter() - .map(|worktree| { - ( - worktree.id, - Worktree { - root_name: worktree.root_name, - visible: worktree.visible, - abs_path: worktree.abs_path.clone(), - entries: Default::default(), - diagnostic_summaries: Default::default(), - scan_id: Default::default(), - is_complete: Default::default(), - }, - ) - }) - .collect(), - language_servers: Default::default(), - }, - ); - - participant - .projects - .extend(Self::build_participant_project(project_id, &self.projects)); - - Ok(room) - } - - pub fn unshare_project( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result<(&proto::Room, Project)> { - match self.projects.entry(project_id) { - btree_map::Entry::Occupied(e) => { - if e.get().host_connection_id == connection_id { - let project = e.remove(); - - if let Some(host_connection) = self.connections.get_mut(&connection_id) { - host_connection.projects.remove(&project_id); - } - - for guest_connection in project.guests.keys() { - if let Some(connection) = self.connections.get_mut(guest_connection) { - connection.projects.remove(&project_id); - } - } - - let room = self - .rooms - .get_mut(&project.room_id) - .ok_or_else(|| anyhow!("no such room"))?; - let participant = room - .participants - .iter_mut() - .find(|participant| participant.peer_id == connection_id.0) - .ok_or_else(|| anyhow!("no such room"))?; - participant - .projects - .retain(|project| project.id != project_id.to_proto()); - - Ok((room, project)) - } else { - Err(anyhow!("no such project"))? - } - } - btree_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?, - } - } - - pub fn update_project( - &mut self, - project_id: ProjectId, - worktrees: &[proto::WorktreeMetadata], - connection_id: ConnectionId, - ) -> Result<&proto::Room> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id { - let mut old_worktrees = mem::take(&mut project.worktrees); - for worktree in worktrees { - if let Some(old_worktree) = old_worktrees.remove(&worktree.id) { - project.worktrees.insert(worktree.id, old_worktree); - } else { - project.worktrees.insert( - worktree.id, - Worktree { - root_name: worktree.root_name.clone(), - visible: worktree.visible, - abs_path: worktree.abs_path.clone(), - entries: Default::default(), - diagnostic_summaries: Default::default(), - scan_id: Default::default(), - is_complete: false, - }, - ); - } - } - - let room = self - .rooms - .get_mut(&project.room_id) - .ok_or_else(|| anyhow!("no such room"))?; - let participant_project = room - .participants - .iter_mut() - .flat_map(|participant| &mut participant.projects) - .find(|project| project.id == project_id.to_proto()) - .ok_or_else(|| anyhow!("no such project"))?; - participant_project.worktree_root_names = worktrees - .iter() - .filter(|worktree| worktree.visible) - .map(|worktree| worktree.root_name.clone()) - .collect(); - - Ok(room) - } else { - Err(anyhow!("no such project"))? - } - } - - pub fn update_diagnostic_summary( - &mut self, - project_id: ProjectId, - worktree_id: u64, - connection_id: ConnectionId, - summary: proto::DiagnosticSummary, - ) -> Result> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id { - let worktree = project - .worktrees - .get_mut(&worktree_id) - .ok_or_else(|| anyhow!("no such worktree"))?; - worktree - .diagnostic_summaries - .insert(summary.path.clone().into(), summary); - return Ok(project.connection_ids()); - } - - Err(anyhow!("no such worktree"))? - } - - pub fn start_language_server( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - language_server: proto::LanguageServer, - ) -> Result> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id { - project.language_servers.push(language_server); - return Ok(project.connection_ids()); - } - - Err(anyhow!("no such project"))? - } - - pub fn join_project( - &mut self, - requester_connection_id: ConnectionId, - project_id: ProjectId, - ) -> Result<(&Project, ReplicaId)> { - let connection = self - .connections - .get_mut(&requester_connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let user = self - .connected_users - .get(&connection.user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let active_call = user.active_call.ok_or_else(|| anyhow!("no such project"))?; - anyhow::ensure!( - active_call.connection_id == Some(requester_connection_id), - "no such project" - ); - - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - anyhow::ensure!(project.room_id == active_call.room_id, "no such project"); - - connection.projects.insert(project_id); - let mut replica_id = 1; - while project.active_replica_ids.contains(&replica_id) { - replica_id += 1; - } - project.active_replica_ids.insert(replica_id); - project.guests.insert( - requester_connection_id, - Collaborator { - replica_id, - user_id: connection.user_id, - admin: connection.admin, - }, - ); - - Ok((project, replica_id)) - } - - pub fn leave_project( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - - // If the connection leaving the project is a collaborator, remove it. - let remove_collaborator = if let Some(guest) = project.guests.remove(&connection_id) { - project.active_replica_ids.remove(&guest.replica_id); - true - } else { - false - }; - - if let Some(connection) = self.connections.get_mut(&connection_id) { - connection.projects.remove(&project_id); - } - - Ok(LeftProject { - id: project.id, - host_connection_id: project.host_connection_id, - host_user_id: project.host.user_id, - connection_ids: project.connection_ids(), - remove_collaborator, - }) - } - - #[allow(clippy::too_many_arguments)] - pub fn update_worktree( - &mut self, - connection_id: ConnectionId, - project_id: ProjectId, - worktree_id: u64, - worktree_root_name: &str, - worktree_abs_path: &[u8], - removed_entries: &[u64], - updated_entries: &[proto::Entry], - scan_id: u64, - is_last_update: bool, - ) -> Result> { - let project = self.write_project(project_id, connection_id)?; - - let connection_ids = project.connection_ids(); - let mut worktree = project.worktrees.entry(worktree_id).or_default(); - worktree.root_name = worktree_root_name.to_string(); - worktree.abs_path = worktree_abs_path.to_vec(); - - for entry_id in removed_entries { - worktree.entries.remove(entry_id); - } - - for entry in updated_entries { - worktree.entries.insert(entry.id, entry.clone()); - } - - worktree.scan_id = scan_id; - worktree.is_complete = is_last_update; - Ok(connection_ids) - } - - fn build_participant_project( - project_id: ProjectId, - projects: &BTreeMap, - ) -> Option { - Some(proto::ParticipantProject { - id: project_id.to_proto(), - worktree_root_names: projects - .get(&project_id)? - .worktrees - .values() - .filter(|worktree| worktree.visible) - .map(|worktree| worktree.root_name.clone()) - .collect(), - }) - } - - pub fn project_connection_ids( - &self, - project_id: ProjectId, - acting_connection_id: ConnectionId, - ) -> Result> { - Ok(self - .read_project(project_id, acting_connection_id)? - .connection_ids()) - } - - pub fn project(&self, project_id: ProjectId) -> Result<&Project> { - self.projects - .get(&project_id) - .ok_or_else(|| anyhow!("no such project")) - } - - pub fn read_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result<&Project> { - let project = self - .projects - .get(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id - || project.guests.contains_key(&connection_id) - { - Ok(project) - } else { - Err(anyhow!("no such project"))? - } - } - - fn write_project( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result<&mut Project> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id - || project.guests.contains_key(&connection_id) - { - Ok(project) - } else { - Err(anyhow!("no such project"))? - } - } - - #[cfg(test)] - pub fn check_invariants(&self) { - for (connection_id, connection) in &self.connections { - for project_id in &connection.projects { - let project = &self.projects.get(project_id).unwrap(); - if project.host_connection_id != *connection_id { - assert!(project.guests.contains_key(connection_id)); - } - - for (worktree_id, worktree) in project.worktrees.iter() { - let mut paths = HashMap::default(); - for entry in worktree.entries.values() { - let prev_entry = paths.insert(&entry.path, entry); - assert_eq!( - prev_entry, - None, - "worktree {:?}, duplicate path for entries {:?} and {:?}", - worktree_id, - prev_entry.unwrap(), - entry - ); - } - } - } - - assert!(self - .connected_users - .get(&connection.user_id) - .unwrap() - .connection_ids - .contains(connection_id)); - } - - for (user_id, state) in &self.connected_users { - for connection_id in &state.connection_ids { - assert_eq!( - self.connections.get(connection_id).unwrap().user_id, - *user_id - ); - } - - if let Some(active_call) = state.active_call.as_ref() { - if let Some(active_call_connection_id) = active_call.connection_id { - assert!( - state.connection_ids.contains(&active_call_connection_id), - "call is active on a dead connection" - ); - assert!( - state.connection_ids.contains(&active_call_connection_id), - "call is active on a dead connection" - ); - } - } - } - - for (room_id, room) in &self.rooms { - for pending_user_id in &room.pending_participant_user_ids { - assert!( - self.connected_users - .contains_key(&UserId::from_proto(*pending_user_id)), - "call is active on a user that has disconnected" - ); - } - - for participant in &room.participants { - assert!( - self.connections - .contains_key(&ConnectionId(participant.peer_id)), - "room {} contains participant {:?} that has disconnected", - room_id, - participant - ); - - for participant_project in &participant.projects { - let project = &self.projects[&ProjectId::from_proto(participant_project.id)]; - assert_eq!( - project.room_id, *room_id, - "project was shared on a different room" - ); - } - } - - assert!( - !room.pending_participant_user_ids.is_empty() || !room.participants.is_empty(), - "room can't be empty" - ); - } - - for (project_id, project) in &self.projects { - let host_connection = self.connections.get(&project.host_connection_id).unwrap(); - assert!(host_connection.projects.contains(project_id)); - - for guest_connection_id in project.guests.keys() { - let guest_connection = self.connections.get(guest_connection_id).unwrap(); - assert!(guest_connection.projects.contains(project_id)); - } - assert_eq!(project.active_replica_ids.len(), project.guests.len()); - assert_eq!( - project.active_replica_ids, - project - .guests - .values() - .map(|guest| guest.replica_id) - .collect::>(), - ); - - let room = &self.rooms[&project.room_id]; - let room_participant = room - .participants - .iter() - .find(|participant| participant.peer_id == project.host_connection_id.0) - .unwrap(); - assert!( - room_participant - .projects - .iter() - .any(|project| project.id == project_id.to_proto()), - "project was not shared in room" - ); - } - } -} - -impl Project { - pub fn guest_connection_ids(&self) -> Vec { - self.guests.keys().copied().collect() - } - - pub fn connection_ids(&self) -> Vec { - self.guests - .keys() - .copied() - .chain(Some(self.host_connection_id)) - .collect() - } -} diff --git a/crates/collab_ui/src/collab_ui.rs b/crates/collab_ui/src/collab_ui.rs index 964cec0f82a7ac0a74567d70261fe1b18f80a14a..abc62605f93cc34046e3663e41c0e561f14b3cb6 100644 --- a/crates/collab_ui/src/collab_ui.rs +++ b/crates/collab_ui/src/collab_ui.rs @@ -43,7 +43,6 @@ pub fn init(app_state: Arc, cx: &mut MutableAppContext) { project_id, app_state.client.clone(), app_state.user_store.clone(), - app_state.project_store.clone(), app_state.languages.clone(), app_state.fs.clone(), cx.clone(), diff --git a/crates/collab_ui/src/incoming_call_notification.rs b/crates/collab_ui/src/incoming_call_notification.rs index e5c4b27d7e4abee3138d107c827b590261115331..a51fb4891d20ee303d35992ef1c2dbc298dd1562 100644 --- a/crates/collab_ui/src/incoming_call_notification.rs +++ b/crates/collab_ui/src/incoming_call_notification.rs @@ -74,7 +74,7 @@ impl IncomingCallNotification { let active_call = ActiveCall::global(cx); if action.accept { let join = active_call.update(cx, |active_call, cx| active_call.accept_incoming(cx)); - let caller_user_id = self.call.caller.id; + let caller_user_id = self.call.calling_user.id; let initial_project_id = self.call.initial_project.as_ref().map(|project| project.id); cx.spawn_weak(|_, mut cx| async move { join.await?; @@ -105,7 +105,7 @@ impl IncomingCallNotification { .as_ref() .unwrap_or(&default_project); Flex::row() - .with_children(self.call.caller.avatar.clone().map(|avatar| { + .with_children(self.call.calling_user.avatar.clone().map(|avatar| { Image::new(avatar) .with_style(theme.caller_avatar) .aligned() @@ -115,7 +115,7 @@ impl IncomingCallNotification { Flex::column() .with_child( Label::new( - self.call.caller.github_login.clone(), + self.call.calling_user.github_login.clone(), theme.caller_username.text.clone(), ) .contained() diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 0639445b0d1f2c35a65fd9777e3d96165bcbb702..876e48351d6e8e224df3dcefd2a953414b4436b9 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -66,21 +66,32 @@ struct DeterministicState { rng: rand::prelude::StdRng, seed: u64, scheduled_from_foreground: collections::HashMap>, - scheduled_from_background: Vec, + scheduled_from_background: Vec, forbid_parking: bool, block_on_ticks: std::ops::RangeInclusive, now: std::time::Instant, next_timer_id: usize, pending_timers: Vec<(usize, std::time::Instant, postage::barrier::Sender)>, waiting_backtrace: Option, + next_runnable_id: usize, + poll_history: Vec, + enable_runnable_backtraces: bool, + runnable_backtraces: collections::HashMap, } #[cfg(any(test, feature = "test-support"))] struct ForegroundRunnable { + id: usize, runnable: Runnable, main: bool, } +#[cfg(any(test, feature = "test-support"))] +struct BackgroundRunnable { + id: usize, + runnable: Runnable, +} + #[cfg(any(test, feature = "test-support"))] pub struct Deterministic { state: Arc>, @@ -117,11 +128,29 @@ impl Deterministic { next_timer_id: Default::default(), pending_timers: Default::default(), waiting_backtrace: None, + next_runnable_id: 0, + poll_history: Default::default(), + enable_runnable_backtraces: false, + runnable_backtraces: Default::default(), })), parker: Default::default(), }) } + pub fn runnable_history(&self) -> Vec { + self.state.lock().poll_history.clone() + } + + pub fn enable_runnable_backtrace(&self) { + self.state.lock().enable_runnable_backtraces = true; + } + + pub fn runnable_backtrace(&self, runnable_id: usize) -> backtrace::Backtrace { + let mut backtrace = self.state.lock().runnable_backtraces[&runnable_id].clone(); + backtrace.resolve(); + backtrace + } + pub fn build_background(self: &Arc) -> Arc { Arc::new(Background::Deterministic { executor: self.clone(), @@ -142,6 +171,17 @@ impl Deterministic { main: bool, ) -> AnyLocalTask { let state = self.state.clone(); + let id; + { + let mut state = state.lock(); + id = util::post_inc(&mut state.next_runnable_id); + if state.enable_runnable_backtraces { + state + .runnable_backtraces + .insert(id, backtrace::Backtrace::new_unresolved()); + } + } + let unparker = self.parker.lock().unparker(); let (runnable, task) = async_task::spawn_local(future, move |runnable| { let mut state = state.lock(); @@ -149,7 +189,7 @@ impl Deterministic { .scheduled_from_foreground .entry(cx_id) .or_default() - .push(ForegroundRunnable { runnable, main }); + .push(ForegroundRunnable { id, runnable, main }); unparker.unpark(); }); runnable.schedule(); @@ -158,10 +198,23 @@ impl Deterministic { fn spawn(&self, future: AnyFuture) -> AnyTask { let state = self.state.clone(); + let id; + { + let mut state = state.lock(); + id = util::post_inc(&mut state.next_runnable_id); + if state.enable_runnable_backtraces { + state + .runnable_backtraces + .insert(id, backtrace::Backtrace::new_unresolved()); + } + } + let unparker = self.parker.lock().unparker(); let (runnable, task) = async_task::spawn(future, move |runnable| { let mut state = state.lock(); - state.scheduled_from_background.push(runnable); + state + .scheduled_from_background + .push(BackgroundRunnable { id, runnable }); unparker.unpark(); }); runnable.schedule(); @@ -178,15 +231,27 @@ impl Deterministic { let woken = Arc::new(AtomicBool::new(false)); let state = self.state.clone(); + let id; + { + let mut state = state.lock(); + id = util::post_inc(&mut state.next_runnable_id); + if state.enable_runnable_backtraces { + state + .runnable_backtraces + .insert(id, backtrace::Backtrace::new_unresolved()); + } + } + let unparker = self.parker.lock().unparker(); let (runnable, mut main_task) = unsafe { async_task::spawn_unchecked(main_future, move |runnable| { - let mut state = state.lock(); + let state = &mut *state.lock(); state .scheduled_from_foreground .entry(cx_id) .or_default() .push(ForegroundRunnable { + id: util::post_inc(&mut state.next_runnable_id), runnable, main: true, }); @@ -248,9 +313,10 @@ impl Deterministic { if !state.scheduled_from_background.is_empty() && state.rng.gen() { let background_len = state.scheduled_from_background.len(); let ix = state.rng.gen_range(0..background_len); - let runnable = state.scheduled_from_background.remove(ix); + let background_runnable = state.scheduled_from_background.remove(ix); + state.poll_history.push(background_runnable.id); drop(state); - runnable.run(); + background_runnable.runnable.run(); } else if !state.scheduled_from_foreground.is_empty() { let available_cx_ids = state .scheduled_from_foreground @@ -266,6 +332,7 @@ impl Deterministic { if scheduled_from_cx.is_empty() { state.scheduled_from_foreground.remove(&cx_id_to_run); } + state.poll_history.push(foreground_runnable.id); drop(state); @@ -298,9 +365,10 @@ impl Deterministic { let runnable_count = state.scheduled_from_background.len(); let ix = state.rng.gen_range(0..=runnable_count); if ix < state.scheduled_from_background.len() { - let runnable = state.scheduled_from_background.remove(ix); + let background_runnable = state.scheduled_from_background.remove(ix); + state.poll_history.push(background_runnable.id); drop(state); - runnable.run(); + background_runnable.runnable.run(); } else { drop(state); if let Poll::Ready(result) = future.poll(&mut cx) { diff --git a/crates/gpui/src/test.rs b/crates/gpui/src/test.rs index e76b094c9a586951ea0bab55ff3e058553635535..aade1054a8d919590bded33c09dc4c458a6579e6 100644 --- a/crates/gpui/src/test.rs +++ b/crates/gpui/src/test.rs @@ -1,11 +1,13 @@ use crate::{ - elements::Empty, executor, platform, Element, ElementBox, Entity, FontCache, Handle, - LeakDetector, MutableAppContext, Platform, RenderContext, Subscription, TestAppContext, View, + elements::Empty, executor, platform, util::CwdBacktrace, Element, ElementBox, Entity, + FontCache, Handle, LeakDetector, MutableAppContext, Platform, RenderContext, Subscription, + TestAppContext, View, }; use futures::StreamExt; use parking_lot::Mutex; use smol::channel; use std::{ + fmt::Write, panic::{self, RefUnwindSafe}, rc::Rc, sync::{ @@ -29,13 +31,13 @@ pub fn run_test( mut num_iterations: u64, mut starting_seed: u64, max_retries: usize, + detect_nondeterminism: bool, test_fn: &mut (dyn RefUnwindSafe + Fn( &mut MutableAppContext, Rc, Arc, u64, - bool, )), fn_name: String, ) { @@ -60,16 +62,20 @@ pub fn run_test( let platform = Arc::new(platform::test::platform()); let font_system = platform.fonts(); let font_cache = Arc::new(FontCache::new(font_system)); + let mut prev_runnable_history: Option> = None; - loop { - let seed = atomic_seed.fetch_add(1, SeqCst); - let is_last_iteration = seed + 1 >= starting_seed + num_iterations; + for _ in 0..num_iterations { + let seed = atomic_seed.load(SeqCst); if is_randomized { dbg!(seed); } let deterministic = executor::Deterministic::new(seed); + if detect_nondeterminism { + deterministic.enable_runnable_backtrace(); + } + let leak_detector = Arc::new(Mutex::new(LeakDetector::default())); let mut cx = TestAppContext::new( foreground_platform.clone(), @@ -82,13 +88,7 @@ pub fn run_test( fn_name.clone(), ); cx.update(|cx| { - test_fn( - cx, - foreground_platform.clone(), - deterministic.clone(), - seed, - is_last_iteration, - ); + test_fn(cx, foreground_platform.clone(), deterministic.clone(), seed); }); cx.update(|cx| cx.remove_all_windows()); @@ -96,8 +96,64 @@ pub fn run_test( cx.update(|cx| cx.clear_globals()); leak_detector.lock().detect(); - if is_last_iteration { - break; + + if detect_nondeterminism { + let curr_runnable_history = deterministic.runnable_history(); + if let Some(prev_runnable_history) = prev_runnable_history { + let mut prev_entries = prev_runnable_history.iter().fuse(); + let mut curr_entries = curr_runnable_history.iter().fuse(); + + let mut nondeterministic = false; + let mut common_history_prefix = Vec::new(); + let mut prev_history_suffix = Vec::new(); + let mut curr_history_suffix = Vec::new(); + loop { + match (prev_entries.next(), curr_entries.next()) { + (None, None) => break, + (None, Some(curr_id)) => curr_history_suffix.push(*curr_id), + (Some(prev_id), None) => prev_history_suffix.push(*prev_id), + (Some(prev_id), Some(curr_id)) => { + if nondeterministic { + prev_history_suffix.push(*prev_id); + curr_history_suffix.push(*curr_id); + } else if prev_id == curr_id { + common_history_prefix.push(*curr_id); + } else { + nondeterministic = true; + prev_history_suffix.push(*prev_id); + curr_history_suffix.push(*curr_id); + } + } + } + } + + if nondeterministic { + let mut error = String::new(); + writeln!(&mut error, "Common prefix: {:?}", common_history_prefix) + .unwrap(); + writeln!(&mut error, "Previous suffix: {:?}", prev_history_suffix) + .unwrap(); + writeln!(&mut error, "Current suffix: {:?}", curr_history_suffix) + .unwrap(); + + let last_common_backtrace = common_history_prefix + .last() + .map(|runnable_id| deterministic.runnable_backtrace(*runnable_id)); + + writeln!( + &mut error, + "Last future that ran on both executions: {:?}", + last_common_backtrace.as_ref().map(CwdBacktrace) + ) + .unwrap(); + panic!("Detected non-determinism.\n{}", error); + } + } + prev_runnable_history = Some(curr_runnable_history); + } + + if !detect_nondeterminism { + atomic_seed.fetch_add(1, SeqCst); } } }); @@ -112,7 +168,7 @@ pub fn run_test( println!("retrying: attempt {}", retries); } else { if is_randomized { - eprintln!("failing seed: {}", atomic_seed.load(SeqCst) - 1); + eprintln!("failing seed: {}", atomic_seed.load(SeqCst)); } panic::resume_unwind(error); } diff --git a/crates/gpui_macros/src/gpui_macros.rs b/crates/gpui_macros/src/gpui_macros.rs index b43bedc64315ffc9f1b7b98845e064e6cc67555d..e28d1711d2ceda9cfcedf28310575e1cbc3cc620 100644 --- a/crates/gpui_macros/src/gpui_macros.rs +++ b/crates/gpui_macros/src/gpui_macros.rs @@ -14,6 +14,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { let mut max_retries = 0; let mut num_iterations = 1; let mut starting_seed = 0; + let mut detect_nondeterminism = false; for arg in args { match arg { @@ -26,6 +27,9 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { let key_name = meta.path.get_ident().map(|i| i.to_string()); let result = (|| { match key_name.as_deref() { + Some("detect_nondeterminism") => { + detect_nondeterminism = parse_bool(&meta.lit)? + } Some("retries") => max_retries = parse_int(&meta.lit)?, Some("iterations") => num_iterations = parse_int(&meta.lit)?, Some("seed") => starting_seed = parse_int(&meta.lit)?, @@ -77,10 +81,6 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),)); continue; } - Some("bool") => { - inner_fn_args.extend(quote!(is_last_iteration,)); - continue; - } Some("Arc") => { if let syn::PathArguments::AngleBracketed(args) = &last_segment.unwrap().arguments @@ -146,7 +146,8 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { #num_iterations as u64, #starting_seed as u64, #max_retries, - &mut |cx, foreground_platform, deterministic, seed, is_last_iteration| { + #detect_nondeterminism, + &mut |cx, foreground_platform, deterministic, seed| { #cx_vars cx.foreground().run(#inner_fn_name(#inner_fn_args)); #cx_teardowns @@ -165,9 +166,6 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { Some("StdRng") => { inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),)); } - Some("bool") => { - inner_fn_args.extend(quote!(is_last_iteration,)); - } _ => {} } } else { @@ -189,7 +187,8 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { #num_iterations as u64, #starting_seed as u64, #max_retries, - &mut |cx, _, _, seed, is_last_iteration| #inner_fn_name(#inner_fn_args), + #detect_nondeterminism, + &mut |cx, _, _, seed| #inner_fn_name(#inner_fn_args), stringify!(#outer_fn_name).to_string(), ); } @@ -209,3 +208,13 @@ fn parse_int(literal: &Lit) -> Result { result.map_err(|err| TokenStream::from(err.into_compile_error())) } + +fn parse_bool(literal: &Lit) -> Result { + let result = if let Lit::Bool(result) = &literal { + Ok(result.value) + } else { + Err(syn::Error::new(literal.span(), "must be a boolean")) + }; + + result.map_err(|err| TokenStream::from(err.into_compile_error())) +} diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index e0cc3cdd0b4f77c8b306648a2df0d7f28f93c830..512ac702d062be924ac5183a9e98d34be80a45f7 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -10,7 +10,11 @@ use anyhow::{anyhow, Context, Result}; use client::{proto, Client, PeerId, TypedEnvelope, UserStore}; use clock::ReplicaId; use collections::{hash_map, BTreeMap, HashMap, HashSet}; -use futures::{future::Shared, AsyncWriteExt, Future, FutureExt, StreamExt, TryFutureExt}; +use futures::{ + channel::{mpsc, oneshot}, + future::Shared, + AsyncWriteExt, Future, FutureExt, StreamExt, TryFutureExt, +}; use gpui::{ AnyModelHandle, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, @@ -45,12 +49,10 @@ use std::{ cell::RefCell, cmp::{self, Ordering}, convert::TryInto, - ffi::OsString, hash::Hash, mem, num::NonZeroU32, ops::Range, - os::unix::{ffi::OsStrExt, prelude::OsStringExt}, path::{Component, Path, PathBuf}, rc::Rc, str, @@ -70,10 +72,6 @@ pub trait Item: Entity { fn entry_id(&self, cx: &AppContext) -> Option; } -pub struct ProjectStore { - projects: Vec>, -} - // Language server state is stored across 3 collections: // language_servers => // a mapping from unique server id to LanguageServerState which can either be a task for a @@ -102,7 +100,6 @@ pub struct Project { next_entry_id: Arc, next_diagnostic_group_id: usize, user_store: ModelHandle, - project_store: ModelHandle, fs: Arc, client_state: Option, collaborators: HashMap, @@ -152,6 +149,8 @@ enum WorktreeHandle { enum ProjectClientState { Local { remote_id: u64, + metadata_changed: mpsc::UnboundedSender>, + _maintain_metadata: Task<()>, _detect_unshare: Task>, }, Remote { @@ -412,46 +411,39 @@ impl Project { pub fn local( client: Arc, user_store: ModelHandle, - project_store: ModelHandle, languages: Arc, fs: Arc, cx: &mut MutableAppContext, ) -> ModelHandle { - cx.add_model(|cx: &mut ModelContext| { - let handle = cx.weak_handle(); - project_store.update(cx, |store, cx| store.add_project(handle, cx)); - - Self { - worktrees: Default::default(), - collaborators: Default::default(), - opened_buffers: Default::default(), - shared_buffers: Default::default(), - incomplete_buffers: Default::default(), - loading_buffers: Default::default(), - loading_local_worktrees: Default::default(), - buffer_snapshots: Default::default(), - client_state: None, - opened_buffer: watch::channel(), - client_subscriptions: Vec::new(), - _subscriptions: vec![cx.observe_global::(Self::on_settings_changed)], - _maintain_buffer_languages: Self::maintain_buffer_languages(&languages, cx), - active_entry: None, - languages, - client, - user_store, - project_store, - fs, - next_entry_id: Default::default(), - next_diagnostic_group_id: Default::default(), - language_servers: Default::default(), - language_server_ids: Default::default(), - language_server_statuses: Default::default(), - last_workspace_edits_by_language_server: Default::default(), - language_server_settings: Default::default(), - buffers_being_formatted: Default::default(), - next_language_server_id: 0, - nonce: StdRng::from_entropy().gen(), - } + cx.add_model(|cx: &mut ModelContext| Self { + worktrees: Default::default(), + collaborators: Default::default(), + opened_buffers: Default::default(), + shared_buffers: Default::default(), + incomplete_buffers: Default::default(), + loading_buffers: Default::default(), + loading_local_worktrees: Default::default(), + buffer_snapshots: Default::default(), + client_state: None, + opened_buffer: watch::channel(), + client_subscriptions: Vec::new(), + _subscriptions: vec![cx.observe_global::(Self::on_settings_changed)], + _maintain_buffer_languages: Self::maintain_buffer_languages(&languages, cx), + active_entry: None, + languages, + client, + user_store, + fs, + next_entry_id: Default::default(), + next_diagnostic_group_id: Default::default(), + language_servers: Default::default(), + language_server_ids: Default::default(), + language_server_statuses: Default::default(), + last_workspace_edits_by_language_server: Default::default(), + language_server_settings: Default::default(), + buffers_being_formatted: Default::default(), + next_language_server_id: 0, + nonce: StdRng::from_entropy().gen(), }) } @@ -459,31 +451,28 @@ impl Project { remote_id: u64, client: Arc, user_store: ModelHandle, - project_store: ModelHandle, languages: Arc, fs: Arc, mut cx: AsyncAppContext, ) -> Result, JoinProjectError> { client.authenticate_and_connect(true, &cx).await?; + let subscription = client.subscribe_to_entity(remote_id); let response = client .request(proto::JoinProject { project_id: remote_id, }) .await?; + let this = cx.add_model(|cx| { + let replica_id = response.replica_id as ReplicaId; - let replica_id = response.replica_id as ReplicaId; - - let mut worktrees = Vec::new(); - for worktree in response.worktrees { - let worktree = cx - .update(|cx| Worktree::remote(remote_id, replica_id, worktree, client.clone(), cx)); - worktrees.push(worktree); - } - - let this = cx.add_model(|cx: &mut ModelContext| { - let handle = cx.weak_handle(); - project_store.update(cx, |store, cx| store.add_project(handle, cx)); + let mut worktrees = Vec::new(); + for worktree in response.worktrees { + let worktree = cx.update(|cx| { + Worktree::remote(remote_id, replica_id, worktree, client.clone(), cx) + }); + worktrees.push(worktree); + } let mut this = Self { worktrees: Vec::new(), @@ -497,11 +486,10 @@ impl Project { _maintain_buffer_languages: Self::maintain_buffer_languages(&languages, cx), languages, user_store: user_store.clone(), - project_store, fs, next_entry_id: Default::default(), next_diagnostic_group_id: Default::default(), - client_subscriptions: vec![client.add_model_for_remote_entity(remote_id, cx)], + client_subscriptions: Default::default(), _subscriptions: Default::default(), client: client.clone(), client_state: Some(ProjectClientState::Remote { @@ -550,10 +538,11 @@ impl Project { nonce: StdRng::from_entropy().gen(), }; for worktree in worktrees { - this.add_worktree(&worktree, cx); + let _ = this.add_worktree(&worktree, cx); } this }); + let subscription = subscription.set_model(&this, &mut cx); let user_ids = response .collaborators @@ -571,6 +560,7 @@ impl Project { this.update(&mut cx, |this, _| { this.collaborators = collaborators; + this.client_subscriptions.push(subscription); }); Ok(this) @@ -593,9 +583,7 @@ impl Project { let http_client = client::test::FakeHttpClient::with_404_response(); let client = cx.update(|cx| client::Client::new(http_client.clone(), cx)); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx)); - let project_store = cx.add_model(|_| ProjectStore::new()); - let project = - cx.update(|cx| Project::local(client, user_store, project_store, languages, fs, cx)); + let project = cx.update(|cx| Project::local(client, user_store, languages, fs, cx)); for path in root_paths { let (tree, _) = project .update(cx, |project, cx| { @@ -676,10 +664,6 @@ impl Project { self.user_store.clone() } - pub fn project_store(&self) -> ModelHandle { - self.project_store.clone() - } - #[cfg(any(test, feature = "test-support"))] pub fn check_invariants(&self, cx: &AppContext) { if self.is_local() { @@ -751,53 +735,22 @@ impl Project { } } - fn metadata_changed(&mut self, cx: &mut ModelContext) { - if let Some(ProjectClientState::Local { remote_id, .. }) = &self.client_state { - let project_id = *remote_id; - // Broadcast worktrees only if the project is online. - let worktrees = self - .worktrees - .iter() - .filter_map(|worktree| { - worktree - .upgrade(cx) - .map(|worktree| worktree.read(cx).as_local().unwrap().metadata_proto()) - }) - .collect(); - self.client - .send(proto::UpdateProject { - project_id, - worktrees, - }) - .log_err(); - - let worktrees = self.visible_worktrees(cx).collect::>(); - let scans_complete = futures::future::join_all( - worktrees - .iter() - .filter_map(|worktree| Some(worktree.read(cx).as_local()?.scan_complete())), - ); - - let worktrees = worktrees.into_iter().map(|handle| handle.downgrade()); - - cx.spawn_weak(move |_, cx| async move { - scans_complete.await; - cx.read(|cx| { - for worktree in worktrees { - if let Some(worktree) = worktree - .upgrade(cx) - .and_then(|worktree| worktree.read(cx).as_local()) - { - worktree.send_extension_counts(project_id); - } - } - }) - }) - .detach(); + fn metadata_changed(&mut self, cx: &mut ModelContext) -> impl Future { + let (tx, rx) = oneshot::channel(); + if let Some(ProjectClientState::Local { + metadata_changed, .. + }) = &mut self.client_state + { + let _ = metadata_changed.unbounded_send(tx); } - - self.project_store.update(cx, |_, cx| cx.notify()); cx.notify(); + + async move { + // If the project is shared, this will resolve when the `_maintain_metadata` task has + // a chance to update the metadata. Otherwise, it will resolve right away because `tx` + // will get dropped. + let _ = rx.await; + } } pub fn collaborators(&self) -> &HashMap { @@ -899,7 +852,7 @@ impl Project { .request(proto::CreateProjectEntry { worktree_id: project_path.worktree_id.to_proto(), project_id, - path: project_path.path.as_os_str().as_bytes().to_vec(), + path: project_path.path.to_string_lossy().into(), is_directory, }) .await?; @@ -943,7 +896,7 @@ impl Project { .request(proto::CopyProjectEntry { project_id, entry_id: entry_id.to_proto(), - new_path: new_path.as_os_str().as_bytes().to_vec(), + new_path: new_path.to_string_lossy().into(), }) .await?; let entry = response @@ -986,7 +939,7 @@ impl Project { .request(proto::RenameProjectEntry { project_id, entry_id: entry_id.to_proto(), - new_path: new_path.as_os_str().as_bytes().to_vec(), + new_path: new_path.to_string_lossy().into(), }) .await?; let entry = response @@ -1087,15 +1040,51 @@ impl Project { }); } - self.client_subscriptions - .push(self.client.add_model_for_remote_entity(project_id, cx)); - self.metadata_changed(cx); + self.client_subscriptions.push( + self.client + .subscribe_to_entity(project_id) + .set_model(&cx.handle(), &mut cx.to_async()), + ); + let _ = self.metadata_changed(cx); cx.emit(Event::RemoteIdChanged(Some(project_id))); cx.notify(); let mut status = self.client.status(); + let (metadata_changed_tx, mut metadata_changed_rx) = mpsc::unbounded(); self.client_state = Some(ProjectClientState::Local { remote_id: project_id, + metadata_changed: metadata_changed_tx, + _maintain_metadata: cx.spawn_weak(move |this, cx| async move { + while let Some(tx) = metadata_changed_rx.next().await { + let mut txs = vec![tx]; + while let Ok(Some(next_tx)) = metadata_changed_rx.try_next() { + txs.push(next_tx); + } + + let Some(this) = this.upgrade(&cx) else { break }; + this.read_with(&cx, |this, cx| { + let worktrees = this + .worktrees + .iter() + .filter_map(|worktree| { + worktree.upgrade(cx).map(|worktree| { + worktree.read(cx).as_local().unwrap().metadata_proto() + }) + }) + .collect(); + this.client.request(proto::UpdateProject { + project_id, + worktrees, + }) + }) + .await + .log_err(); + + for tx in txs { + let _ = tx.send(()); + } + } + }), _detect_unshare: cx.spawn_weak(move |this, mut cx| { async move { let is_connected = status.next().await.map_or(false, |s| s.is_connected()); @@ -1145,7 +1134,7 @@ impl Project { } } - self.metadata_changed(cx); + let _ = self.metadata_changed(cx); cx.notify(); self.client.send(proto::UnshareProject { project_id: remote_id, @@ -1634,10 +1623,6 @@ impl Project { operations: vec![language::proto::serialize_operation(operation)], }); cx.background().spawn(request).detach_and_log_err(cx); - } else if let Some(project_id) = self.remote_id() { - let _ = self - .client - .send(proto::RegisterProjectActivity { project_id }); } } BufferEvent::Edited { .. } => { @@ -3429,19 +3414,29 @@ impl Project { position: Some(language::proto::serialize_anchor(&anchor)), version: serialize_version(&source_buffer.version()), }; - cx.spawn_weak(|_, mut cx| async move { + cx.spawn_weak(|this, mut cx| async move { let response = rpc.request(message).await?; - source_buffer_handle - .update(&mut cx, |buffer, _| { - buffer.wait_for_version(deserialize_version(response.version)) - }) - .await; + if this + .upgrade(&cx) + .ok_or_else(|| anyhow!("project was dropped"))? + .read_with(&cx, |this, _| this.is_read_only()) + { + return Err(anyhow!( + "failed to get completions: project was disconnected" + )); + } else { + source_buffer_handle + .update(&mut cx, |buffer, _| { + buffer.wait_for_version(deserialize_version(response.version)) + }) + .await; - let completions = response.completions.into_iter().map(|completion| { - language::proto::deserialize_completion(completion, language.clone()) - }); - futures::future::try_join_all(completions).await + let completions = response.completions.into_iter().map(|completion| { + language::proto::deserialize_completion(completion, language.clone()) + }); + futures::future::try_join_all(completions).await + } }) } else { Task::ready(Ok(Default::default())) @@ -3618,7 +3613,7 @@ impl Project { } else if let Some(project_id) = self.remote_id() { let rpc = self.client.clone(); let version = buffer.version(); - cx.spawn_weak(|_, mut cx| async move { + cx.spawn_weak(|this, mut cx| async move { let response = rpc .request(proto::GetCodeActions { project_id, @@ -3629,17 +3624,27 @@ impl Project { }) .await?; - buffer_handle - .update(&mut cx, |buffer, _| { - buffer.wait_for_version(deserialize_version(response.version)) - }) - .await; + if this + .upgrade(&cx) + .ok_or_else(|| anyhow!("project was dropped"))? + .read_with(&cx, |this, _| this.is_read_only()) + { + return Err(anyhow!( + "failed to get code actions: project was disconnected" + )); + } else { + buffer_handle + .update(&mut cx, |buffer, _| { + buffer.wait_for_version(deserialize_version(response.version)) + }) + .await; - response - .actions - .into_iter() - .map(language::proto::deserialize_code_action) - .collect() + response + .actions + .into_iter() + .map(language::proto::deserialize_code_action) + .collect() + } }) } else { Task::ready(Ok(Default::default())) @@ -4148,9 +4153,13 @@ impl Project { let message = request.to_proto(project_id, buffer); return cx.spawn(|this, cx| async move { let response = rpc.request(message).await?; - request - .response_from_proto(response, this, buffer_handle, cx) - .await + if this.read_with(&cx, |this, _| this.is_read_only()) { + Err(anyhow!("disconnected before completing request")) + } else { + request + .response_from_proto(response, this, buffer_handle, cx) + .await + } }); } Task::ready(Ok(Default::default())) @@ -4228,12 +4237,13 @@ impl Project { }); let worktree = worktree?; - let project_id = project.update(&mut cx, |project, cx| { - project.add_worktree(&worktree, cx); - project.remote_id() - }); + project + .update(&mut cx, |project, cx| project.add_worktree(&worktree, cx)) + .await; - if let Some(project_id) = project_id { + if let Some(project_id) = + project.read_with(&cx, |project, _| project.remote_id()) + { worktree .update(&mut cx, |worktree, cx| { worktree.as_local_mut().unwrap().share(project_id, cx) @@ -4257,7 +4267,11 @@ impl Project { }) } - pub fn remove_worktree(&mut self, id_to_remove: WorktreeId, cx: &mut ModelContext) { + pub fn remove_worktree( + &mut self, + id_to_remove: WorktreeId, + cx: &mut ModelContext, + ) -> impl Future { self.worktrees.retain(|worktree| { if let Some(worktree) = worktree.upgrade(cx) { let id = worktree.read(cx).id(); @@ -4271,11 +4285,14 @@ impl Project { false } }); - self.metadata_changed(cx); - cx.notify(); + self.metadata_changed(cx) } - fn add_worktree(&mut self, worktree: &ModelHandle, cx: &mut ModelContext) { + fn add_worktree( + &mut self, + worktree: &ModelHandle, + cx: &mut ModelContext, + ) -> impl Future { cx.observe(worktree, |_, _, cx| cx.notify()).detach(); if worktree.read(cx).is_local() { cx.subscribe(worktree, |this, worktree, event, cx| match event { @@ -4299,15 +4316,13 @@ impl Project { .push(WorktreeHandle::Weak(worktree.downgrade())); } - self.metadata_changed(cx); cx.observe_release(worktree, |this, worktree, cx| { - this.remove_worktree(worktree.id(), cx); - cx.notify(); + let _ = this.remove_worktree(worktree.id(), cx); }) .detach(); cx.emit(Event::WorktreeAdded); - cx.notify(); + self.metadata_changed(cx) } fn update_local_worktree_buffers( @@ -4624,11 +4639,11 @@ impl Project { } else { let worktree = Worktree::remote(remote_id, replica_id, worktree, client.clone(), cx); - this.add_worktree(&worktree, cx); + let _ = this.add_worktree(&worktree, cx); } } - this.metadata_changed(cx); + let _ = this.metadata_changed(cx); for (id, _) in old_worktrees_by_id { cx.emit(Event::WorktreeRemoved(id)); } @@ -4670,7 +4685,7 @@ impl Project { let entry = worktree .update(&mut cx, |worktree, cx| { let worktree = worktree.as_local_mut().unwrap(); - let path = PathBuf::from(OsString::from_vec(envelope.payload.path)); + let path = PathBuf::from(envelope.payload.path); worktree.create_entry(path, envelope.payload.is_directory, cx) }) .await?; @@ -4694,7 +4709,7 @@ impl Project { let worktree_scan_id = worktree.read_with(&cx, |worktree, _| worktree.scan_id()); let entry = worktree .update(&mut cx, |worktree, cx| { - let new_path = PathBuf::from(OsString::from_vec(envelope.payload.new_path)); + let new_path = PathBuf::from(envelope.payload.new_path); worktree .as_local_mut() .unwrap() @@ -4722,7 +4737,7 @@ impl Project { let worktree_scan_id = worktree.read_with(&cx, |worktree, _| worktree.scan_id()); let entry = worktree .update(&mut cx, |worktree, cx| { - let new_path = PathBuf::from(OsString::from_vec(envelope.payload.new_path)); + let new_path = PathBuf::from(envelope.payload.new_path); worktree .as_local_mut() .unwrap() @@ -5864,48 +5879,6 @@ impl Project { } } -impl ProjectStore { - pub fn new() -> Self { - Self { - projects: Default::default(), - } - } - - pub fn projects<'a>( - &'a self, - cx: &'a AppContext, - ) -> impl 'a + Iterator> { - self.projects - .iter() - .filter_map(|project| project.upgrade(cx)) - } - - fn add_project(&mut self, project: WeakModelHandle, cx: &mut ModelContext) { - if let Err(ix) = self - .projects - .binary_search_by_key(&project.id(), WeakModelHandle::id) - { - self.projects.insert(ix, project); - } - cx.notify(); - } - - fn prune_projects(&mut self, cx: &mut ModelContext) { - let mut did_change = false; - self.projects.retain(|project| { - if project.is_upgradable(cx) { - true - } else { - did_change = true; - false - } - }); - if did_change { - cx.notify(); - } - } -} - impl WorktreeHandle { pub fn upgrade(&self, cx: &AppContext) -> Option> { match self { @@ -5984,16 +5957,10 @@ impl<'a> Iterator for PathMatchCandidateSetIter<'a> { } } -impl Entity for ProjectStore { - type Event = (); -} - impl Entity for Project { type Event = Event; - fn release(&mut self, cx: &mut gpui::MutableAppContext) { - self.project_store.update(cx, ProjectStore::prune_projects); - + fn release(&mut self, _: &mut gpui::MutableAppContext) { match &self.client_state { Some(ProjectClientState::Local { remote_id, .. }) => { self.client diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index dfb699fdbb55ece4e32e7c7a362709ff6ddf92ff..a36831857f0ffed548fabe2d57ddb5f28aaac5fb 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -2166,7 +2166,11 @@ async fn test_rescan_and_remote_updates( proto::WorktreeMetadata { id: initial_snapshot.id().to_proto(), root_name: initial_snapshot.root_name().into(), - abs_path: initial_snapshot.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: initial_snapshot + .abs_path() + .as_os_str() + .to_string_lossy() + .into(), visible: true, }, rpc.clone(), diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 3bab90d5e34f188b923337d84cadd6397e0e8017..4781e17541a936e7e58cd19fc324a974da918076 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -41,7 +41,6 @@ use std::{ future::Future, mem, ops::{Deref, DerefMut}, - os::unix::prelude::{OsStrExt, OsStringExt}, path::{Path, PathBuf}, sync::{atomic::AtomicUsize, Arc}, task::Poll, @@ -83,6 +82,7 @@ pub struct RemoteWorktree { replica_id: ReplicaId, diagnostic_summaries: TreeMap, visible: bool, + disconnected: bool, } #[derive(Clone)] @@ -168,7 +168,7 @@ enum ScanState { struct ShareState { project_id: u64, snapshots_tx: watch::Sender, - _maintain_remote_snapshot: Option>>, + _maintain_remote_snapshot: Task>, } pub enum Event { @@ -222,7 +222,7 @@ impl Worktree { let root_name = worktree.root_name.clone(); let visible = worktree.visible; - let abs_path = PathBuf::from(OsString::from_vec(worktree.abs_path)); + let abs_path = PathBuf::from(worktree.abs_path); let snapshot = Snapshot { id: WorktreeId(remote_id as usize), abs_path: Arc::from(abs_path.deref()), @@ -248,6 +248,7 @@ impl Worktree { client: client.clone(), diagnostic_summaries: Default::default(), visible, + disconnected: false, }) }); @@ -660,7 +661,7 @@ impl LocalWorktree { id: self.id().to_proto(), root_name: self.root_name().to_string(), visible: self.visible, - abs_path: self.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: self.abs_path().as_os_str().to_string_lossy().into(), } } @@ -972,11 +973,10 @@ impl LocalWorktree { let _ = share_tx.send(Ok(())); } else { let (snapshots_tx, mut snapshots_rx) = watch::channel_with(self.snapshot()); - let rpc = self.client.clone(); let worktree_id = cx.model_id() as u64; for (path, summary) in self.diagnostic_summaries.iter() { - if let Err(e) = rpc.send(proto::UpdateDiagnosticSummary { + if let Err(e) = self.client.send(proto::UpdateDiagnosticSummary { project_id, worktree_id, summary: Some(summary.to_proto(&path.0)), @@ -986,15 +986,14 @@ impl LocalWorktree { } let maintain_remote_snapshot = cx.background().spawn({ - let rpc = rpc; - + let rpc = self.client.clone(); async move { let mut prev_snapshot = match snapshots_rx.recv().await { Some(snapshot) => { let update = proto::UpdateWorktree { project_id, worktree_id, - abs_path: snapshot.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: snapshot.abs_path().to_string_lossy().into(), root_name: snapshot.root_name().to_string(), updated_entries: snapshot .entries_by_path @@ -1034,10 +1033,11 @@ impl LocalWorktree { } .log_err() }); + self.share = Some(ShareState { project_id, snapshots_tx, - _maintain_remote_snapshot: Some(maintain_remote_snapshot), + _maintain_remote_snapshot: maintain_remote_snapshot, }); } @@ -1055,25 +1055,6 @@ impl LocalWorktree { pub fn is_shared(&self) -> bool { self.share.is_some() } - - pub fn send_extension_counts(&self, project_id: u64) { - let mut extensions = Vec::new(); - let mut counts = Vec::new(); - - for (extension, count) in self.extension_counts() { - extensions.push(extension.to_string_lossy().to_string()); - counts.push(*count as u32); - } - - self.client - .send(proto::UpdateWorktreeExtensions { - project_id, - worktree_id: self.id().to_proto(), - extensions, - counts, - }) - .log_err(); - } } impl RemoteWorktree { @@ -1090,6 +1071,7 @@ impl RemoteWorktree { pub fn disconnected_from_host(&mut self) { self.updates_tx.take(); self.snapshot_subscriptions.clear(); + self.disconnected = true; } pub fn update_from_remote(&mut self, update: proto::UpdateWorktree) { @@ -1104,10 +1086,12 @@ impl RemoteWorktree { self.scan_id > scan_id || (self.scan_id == scan_id && self.is_complete) } - fn wait_for_snapshot(&mut self, scan_id: usize) -> impl Future { + fn wait_for_snapshot(&mut self, scan_id: usize) -> impl Future> { let (tx, rx) = oneshot::channel(); if self.observed_snapshot(scan_id) { let _ = tx.send(()); + } else if self.disconnected { + drop(tx); } else { match self .snapshot_subscriptions @@ -1118,7 +1102,8 @@ impl RemoteWorktree { } async move { - let _ = rx.await; + rx.await?; + Ok(()) } } @@ -1147,7 +1132,7 @@ impl RemoteWorktree { ) -> Task> { let wait_for_snapshot = self.wait_for_snapshot(scan_id); cx.spawn(|this, mut cx| async move { - wait_for_snapshot.await; + wait_for_snapshot.await?; this.update(&mut cx, |worktree, _| { let worktree = worktree.as_remote_mut().unwrap(); let mut snapshot = worktree.background_snapshot.lock(); @@ -1166,7 +1151,7 @@ impl RemoteWorktree { ) -> Task> { let wait_for_snapshot = self.wait_for_snapshot(scan_id); cx.spawn(|this, mut cx| async move { - wait_for_snapshot.await; + wait_for_snapshot.await?; this.update(&mut cx, |worktree, _| { let worktree = worktree.as_remote_mut().unwrap(); let mut snapshot = worktree.background_snapshot.lock(); @@ -1404,7 +1389,7 @@ impl LocalSnapshot { proto::UpdateWorktree { project_id, worktree_id: self.id().to_proto(), - abs_path: self.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: self.abs_path().to_string_lossy().into(), root_name, updated_entries: self.entries_by_path.iter().map(Into::into).collect(), removed_entries: Default::default(), @@ -1472,7 +1457,7 @@ impl LocalSnapshot { proto::UpdateWorktree { project_id, worktree_id, - abs_path: self.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: self.abs_path().to_string_lossy().into(), root_name: self.root_name().to_string(), updated_entries, removed_entries, @@ -2951,7 +2936,7 @@ impl<'a> From<&'a Entry> for proto::Entry { Self { id: entry.id.to_proto(), is_dir: entry.is_dir(), - path: entry.path.as_os_str().as_bytes().to_vec(), + path: entry.path.to_string_lossy().into(), inode: entry.inode, mtime: Some(entry.mtime.into()), is_symlink: entry.is_symlink, @@ -2969,14 +2954,10 @@ impl<'a> TryFrom<(&'a CharBag, proto::Entry)> for Entry { EntryKind::Dir } else { let mut char_bag = *root_char_bag; - char_bag.extend( - String::from_utf8_lossy(&entry.path) - .chars() - .map(|c| c.to_ascii_lowercase()), - ); + char_bag.extend(entry.path.chars().map(|c| c.to_ascii_lowercase())); EntryKind::File(char_bag) }; - let path: Arc = PathBuf::from(OsString::from_vec(entry.path)).into(); + let path: Arc = PathBuf::from(entry.path).into(); Ok(Entry { id: ProjectEntryId::from_proto(entry.id), kind, diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 6bfe7124c9ce936fdee3e11b697f3d3925a0cf22..f7c7bfd6bcc7b5c881fa7b5c8c533419db016d6a 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -48,9 +48,7 @@ message Envelope { OpenBufferForSymbolResponse open_buffer_for_symbol_response = 40; UpdateProject update_project = 41; - RegisterProjectActivity register_project_activity = 42; UpdateWorktree update_worktree = 43; - UpdateWorktreeExtensions update_worktree_extensions = 44; CreateProjectEntry create_project_entry = 45; RenameProjectEntry rename_project_entry = 46; @@ -158,14 +156,12 @@ message JoinRoomResponse { optional LiveKitConnectionInfo live_kit_connection_info = 2; } -message LeaveRoom { - uint64 id = 1; -} +message LeaveRoom {} message Room { uint64 id = 1; repeated Participant participants = 2; - repeated uint64 pending_participant_user_ids = 3; + repeated PendingParticipant pending_participants = 3; string live_kit_room = 4; } @@ -176,6 +172,12 @@ message Participant { ParticipantLocation location = 4; } +message PendingParticipant { + uint64 user_id = 1; + uint64 calling_user_id = 2; + optional uint64 initial_project_id = 3; +} + message ParticipantProject { uint64 id = 1; repeated string worktree_root_names = 2; @@ -199,13 +201,13 @@ message ParticipantLocation { message Call { uint64 room_id = 1; - uint64 recipient_user_id = 2; + uint64 called_user_id = 2; optional uint64 initial_project_id = 3; } message IncomingCall { uint64 room_id = 1; - uint64 caller_user_id = 2; + uint64 calling_user_id = 2; repeated uint64 participant_user_ids = 3; optional ParticipantProject initial_project = 4; } @@ -214,7 +216,7 @@ message CallCanceled {} message CancelCall { uint64 room_id = 1; - uint64 recipient_user_id = 2; + uint64 called_user_id = 2; } message DeclineCall { @@ -253,10 +255,6 @@ message UpdateProject { repeated WorktreeMetadata worktrees = 2; } -message RegisterProjectActivity { - uint64 project_id = 1; -} - message JoinProject { uint64 project_id = 1; } @@ -280,33 +278,26 @@ message UpdateWorktree { repeated uint64 removed_entries = 5; uint64 scan_id = 6; bool is_last_update = 7; - bytes abs_path = 8; -} - -message UpdateWorktreeExtensions { - uint64 project_id = 1; - uint64 worktree_id = 2; - repeated string extensions = 3; - repeated uint32 counts = 4; + string abs_path = 8; } message CreateProjectEntry { uint64 project_id = 1; uint64 worktree_id = 2; - bytes path = 3; + string path = 3; bool is_directory = 4; } message RenameProjectEntry { uint64 project_id = 1; uint64 entry_id = 2; - bytes new_path = 3; + string new_path = 3; } message CopyProjectEntry { uint64 project_id = 1; uint64 entry_id = 2; - bytes new_path = 3; + string new_path = 3; } message DeleteProjectEntry { @@ -894,7 +885,7 @@ message File { message Entry { uint64 id = 1; bool is_dir = 2; - bytes path = 3; + string path = 3; uint64 inode = 4; Timestamp mtime = 5; bool is_symlink = 6; @@ -1078,7 +1069,7 @@ message WorktreeMetadata { uint64 id = 1; string root_name = 2; bool visible = 3; - bytes abs_path = 4; + string abs_path = 4; } message UpdateDiffBase { diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 4dbade4fec7969164ba80eb13dc3592cfb1c1bda..66ba6a40292d87d13a83943b0e40239ee37b526d 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -24,7 +24,7 @@ use std::{ }; use tracing::instrument; -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)] pub struct ConnectionId(pub u32); impl fmt::Display for ConnectionId { diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 11bbaaf5ffcbdeff96906033b1cacae6f62e48f0..6d9bc9a0aa348af8c1a14f442323fcf06064688e 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -140,12 +140,11 @@ messages!( (OpenBufferResponse, Background), (PerformRename, Background), (PerformRenameResponse, Background), + (Ping, Foreground), (PrepareRename, Background), (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), (RemoveContact, Foreground), - (Ping, Foreground), - (RegisterProjectActivity, Foreground), (ReloadBuffers, Foreground), (ReloadBuffersResponse, Foreground), (RemoveProjectCollaborator, Foreground), @@ -175,7 +174,6 @@ messages!( (UpdateParticipantLocation, Foreground), (UpdateProject, Foreground), (UpdateWorktree, Foreground), - (UpdateWorktreeExtensions, Background), (UpdateDiffBase, Background), (GetPrivateUserInfo, Foreground), (GetPrivateUserInfoResponse, Foreground), @@ -231,6 +229,7 @@ request_messages!( (Test, Test), (UpdateBuffer, Ack), (UpdateParticipantLocation, Ack), + (UpdateProject, Ack), (UpdateWorktree, Ack), ); @@ -262,7 +261,6 @@ entity_messages!( OpenBufferForSymbol, PerformRename, PrepareRename, - RegisterProjectActivity, ReloadBuffers, RemoveProjectCollaborator, RenameProjectEntry, @@ -278,7 +276,6 @@ entity_messages!( UpdateLanguageServer, UpdateProject, UpdateWorktree, - UpdateWorktreeExtensions, UpdateDiffBase ); diff --git a/crates/rpc/src/rpc.rs b/crates/rpc/src/rpc.rs index b6aef64677b6f06716a6ea40d9b52a42017c3543..5ca5711d9ca8c43cd5f1979ee76ea11e61053bec 100644 --- a/crates/rpc/src/rpc.rs +++ b/crates/rpc/src/rpc.rs @@ -6,4 +6,4 @@ pub use conn::Connection; pub use peer::*; mod macros; -pub const PROTOCOL_VERSION: u32 = 39; +pub const PROTOCOL_VERSION: u32 = 40; diff --git a/crates/sqlez/Cargo.toml b/crates/sqlez/Cargo.toml index c6c018b9244aa38444a6ef320aa6facd7d13f046..78bf83dc303cbda505347aefea9a7c7c7875e822 100644 --- a/crates/sqlez/Cargo.toml +++ b/crates/sqlez/Cargo.toml @@ -8,7 +8,7 @@ edition = "2021" [dependencies] anyhow = { version = "1.0.38", features = ["backtrace"] } indoc = "1.0.7" -libsqlite3-sys = { version = "0.25.2", features = ["bundled"] } +libsqlite3-sys = { version = "0.24", features = ["bundled"] } smol = "1.2" thread_local = "1.1.4" lazy_static = "1.4" diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 2327c517632965427fe7ca824f98262f8804e528..a0c353b3f808bf1f1a5c9a9909f2047139916449 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -53,7 +53,7 @@ pub use persistence::{ WorkspaceDb, }; use postage::prelude::Stream; -use project::{Project, ProjectEntryId, ProjectPath, ProjectStore, Worktree, WorktreeId}; +use project::{Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId}; use serde::Deserialize; use settings::{Autosave, DockAnchor, Settings}; use shared_screen::SharedScreen; @@ -372,7 +372,6 @@ pub struct AppState { pub themes: Arc, pub client: Arc, pub user_store: ModelHandle, - pub project_store: ModelHandle, pub fs: Arc, pub build_window_options: fn() -> WindowOptions<'static>, pub initialize_workspace: fn(&mut Workspace, &Arc, &mut ViewContext), @@ -392,7 +391,6 @@ impl AppState { let languages = Arc::new(LanguageRegistry::test()); let http_client = client::test::FakeHttpClient::with_404_response(); let client = Client::new(http_client.clone(), cx); - let project_store = cx.add_model(|_| ProjectStore::new()); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx)); let themes = ThemeRegistry::new((), cx.font_cache().clone()); Arc::new(Self { @@ -401,7 +399,6 @@ impl AppState { fs, languages, user_store, - project_store, initialize_workspace: |_, _, _| {}, build_window_options: Default::default, default_item_factory: |_, _| unimplemented!(), @@ -663,7 +660,6 @@ impl Workspace { let project_handle = Project::local( app_state.client.clone(), app_state.user_store.clone(), - app_state.project_store.clone(), app_state.languages.clone(), app_state.fs.clone(), cx, @@ -1035,8 +1031,10 @@ impl Workspace { RemoveWorktreeFromProject(worktree_id): &RemoveWorktreeFromProject, cx: &mut ViewContext, ) { - self.project + let future = self + .project .update(cx, |project, cx| project.remove_worktree(*worktree_id, cx)); + cx.foreground().spawn(future).detach(); } fn project_path_for_path( @@ -2866,9 +2864,9 @@ mod tests { ); // Remove a project folder - project.update(cx, |project, cx| { - project.remove_worktree(worktree_id, cx); - }); + project + .update(cx, |project, cx| project.remove_worktree(worktree_id, cx)) + .await; assert_eq!( cx.current_window_title(window_id).as_deref(), Some("one.txt — root2") diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 97a19b6d86f2aac54dc2520c22ed3fbff90f06b9..4163841d455d3e4e30620a28b80a0e5fdba16ca9 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -23,7 +23,7 @@ use isahc::{config::Configurable, Request}; use language::LanguageRegistry; use log::LevelFilter; use parking_lot::Mutex; -use project::{Fs, HomeDir, ProjectStore}; +use project::{Fs, HomeDir}; use serde_json::json; use settings::{ self, settings_file::SettingsFile, KeymapFileContent, Settings, SettingsFileContent, @@ -139,8 +139,6 @@ fn main() { }) .detach(); - let project_store = cx.add_model(|_| ProjectStore::new()); - client.start_telemetry(); client.report_event("start app", Default::default()); @@ -149,7 +147,6 @@ fn main() { themes, client: client.clone(), user_store, - project_store, fs, build_window_options, initialize_workspace,